anas commited on
Commit
86c87ce
·
1 Parent(s): 16ab340

Fix models/ops build: respect FORCE_CUDA env var for headless CUDA compilation

Browse files
Files changed (1) hide show
  1. models/ops/setup.py +3 -2
models/ops/setup.py CHANGED
@@ -33,7 +33,8 @@ def get_extensions():
33
  extra_compile_args = {"cxx": []}
34
  define_macros = []
35
 
36
- if torch.cuda.is_available() and CUDA_HOME is not None:
 
37
  extension = CUDAExtension
38
  sources += source_cuda
39
  define_macros += [("WITH_CUDA", None)]
@@ -44,7 +45,7 @@ def get_extensions():
44
  "-D__CUDA_NO_HALF2_OPERATORS__",
45
  ]
46
  else:
47
- raise NotImplementedError('Cuda is not availabel')
48
 
49
  sources = [os.path.join(extensions_dir, s) for s in sources]
50
  include_dirs = [extensions_dir]
 
33
  extra_compile_args = {"cxx": []}
34
  define_macros = []
35
 
36
+ force_cuda = os.environ.get("FORCE_CUDA", "0") == "1"
37
+ if (torch.cuda.is_available() or force_cuda) and CUDA_HOME is not None:
38
  extension = CUDAExtension
39
  sources += source_cuda
40
  define_macros += [("WITH_CUDA", None)]
 
45
  "-D__CUDA_NO_HALF2_OPERATORS__",
46
  ]
47
  else:
48
+ raise NotImplementedError('Cuda is not available')
49
 
50
  sources = [os.path.join(extensions_dir, s) for s in sources]
51
  include_dirs = [extensions_dir]