File size: 3,614 Bytes
04eaca9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import re
import subprocess
import sys
from datetime import date

import setuptools
import torch
from packaging import version as packaging_version
from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension


class CustomBuildExtension(BuildExtension):
    def build_extensions(self):
        for ext in self.extensions:
            if not "cxx" in ext.extra_compile_args:
                ext.extra_compile_args["cxx"] = []
            if not "nvcc" in ext.extra_compile_args:
                ext.extra_compile_args["nvcc"] = []
            if self.compiler.compiler_type == "msvc":
                ext.extra_compile_args["cxx"] += ext.extra_compile_args["msvc"]
                ext.extra_compile_args["nvcc"] += ext.extra_compile_args["nvcc_msvc"]
            else:
                ext.extra_compile_args["cxx"] += ext.extra_compile_args["gcc"]
        super().build_extensions()


def get_sm_targets() -> list[str]:
    nvcc_path = os.path.join(CUDA_HOME, "bin/nvcc") if CUDA_HOME else "nvcc"
    try:
        nvcc_output = subprocess.check_output([nvcc_path, "--version"]).decode()
        match = re.search(r"release (\d+\.\d+), V(\d+\.\d+\.\d+)", nvcc_output)
        if match:
            nvcc_version = match.group(2)
        else:
            raise Exception("nvcc version not found")
        print(f"Found nvcc version: {nvcc_version}")
    except:
        raise Exception("nvcc not found")

    support_sm120 = packaging_version.parse(nvcc_version) >= packaging_version.parse("12.8")

    install_mode = os.getenv("NUNCHAKU_INSTALL_MODE", "FAST")
    if install_mode == "FAST":
        ret = []
        for i in range(torch.cuda.device_count()):
            capability = torch.cuda.get_device_capability(i)
            sm = f"{capability[0]}{capability[1]}"
            if sm == "120" and support_sm120:
                sm = "120a"
            ret.append(sm)
        return ret
    elif install_mode == "ALL":
        # All supported architectures (except for experimental ones)
        sm_targets = ["75", "80", "86", "89", "90"]
        if support_sm120:
            sm_targets.append("120a")
        return sm_targets
    else:
        raise ValueError(f"Unknown install mode: {install_mode}")


FLUX_SOURCES = [
    "nunchaku/csrc/pybind.cpp",
]

ext_modules = []

# Check if CUDA is available
if torch.cuda.is_available() and CUDA_HOME is not None:
    sm_targets = get_sm_targets()
    arch_flags = [f"-gencode=arch=compute_{sm},code=sm_{sm}" for sm in sm_targets]
    
    ext_modules.append(
        CUDAExtension(
            "nunchaku._C",
            FLUX_SOURCES,
            extra_compile_args={
                "cxx": ["-O3", "-std=c++20"],
                "nvcc": [
                    "-O3",
                    "-std=c++20",
                    "--use_fast_math",
                    "-U__CUDA_NO_HALF_OPERATORS__",
                    "-U__CUDA_NO_HALF_CONVERSIONS__",
                    "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                    "-U__CUDA_NO_HALF2_OPERATORS__",
                ] + arch_flags,
                "msvc": ["/std:c++20"],
                "gcc": ["-std=c++20"],
                "nvcc_msvc": [],
            },
            include_dirs=[
                "third_party/cutlass/include",
                "third_party/cutlass/tools/util/include",
            ],
        )
    )
else:
    print("CUDA not available. Installing CPU-only version.")

setuptools.setup(
    name="flux-kontext",
    packages=setuptools.find_packages(),
    ext_modules=ext_modules,
    cmdclass={"build_ext": CustomBuildExtension},
    zip_safe=False,
)