theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb
import shutil
import subprocess
from pathlib import Path
import torch
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CppExtension
ROOT = Path(__file__).parent.resolve()
# PyTorch dylibs (libc10, libtorch, …) are @rpath-linked; embed torch/lib so import works from any cwd.
_TORCH_LIB = Path(torch.__file__).resolve().parent / "lib"
METAL_SRC = ROOT / "sparse_linear.metal"
AIR = ROOT / "sparse_linear.air"
METALLIB = ROOT / "sparse_linear_ops.metallib"
class MetalBuildExt(BuildExtension):
def run(self):
if shutil.which("xcrun") is None:
raise RuntimeError("xcrun not found. Install Xcode command line tools.")
subprocess.check_call(["xcrun", "-sdk", "macosx", "metal", "-c", str(METAL_SRC), "-o", str(AIR)])
subprocess.check_call(["xcrun", "-sdk", "macosx", "metallib", str(AIR), "-o", str(METALLIB)])
super().run()
# Copy metallib next to the built extension .so.
build_lib = Path(self.build_lib)
for so in build_lib.rglob("sparse_linear_metal*.so"):
shutil.copy2(METALLIB, so.parent / METALLIB.name)
setup(
name="sparse_linear_metal",
version="0.1.0",
ext_modules=[
CppExtension(
name="sparse_linear_metal",
sources=["sparse_linear_ops.mm"],
extra_compile_args={"cxx": ["-std=c++17", "-ObjC++", "-fobjc-arc"]},
extra_link_args=[
"-framework",
"Metal",
"-framework",
"Foundation",
f"-Wl,-rpath,{_TORCH_LIB}",
],
)
],
cmdclass={"build_ext": MetalBuildExt},
)