--- setup.py.orig 2024-10-02 00:00:00.000000000 +0000 +++ setup.py 2024-10-02 00:00:00.000000000 +0000 @@ -66,6 +66,17 @@ nvcc_cuda_version = parse(output[release_idx].split(",")[0]) return nvcc_cuda_version +# Check for TORCH_CUDA_ARCH_LIST environment variable first +import os +env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) +if env_arch_list: + print(f"Using TORCH_CUDA_ARCH_LIST from environment: {env_arch_list}") + arch_list = env_arch_list.replace(" ", ";").split(";") + for arch in arch_list: + arch = arch.strip() + if not arch: + continue + if arch.endswith("+PTX"): + arch = arch[:-4].strip() + if arch: + compute_capabilities.add(arch) + # Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures. compute_capabilities = set() device_count = torch.cuda.device_count()