Stylique commited on
Commit
3e2e28d
Β·
verified Β·
1 Parent(s): 54b1094

Upload 4 files

Browse files
Files changed (1) hide show
  1. post_install.py +81 -17
post_install.py CHANGED
@@ -71,17 +71,9 @@ def install_torch_sparse():
71
 
72
  if pytorch_base.startswith("2.7"):
73
  print("PyTorch 2.7.x detected - this is very recent and may not have compatible torch-sparse/torch-scatter wheels")
74
- print("Attempting to install a more compatible PyTorch version...")
75
-
76
- # Try to install PyTorch 2.0.1 which has known compatible wheels
77
- if run_command("pip install --force-reinstall torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117"):
78
- print("Successfully installed PyTorch 2.0.1+cu117")
79
- PYTORCH_VERSION = "2.0.1"
80
- CUDA_VERSION = "cu117"
81
- else:
82
- print("Failed to install PyTorch 2.0.1, will try to work with current version")
83
- PYTORCH_VERSION = pytorch_base
84
- CUDA_VERSION = cuda_version
85
  else:
86
  # Store version info for later use
87
  PYTORCH_VERSION = pytorch_base
@@ -155,7 +147,45 @@ def install_torch_sparse():
155
  print("Successfully installed torch-sparse from source")
156
  return True
157
 
158
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
  def install_torch_scatter():
161
  """Install torch-scatter with compatible PyTorch version"""
@@ -217,7 +247,24 @@ def install_torch_scatter():
217
  print("Successfully installed torch-scatter from source")
218
  return True
219
 
220
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
  def install_nvdiffrast():
223
  """Install nvdiffrast"""
@@ -397,11 +444,28 @@ def main():
397
  print(f"⚠ Warning: PyTorch version {torch.__version__} may not be compatible with installed extensions")
398
  print(f"Expected version: {PYTORCH_VERSION}")
399
 
400
- import torch_sparse
401
- print("βœ“ torch-sparse")
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
- import torch_scatter
404
- print("βœ“ torch-scatter")
 
 
 
 
405
 
406
  import nvdiffrast
407
  print("βœ“ nvdiffrast")
 
71
 
72
  if pytorch_base.startswith("2.7"):
73
  print("PyTorch 2.7.x detected - this is very recent and may not have compatible torch-sparse/torch-scatter wheels")
74
+ print("Will try to work with current PyTorch version and attempt alternative installation methods")
75
+ PYTORCH_VERSION = pytorch_base
76
+ CUDA_VERSION = cuda_version
 
 
 
 
 
 
 
 
77
  else:
78
  # Store version info for later use
79
  PYTORCH_VERSION = pytorch_base
 
147
  print("Successfully installed torch-sparse from source")
148
  return True
149
 
150
+ # Try installing with specific build flags for PyTorch 2.7
151
+ print("Trying torch-sparse installation with specific build flags...")
152
+ env = os.environ.copy()
153
+ env['TORCH_CUDA_ARCH_LIST'] = '6.0;6.1;7.0;7.5;8.0;8.6'
154
+ env['FORCE_CUDA'] = '1'
155
+ if run_command("pip install torch-sparse --no-build-isolation", env=env):
156
+ print("Successfully installed torch-sparse with build flags")
157
+ return True
158
+
159
+ # Try installing from git with specific version
160
+ print("Trying torch-sparse installation from git...")
161
+ if run_command("pip install git+https://github.com/rusty1s/pytorch_sparse.git --no-build-isolation"):
162
+ print("Successfully installed torch-sparse from git")
163
+ return True
164
+
165
+ # If all else fails, disable torch-sparse and use built-in PyTorch sparse operations
166
+ print("Failed to install torch-sparse, disabling it and using built-in PyTorch sparse operations")
167
+ disable_torch_sparse()
168
+ return True
169
+
170
+ def disable_torch_sparse():
171
+ """Disable torch-sparse in the code by modifying the configuration"""
172
+ try:
173
+ # Modify the PoissonSystem.py to disable torch-sparse
174
+ poisson_file = "NeuralJacobianFields/PoissonSystem.py"
175
+ if os.path.exists(poisson_file):
176
+ with open(poisson_file, 'r') as f:
177
+ content = f.read()
178
+
179
+ # Replace USE_TORCH_SPARSE = True with False
180
+ content = content.replace("USE_TORCH_SPARSE = True", "USE_TORCH_SPARSE = False")
181
+
182
+ with open(poisson_file, 'w') as f:
183
+ f.write(content)
184
+
185
+ print("Disabled torch-sparse in NeuralJacobianFields/PoissonSystem.py")
186
+ print("Will use built-in PyTorch sparse operations instead")
187
+ except Exception as e:
188
+ print(f"Warning: Could not disable torch-sparse: {e}")
189
 
190
  def install_torch_scatter():
191
  """Install torch-scatter with compatible PyTorch version"""
 
247
  print("Successfully installed torch-scatter from source")
248
  return True
249
 
250
+ # Try installing with specific build flags for PyTorch 2.7
251
+ print("Trying torch-scatter installation with specific build flags...")
252
+ env = os.environ.copy()
253
+ env['TORCH_CUDA_ARCH_LIST'] = '6.0;6.1;7.0;7.5;8.0;8.6'
254
+ env['FORCE_CUDA'] = '1'
255
+ if run_command("pip install torch-scatter --no-build-isolation", env=env):
256
+ print("Successfully installed torch-scatter with build flags")
257
+ return True
258
+
259
+ # Try installing from git with specific version
260
+ print("Trying torch-scatter installation from git...")
261
+ if run_command("pip install git+https://github.com/rusty1s/pytorch_scatter.git --no-build-isolation"):
262
+ print("Successfully installed torch-scatter from git")
263
+ return True
264
+
265
+ # If all else fails, note that torch-scatter is not critical for basic functionality
266
+ print("Failed to install torch-scatter, but this may not be critical for basic functionality")
267
+ return True
268
 
269
  def install_nvdiffrast():
270
  """Install nvdiffrast"""
 
444
  print(f"⚠ Warning: PyTorch version {torch.__version__} may not be compatible with installed extensions")
445
  print(f"Expected version: {PYTORCH_VERSION}")
446
 
447
+ # Check if torch-sparse is available or disabled
448
+ try:
449
+ import torch_sparse
450
+ print("βœ“ torch-sparse")
451
+ except ImportError:
452
+ # Check if torch-sparse was disabled
453
+ try:
454
+ with open("NeuralJacobianFields/PoissonSystem.py", 'r') as f:
455
+ content = f.read()
456
+ if "USE_TORCH_SPARSE = False" in content:
457
+ print("βœ“ torch-sparse (disabled, using built-in PyTorch sparse)")
458
+ else:
459
+ print("βœ— torch-sparse (not available)")
460
+ except:
461
+ print("βœ— torch-sparse (not available)")
462
 
463
+ # Check if torch-scatter is available
464
+ try:
465
+ import torch_scatter
466
+ print("βœ“ torch-scatter")
467
+ except ImportError:
468
+ print("⚠ torch-scatter (not available, may not be critical)")
469
 
470
  import nvdiffrast
471
  print("βœ“ nvdiffrast")