import shutil import subprocess from pathlib import Path import torch from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension ROOT = Path(__file__).parent.resolve() # PyTorch dylibs (libc10, libtorch, …) are @rpath-linked; embed torch/lib so import works from any cwd. _TORCH_LIB = Path(torch.__file__).resolve().parent / "lib" METAL_SRC = ROOT / "sparse_linear.metal" AIR = ROOT / "sparse_linear.air" METALLIB = ROOT / "sparse_linear_ops.metallib" class MetalBuildExt(BuildExtension): def run(self): if shutil.which("xcrun") is None: raise RuntimeError("xcrun not found. Install Xcode command line tools.") subprocess.check_call(["xcrun", "-sdk", "macosx", "metal", "-c", str(METAL_SRC), "-o", str(AIR)]) subprocess.check_call(["xcrun", "-sdk", "macosx", "metallib", str(AIR), "-o", str(METALLIB)]) super().run() # Copy metallib next to the built extension .so. build_lib = Path(self.build_lib) for so in build_lib.rglob("sparse_linear_metal*.so"): shutil.copy2(METALLIB, so.parent / METALLIB.name) setup( name="sparse_linear_metal", version="0.1.0", ext_modules=[ CppExtension( name="sparse_linear_metal", sources=["sparse_linear_ops.mm"], extra_compile_args={"cxx": ["-std=c++17", "-ObjC++", "-fobjc-arc"]}, extra_link_args=[ "-framework", "Metal", "-framework", "Foundation", f"-Wl,-rpath,{_TORCH_LIB}", ], ) ], cmdclass={"build_ext": MetalBuildExt}, )