File size: 1,662 Bytes
bc1b8eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import shutil
import subprocess
from pathlib import Path
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension

ROOT = Path(__file__).parent.resolve()
# PyTorch dylibs (libc10, libtorch, …) are @rpath-linked; embed torch/lib so import works from any cwd.
_TORCH_LIB = Path(torch.__file__).resolve().parent / "lib"
METAL_SRC = ROOT / "sparse_linear.metal"
AIR = ROOT / "sparse_linear.air"
METALLIB = ROOT / "sparse_linear_ops.metallib"

class MetalBuildExt(BuildExtension):
    def run(self):
        if shutil.which("xcrun") is None:
            raise RuntimeError("xcrun not found. Install Xcode command line tools.")
        subprocess.check_call(["xcrun", "-sdk", "macosx", "metal", "-c", str(METAL_SRC), "-o", str(AIR)])
        subprocess.check_call(["xcrun", "-sdk", "macosx", "metallib", str(AIR), "-o", str(METALLIB)])
        super().run()
        # Copy metallib next to the built extension .so.
        build_lib = Path(self.build_lib)
        for so in build_lib.rglob("sparse_linear_metal*.so"):
            shutil.copy2(METALLIB, so.parent / METALLIB.name)

setup(
    name="sparse_linear_metal",
    version="0.1.0",
    ext_modules=[
        CppExtension(
            name="sparse_linear_metal",
            sources=["sparse_linear_ops.mm"],
            extra_compile_args={"cxx": ["-std=c++17", "-ObjC++", "-fobjc-arc"]},
            extra_link_args=[
                "-framework",
                "Metal",
                "-framework",
                "Foundation",
                f"-Wl,-rpath,{_TORCH_LIB}",
            ],
        )
    ],
    cmdclass={"build_ext": MetalBuildExt},
)