| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import platform |
| import subprocess |
| import sys |
| from pathlib import Path |
|
|
| |
| import torch |
| import torchada |
|
|
| |
| from setuptools import find_packages, setup |
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension |
|
|
| root = Path(__file__).parent.resolve() |
| third_party = Path(os.environ.get("SGLANG_MUSA_THIRD_PARTY_DIR", "build/_deps")) |
| arch = platform.machine().lower() |
|
|
|
|
| class _RepoInfo: |
| """Configuration for a third-party git repository.""" |
|
|
| def __init__(self, name, git_repository, git_tag, git_shallow=False): |
| self.name = name |
| self.git_repository = git_repository |
| self.git_tag = git_tag |
| self.git_shallow = git_shallow |
| self.source_dir = third_party / name |
|
|
|
|
| _FLASHINFER_REPO = _RepoInfo( |
| name="flashinfer", |
| git_repository="https://github.com/flashinfer-ai/flashinfer.git", |
| git_tag="bc29697ba20b7e6bdb728ded98f04788e16ee021", |
| git_shallow=False, |
| ) |
|
|
| _MUTLASS_REPO = _RepoInfo( |
| name="mutlass", |
| git_repository="https://github.com/MooreThreads/mutlass.git", |
| git_tag="3abd6a728aacd190df0d922514aca8a8bc3c46b7", |
| git_shallow=False, |
| ) |
|
|
|
|
| def _get_version(): |
| with open(root / "pyproject.toml") as f: |
| for line in f: |
| if line.startswith("version"): |
| return line.split("=")[1].strip().strip('"') |
|
|
|
|
| operator_namespace = "sgl_kernel" |
| include_dirs = [ |
| root / "include", |
| root / "include" / "impl", |
| root / "csrc", |
| root / _FLASHINFER_REPO.source_dir / "include", |
| root / _FLASHINFER_REPO.source_dir / "csrc", |
| root / _MUTLASS_REPO.source_dir / "include", |
| ] |
|
|
| sources = [ |
| "csrc/common_extension_musa.cc", |
| str(_FLASHINFER_REPO.source_dir / "csrc/norm.cu"), |
| str(_FLASHINFER_REPO.source_dir / "csrc/renorm.cu"), |
| str(_FLASHINFER_REPO.source_dir / "csrc/sampling.cu"), |
| ] |
|
|
| cxx_flags = ["force_mcc"] |
| libraries = ["c10", "torch", "torch_python"] |
| extra_link_args = [ |
| "-Wl,-rpath,$ORIGIN/../../torch/lib", |
| f"-L/usr/lib/{arch}-linux-gnu", |
| "-lmublasLt", |
| ] |
|
|
| default_target = "mp_31" |
| mtgpu_target = os.environ.get("MTGPU_TARGET", default_target) |
|
|
| if torch.musa.is_available(): |
| try: |
| prop = torch.musa.get_device_properties(0) |
| mtgpu_target = f"mp_{prop.major}{prop.minor}" |
| except Exception as e: |
| print(f"Warning: Failed to detect GPU properties: {e}") |
| else: |
| print(f"Warning: torch.musa not available. Using default target: {mtgpu_target}") |
|
|
| if mtgpu_target not in ["mp_22", "mp_31"]: |
| print( |
| f"Warning: Unsupported GPU architecture detected '{mtgpu_target}'. Expected 'mp_22' or 'mp_31'." |
| ) |
| sys.exit(1) |
|
|
| mcc_flags = [ |
| "-DNDEBUG", |
| f"-DOPERATOR_NAMESPACE={operator_namespace}", |
| "-O3", |
| "-fPIC", |
| "-std=c++17", |
| f"--cuda-gpu-arch={mtgpu_target}", |
| "-x", |
| "musa", |
| "-mtgpu", |
| "-Od3", |
| "-ffast-math", |
| "-fmusa-flush-denormals-to-zero", |
| "-fno-strict-aliasing", |
| "-DUSE_MUSA", |
| "-DENABLE_BF16", |
| "-DFLASHINFER_ENABLE_F16", |
| "-DFLASHINFER_ENABLE_BF16", |
| ] |
|
|
| if mtgpu_target == "mp_31": |
| mcc_flags.extend( |
| [ |
| "-DENABLE_FP8", |
| "-DFLASHINFER_ENABLE_FP8", |
| "-DFLASHINFER_ENABLE_FP8_E4M3", |
| "-DFLASHINFER_ENABLE_FP8_E5M2", |
| ] |
| ) |
|
|
| ext_modules = [ |
| CUDAExtension( |
| name="sgl_kernel.common_ops", |
| sources=sources, |
| include_dirs=include_dirs, |
| extra_compile_args={ |
| "mcc": mcc_flags, |
| "cxx": cxx_flags, |
| }, |
| libraries=libraries, |
| extra_link_args=extra_link_args, |
| py_limited_api=False, |
| ), |
| ] |
|
|
|
|
| class _CustomBuildExt(BuildExtension): |
| """Custom build extension that clones third-party repositories before building.""" |
|
|
| @staticmethod |
| def _clone_and_checkout(repo_path, repo_url, git_tag, git_shallow): |
| """Clone a git repository and checkout a specific tag/commit.""" |
| repo_path.parent.mkdir(parents=True, exist_ok=True) |
| if not repo_path.exists(): |
| clone_cmd = ["git", "clone"] |
| if git_shallow: |
| clone_cmd += ["--depth", "1"] |
| clone_cmd += [repo_url, str(repo_path)] |
| subprocess.check_call(clone_cmd) |
| subprocess.check_call(["git", "checkout", git_tag], cwd=repo_path) |
| else: |
| subprocess.check_call(["git", "fetch", "--all"], cwd=repo_path) |
| subprocess.check_call(["git", "checkout", git_tag], cwd=repo_path) |
|
|
| def run(self): |
| if os.environ.get("SGLANG_MUSA_SKIP_THIRD_PARTY", "0") == "1": |
| print( |
| "Skipping third-party repositories cloning (SGLANG_MUSA_SKIP_THIRD_PARTY=1)" |
| ) |
| else: |
| print("Cloning third-party repositories...") |
| self._clone_and_checkout( |
| _MUTLASS_REPO.source_dir, |
| _MUTLASS_REPO.git_repository, |
| _MUTLASS_REPO.git_tag, |
| _MUTLASS_REPO.git_shallow, |
| ) |
| self._clone_and_checkout( |
| _FLASHINFER_REPO.source_dir, |
| _FLASHINFER_REPO.git_repository, |
| _FLASHINFER_REPO.git_tag, |
| _FLASHINFER_REPO.git_shallow, |
| ) |
| print("Third-party repositories ready.") |
|
|
| super().run() |
|
|
|
|
| setup( |
| name="sgl-kernel", |
| version=_get_version(), |
| packages=find_packages(where="python"), |
| package_dir={"": "python"}, |
| ext_modules=ext_modules, |
| cmdclass={"build_ext": _CustomBuildExt.with_options(use_ninja=True)}, |
| options={"bdist_wheel": {"py_limited_api": "cp39"}}, |
| ) |
|
|