# Copyright 2025 SGLang Team. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import os import platform import subprocess import sys from pathlib import Path # isort: off import torch import torchada # noqa: F401 # isort: on from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension root = Path(__file__).parent.resolve() third_party = Path(os.environ.get("SGLANG_MUSA_THIRD_PARTY_DIR", "build/_deps")) arch = platform.machine().lower() class _RepoInfo: """Configuration for a third-party git repository.""" def __init__(self, name, git_repository, git_tag, git_shallow=False): self.name = name self.git_repository = git_repository self.git_tag = git_tag self.git_shallow = git_shallow self.source_dir = third_party / name _FLASHINFER_REPO = _RepoInfo( name="flashinfer", git_repository="https://github.com/flashinfer-ai/flashinfer.git", git_tag="bc29697ba20b7e6bdb728ded98f04788e16ee021", git_shallow=False, ) _MUTLASS_REPO = _RepoInfo( name="mutlass", git_repository="https://github.com/MooreThreads/mutlass.git", git_tag="3abd6a728aacd190df0d922514aca8a8bc3c46b7", git_shallow=False, ) def _get_version(): with open(root / "pyproject.toml") as f: for line in f: if line.startswith("version"): return line.split("=")[1].strip().strip('"') operator_namespace = "sgl_kernel" include_dirs = [ root / "include", root / "include" / "impl", root / "csrc", root / _FLASHINFER_REPO.source_dir / "include", root / _FLASHINFER_REPO.source_dir / "csrc", root / _MUTLASS_REPO.source_dir / "include", ] sources = [ "csrc/common_extension_musa.cc", str(_FLASHINFER_REPO.source_dir / "csrc/norm.cu"), str(_FLASHINFER_REPO.source_dir / "csrc/renorm.cu"), str(_FLASHINFER_REPO.source_dir / "csrc/sampling.cu"), ] cxx_flags = ["force_mcc"] libraries = ["c10", "torch", "torch_python"] extra_link_args = [ "-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu", "-lmublasLt", ] default_target = "mp_31" mtgpu_target = os.environ.get("MTGPU_TARGET", default_target) if torch.musa.is_available(): try: prop = torch.musa.get_device_properties(0) mtgpu_target = f"mp_{prop.major}{prop.minor}" except Exception as e: print(f"Warning: Failed to detect GPU properties: {e}") else: print(f"Warning: torch.musa not available. Using default target: {mtgpu_target}") if mtgpu_target not in ["mp_22", "mp_31"]: print( f"Warning: Unsupported GPU architecture detected '{mtgpu_target}'. Expected 'mp_22' or 'mp_31'." ) sys.exit(1) mcc_flags = [ "-DNDEBUG", f"-DOPERATOR_NAMESPACE={operator_namespace}", "-O3", "-fPIC", "-std=c++17", f"--cuda-gpu-arch={mtgpu_target}", "-x", "musa", "-mtgpu", "-Od3", "-ffast-math", "-fmusa-flush-denormals-to-zero", "-fno-strict-aliasing", "-DUSE_MUSA", "-DENABLE_BF16", "-DFLASHINFER_ENABLE_F16", "-DFLASHINFER_ENABLE_BF16", ] if mtgpu_target == "mp_31": mcc_flags.extend( [ "-DENABLE_FP8", "-DFLASHINFER_ENABLE_FP8", "-DFLASHINFER_ENABLE_FP8_E4M3", "-DFLASHINFER_ENABLE_FP8_E5M2", ] ) ext_modules = [ CUDAExtension( name="sgl_kernel.common_ops", sources=sources, include_dirs=include_dirs, extra_compile_args={ "mcc": mcc_flags, "cxx": cxx_flags, }, libraries=libraries, extra_link_args=extra_link_args, py_limited_api=False, ), ] class _CustomBuildExt(BuildExtension): """Custom build extension that clones third-party repositories before building.""" @staticmethod def _clone_and_checkout(repo_path, repo_url, git_tag, git_shallow): """Clone a git repository and checkout a specific tag/commit.""" repo_path.parent.mkdir(parents=True, exist_ok=True) if not repo_path.exists(): clone_cmd = ["git", "clone"] if git_shallow: clone_cmd += ["--depth", "1"] clone_cmd += [repo_url, str(repo_path)] subprocess.check_call(clone_cmd) subprocess.check_call(["git", "checkout", git_tag], cwd=repo_path) else: subprocess.check_call(["git", "fetch", "--all"], cwd=repo_path) subprocess.check_call(["git", "checkout", git_tag], cwd=repo_path) def run(self): if os.environ.get("SGLANG_MUSA_SKIP_THIRD_PARTY", "0") == "1": print( "Skipping third-party repositories cloning (SGLANG_MUSA_SKIP_THIRD_PARTY=1)" ) else: print("Cloning third-party repositories...") self._clone_and_checkout( _MUTLASS_REPO.source_dir, _MUTLASS_REPO.git_repository, _MUTLASS_REPO.git_tag, _MUTLASS_REPO.git_shallow, ) self._clone_and_checkout( _FLASHINFER_REPO.source_dir, _FLASHINFER_REPO.git_repository, _FLASHINFER_REPO.git_tag, _FLASHINFER_REPO.git_shallow, ) print("Third-party repositories ready.") super().run() setup( name="sgl-kernel", version=_get_version(), packages=find_packages(where="python"), package_dir={"": "python"}, ext_modules=ext_modules, cmdclass={"build_ext": _CustomBuildExt.with_options(use_ninja=True)}, options={"bdist_wheel": {"py_limited_api": "cp39"}}, )