Spaces:
Runtime error
Runtime error
anas commited on
Commit ·
86c87ce
1
Parent(s): 16ab340
Fix models/ops build: respect FORCE_CUDA env var for headless CUDA compilation
Browse files- models/ops/setup.py +3 -2
models/ops/setup.py
CHANGED
|
@@ -33,7 +33,8 @@ def get_extensions():
|
|
| 33 |
extra_compile_args = {"cxx": []}
|
| 34 |
define_macros = []
|
| 35 |
|
| 36 |
-
|
|
|
|
| 37 |
extension = CUDAExtension
|
| 38 |
sources += source_cuda
|
| 39 |
define_macros += [("WITH_CUDA", None)]
|
|
@@ -44,7 +45,7 @@ def get_extensions():
|
|
| 44 |
"-D__CUDA_NO_HALF2_OPERATORS__",
|
| 45 |
]
|
| 46 |
else:
|
| 47 |
-
raise NotImplementedError('Cuda is not
|
| 48 |
|
| 49 |
sources = [os.path.join(extensions_dir, s) for s in sources]
|
| 50 |
include_dirs = [extensions_dir]
|
|
|
|
| 33 |
extra_compile_args = {"cxx": []}
|
| 34 |
define_macros = []
|
| 35 |
|
| 36 |
+
force_cuda = os.environ.get("FORCE_CUDA", "0") == "1"
|
| 37 |
+
if (torch.cuda.is_available() or force_cuda) and CUDA_HOME is not None:
|
| 38 |
extension = CUDAExtension
|
| 39 |
sources += source_cuda
|
| 40 |
define_macros += [("WITH_CUDA", None)]
|
|
|
|
| 45 |
"-D__CUDA_NO_HALF2_OPERATORS__",
|
| 46 |
]
|
| 47 |
else:
|
| 48 |
+
raise NotImplementedError('Cuda is not available')
|
| 49 |
|
| 50 |
sources = [os.path.join(extensions_dir, s) for s in sources]
|
| 51 |
include_dirs = [extensions_dir]
|