File size: 979 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
--- setup.py.orig	2024-10-02 00:00:00.000000000 +0000
+++ setup.py	2024-10-02 00:00:00.000000000 +0000
@@ -66,6 +66,17 @@
     nvcc_cuda_version = parse(output[release_idx].split(",")[0])
     return nvcc_cuda_version
 
+# Check for TORCH_CUDA_ARCH_LIST environment variable first
+import os
+env_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
+if env_arch_list:
+    print(f"Using TORCH_CUDA_ARCH_LIST from environment: {env_arch_list}")
+    arch_list = env_arch_list.replace(" ", ";").split(";")
+    for arch in arch_list:
+        arch = arch.strip()
+        if not arch:
+            continue
+        if arch.endswith("+PTX"):
+            arch = arch[:-4].strip()
+        if arch:
+            compute_capabilities.add(arch)
+
 # Iterate over all GPUs on the current machine. Also you can modify this part to specify the architecture if you want to build for specific GPU architectures.
 compute_capabilities = set()
 device_count = torch.cuda.device_count()