Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | import shutil | |
| import subprocess | |
| from pathlib import Path | |
| import torch | |
| from setuptools import setup | |
| from torch.utils.cpp_extension import BuildExtension, CppExtension | |
| ROOT = Path(__file__).parent.resolve() | |
| # PyTorch dylibs (libc10, libtorch, …) are @rpath-linked; embed torch/lib so import works from any cwd. | |
| _TORCH_LIB = Path(torch.__file__).resolve().parent / "lib" | |
| METAL_SRC = ROOT / "sparse_linear.metal" | |
| AIR = ROOT / "sparse_linear.air" | |
| METALLIB = ROOT / "sparse_linear_ops.metallib" | |
| class MetalBuildExt(BuildExtension): | |
| def run(self): | |
| if shutil.which("xcrun") is None: | |
| raise RuntimeError("xcrun not found. Install Xcode command line tools.") | |
| subprocess.check_call(["xcrun", "-sdk", "macosx", "metal", "-c", str(METAL_SRC), "-o", str(AIR)]) | |
| subprocess.check_call(["xcrun", "-sdk", "macosx", "metallib", str(AIR), "-o", str(METALLIB)]) | |
| super().run() | |
| # Copy metallib next to the built extension .so. | |
| build_lib = Path(self.build_lib) | |
| for so in build_lib.rglob("sparse_linear_metal*.so"): | |
| shutil.copy2(METALLIB, so.parent / METALLIB.name) | |
| setup( | |
| name="sparse_linear_metal", | |
| version="0.1.0", | |
| ext_modules=[ | |
| CppExtension( | |
| name="sparse_linear_metal", | |
| sources=["sparse_linear_ops.mm"], | |
| extra_compile_args={"cxx": ["-std=c++17", "-ObjC++", "-fobjc-arc"]}, | |
| extra_link_args=[ | |
| "-framework", | |
| "Metal", | |
| "-framework", | |
| "Foundation", | |
| f"-Wl,-rpath,{_TORCH_LIB}", | |
| ], | |
| ) | |
| ], | |
| cmdclass={"build_ext": MetalBuildExt}, | |
| ) | |