import os import re import subprocess import sys from datetime import date import setuptools import torch from packaging import version as packaging_version from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension class CustomBuildExtension(BuildExtension): def build_extensions(self): for ext in self.extensions: if not "cxx" in ext.extra_compile_args: ext.extra_compile_args["cxx"] = [] if not "nvcc" in ext.extra_compile_args: ext.extra_compile_args["nvcc"] = [] if self.compiler.compiler_type == "msvc": ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"] ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"] else: ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"] super().build_extensions() def get_sm_targets() -> list[str]: nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc" try: nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode() match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output) if match: nvcc_version = match.group(2) else: raise Exception("nvcc version not found") print(f"Found nvcc version: {nvcc_version}") except: raise Exception("nvcc not found") support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8") install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST") if install_mode == "FAST": ret = [] for i in range(torch.cuda.device_count()): capability = torch.cuda.get_device_capability(i) sm = f"{capability[0]}{capability[1]}" if sm == "120" and support_sm120: sm = "120a" ret.append(sm) return ret elif install_mode == "ALL": # All supported architectures (except for experimental ones) sm_targets = ["75", "80", "86", "89", "90"] if support_sm120: sm_targets.append("120a") return sm_targets else: raise ValueError(f"Unknown install mode: {install_mode}") FLUX_SOURCES = [ "nunchaku/csrc/pybind.cpp", ] ext_modules = [] # Check if CUDA is available if torch.cuda.is_available() and CUDA_HOME is not None: sm_targets = get_sm_targets() arch_flags = [f"-gencode=arch=compute_{sm},code=sm_{sm}" for sm in sm_targets] ext_modules.append( CUDAExtension( "nunchaku._C", FLUX_SOURCES, extra_compile_args={ "cxx": ["-O3", "-std=c++20"], "nvcc": [ "-O3", "-std=c++20", "--use_fast_math", "-U__CUDA_NO_HALF_OPERATORS__", "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", ] + arch_flags, "msvc": ["/std:c++20"], "gcc": ["-std=c++20"], "nvcc_msvc": [], }, include_dirs=[ "third_party/cutlass/include", "third_party/cutlass/tools/util/include", ], ) ) else: print("CUDA not available. Installing CPU-only version.") setuptools.setup( name="flux-kontext", packages=setuptools.find_packages(), ext_modules=ext_modules, cmdclass={"build_ext": CustomBuildExtension}, zip_safe=False, )