Spaces:
Runtime error
Runtime error
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| from datetime import date | |
| import setuptools | |
| import torch | |
| from packaging import version as packaging_version | |
| from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension | |
| class CustomBuildExtension(BuildExtension): | |
| def build_extensions(self): | |
| for ext in self.extensions: | |
| if not "cxx" in ext.extra_compile_args: | |
| ext.extra_compile_args["cxx"] = [] | |
| if not "nvcc" in ext.extra_compile_args: | |
| ext.extra_compile_args["nvcc"] = [] | |
| if self.compiler.compiler_type == "msvc": | |
| ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"] | |
| ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"] | |
| else: | |
| ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"] | |
| super().build_extensions() | |
| def get_sm_targets() -> list[str]: | |
| nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc" | |
| try: | |
| nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode() | |
| match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output) | |
| if match: | |
| nvcc_version = match.group(2) | |
| else: | |
| raise Exception("nvcc version not found") | |
| print(f"Found nvcc version: {nvcc_version}") | |
| except: | |
| raise Exception("nvcc not found") | |
| support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8") | |
| install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST") | |
| if install_mode == "FAST": | |
| ret = [] | |
| for i in range(torch.cuda.device_count()): | |
| capability = torch.cuda.get_device_capability(i) | |
| sm = f"{capability[0]}{capability[1]}" | |
| if sm == "120" and support_sm120: | |
| sm = "120a" | |
| ret.append(sm) | |
| return ret | |
| elif install_mode == "ALL": | |
| # All supported architectures (except for experimental ones) | |
| sm_targets = ["75", "80", "86", "89", "90"] | |
| if support_sm120: | |
| sm_targets.append("120a") | |
| return sm_targets | |
| else: | |
| raise ValueError(f"Unknown install mode: {install_mode}") | |
| FLUX_SOURCES = [ | |
| "nunchaku/csrc/pybind.cpp", | |
| ] | |
| ext_modules = [] | |
| # Check if CUDA is available | |
| if torch.cuda.is_available() and CUDA_HOME is not None: | |
| sm_targets = get_sm_targets() | |
| arch_flags = [f"-gencode=arch=compute_{sm},code=sm_{sm}" for sm in sm_targets] | |
| ext_modules.append( | |
| CUDAExtension( | |
| "nunchaku._C", | |
| FLUX_SOURCES, | |
| extra_compile_args={ | |
| "cxx": ["-O3", "-std=c++20"], | |
| "nvcc": [ | |
| "-O3", | |
| "-std=c++20", | |
| "--use_fast_math", | |
| "-U__CUDA_NO_HALF_OPERATORS__", | |
| "-U__CUDA_NO_HALF_CONVERSIONS__", | |
| "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", | |
| "-U__CUDA_NO_HALF2_OPERATORS__", | |
| ] + arch_flags, | |
| "msvc": ["/std:c++20"], | |
| "gcc": ["-std=c++20"], | |
| "nvcc_msvc": [], | |
| }, | |
| include_dirs=[ | |
| "third_party/cutlass/include", | |
| "third_party/cutlass/tools/util/include", | |
| ], | |
| ) | |
| ) | |
| else: | |
| print("CUDA not available. Installing CPU-only version.") | |
| setuptools.setup( | |
| name="flux-kontext", | |
| packages=setuptools.find_packages(), | |
| ext_modules=ext_modules, | |
| cmdclass={"build_ext": CustomBuildExtension}, | |
| zip_safe=False, | |
| ) | |