File size: 1,662 Bytes
bc1b8eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import shutil
import subprocess
from pathlib import Path
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
ROOT = Path(__file__).parent.resolve()
# PyTorch dylibs (libc10, libtorch, …) are @rpath-linked; embed torch/lib so import works from any cwd.
_TORCH_LIB = Path(torch.__file__).resolve().parent / "lib"
METAL_SRC = ROOT / "sparse_linear.metal"
AIR = ROOT / "sparse_linear.air"
METALLIB = ROOT / "sparse_linear_ops.metallib"
class MetalBuildExt(BuildExtension):
def run(self):
if shutil.which("xcrun") is None:
raise RuntimeError("xcrun not found. Install Xcode command line tools.")
subprocess.check_call(["xcrun", "-sdk", "macosx", "metal", "-c", str(METAL_SRC), "-o", str(AIR)])
subprocess.check_call(["xcrun", "-sdk", "macosx", "metallib", str(AIR), "-o", str(METALLIB)])
super().run()
# Copy metallib next to the built extension .so.
build_lib = Path(self.build_lib)
for so in build_lib.rglob("sparse_linear_metal*.so"):
shutil.copy2(METALLIB, so.parent / METALLIB.name)
setup(
name="sparse_linear_metal",
version="0.1.0",
ext_modules=[
CppExtension(
name="sparse_linear_metal",
sources=["sparse_linear_ops.mm"],
extra_compile_args={"cxx": ["-std=c++17", "-ObjC++", "-fobjc-arc"]},
extra_link_args=[
"-framework",
"Metal",
"-framework",
"Foundation",
f"-Wl,-rpath,{_TORCH_LIB}",
],
)
],
cmdclass={"build_ext": MetalBuildExt},
)
|