Stylique commited on
Commit
29a3bb4
·
verified ·
1 Parent(s): 90a1041

Upload post_install.py

Browse files
Files changed (1) hide show
  1. post_install.py +81 -43
post_install.py CHANGED
@@ -10,6 +10,10 @@ import subprocess
10
  import shutil
11
  from pathlib import Path
12
 
 
 
 
 
13
  def run_command(command, cwd=None, env=None):
14
  """Run a shell command and return the result"""
15
  print(f"Running: {command}")
@@ -47,27 +51,32 @@ def install_torch_sparse():
47
  except ImportError:
48
  pass
49
 
50
- # Check if PyTorch is already installed with the correct version
51
  try:
52
  import torch
53
  print(f"Current PyTorch version: {torch.__version__}")
54
 
55
- # Check if we need to update PyTorch
56
- if torch.__version__.startswith("2.0.1") and "+cu117" in torch.__version__:
57
- print(f"PyTorch {torch.__version__} already installed with correct CUDA version")
 
58
  else:
59
- print(f"PyTorch {torch.__version__} installed, but need to update to 2.0.1+cu117")
60
- # Uninstall current PyTorch and reinstall with correct version
61
- print("Uninstalling current PyTorch...")
62
- run_command("pip uninstall torch torchvision torchaudio -y")
63
- print("Installing compatible PyTorch version...")
64
- if not run_command("pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117"):
65
- return False
 
 
 
66
  except ImportError:
67
- # First, install a compatible PyTorch version with CUDA 11.7 (as expected by PyTorch3D)
68
- print("Installing compatible PyTorch version...")
69
  if not run_command("pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117"):
70
  return False
 
 
71
 
72
  # Check PyTorch installation
73
  print("Checking PyTorch installation...")
@@ -83,8 +92,13 @@ def install_torch_sparse():
83
  print(f"Error checking PyTorch version: {e}")
84
 
85
  # Now install torch-sparse with the compatible version
86
- print("Installing torch-sparse with PyTorch 2.0.1...")
87
- if run_command("pip install torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu117.html"):
 
 
 
 
 
88
  print("Successfully installed torch-sparse")
89
 
90
  # Verify torch-sparse is compatible
@@ -97,6 +111,22 @@ def install_torch_sparse():
97
 
98
  return True
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  return False
101
 
102
  def install_torch_scatter():
@@ -112,8 +142,13 @@ def install_torch_scatter():
112
  pass
113
 
114
  # Install torch-scatter with the compatible PyTorch version
115
- print("Installing torch-scatter with PyTorch 2.0.1...")
116
- if run_command("pip install torch-scatter -f https://data.pyg.org/whl/torch-2.0.1+cu117.html"):
 
 
 
 
 
117
  print("Successfully installed torch-scatter")
118
 
119
  # Verify torch-scatter is compatible
@@ -126,6 +161,22 @@ def install_torch_scatter():
126
 
127
  return True
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  return False
130
 
131
  def install_nvdiffrast():
@@ -187,35 +238,21 @@ def install_pytorch3d():
187
  return True
188
 
189
  # Try different wheel URLs for different Python/PyTorch versions
190
- # First, try to determine the correct wheel URL based on current PyTorch version
191
- try:
192
- import torch
193
- torch_version = torch.__version__
194
- print(f"Determining PyTorch3D wheel URL for PyTorch {torch_version}")
195
-
196
- # Extract PyTorch major.minor version
197
- if "+" in torch_version:
198
- pytorch_base = torch_version.split("+")[0]
199
- else:
200
- pytorch_base = torch_version
201
-
202
- # Extract CUDA version
203
- cuda_version = "cu117" # default
204
- if "+cu" in torch_version:
205
- cuda_version = torch_version.split("+")[1]
206
-
207
- # Try specific wheel URL for current PyTorch version
208
- specific_url = f"https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_{cuda_version}_pyt{pytorch_base.replace('.', '')}/download.html"
209
- print(f"Trying specific wheel URL: {specific_url}")
210
- if run_command(f"pip install pytorch3d -f {specific_url}"):
211
- print(f"Successfully installed PyTorch3D with specific wheel URL")
212
- return True
213
- except Exception as e:
214
- print(f"Error determining specific wheel URL: {e}")
215
 
216
  # Fallback to known working wheel URLs
217
  wheel_urls = [
 
218
  "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu117_pyt201/download.html",
 
219
  "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu117_pyt201/download.html",
220
  "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu117_pyt201/download.html"
221
  ]
@@ -316,8 +353,9 @@ def main():
316
  print(f"✓ PyTorch {torch.__version__} - CUDA: {torch.cuda.is_available()}")
317
 
318
  # Check if PyTorch version is compatible
319
- if not torch.__version__.startswith("2.0.1") or "+cu117" not in torch.__version__:
320
  print(f"⚠ Warning: PyTorch version {torch.__version__} may not be compatible with installed extensions")
 
321
 
322
  import torch_sparse
323
  print("✓ torch-sparse")
 
10
  import shutil
11
  from pathlib import Path
12
 
13
+ # Global variables for PyTorch version detection
14
+ PYTORCH_VERSION = None
15
+ CUDA_VERSION = None
16
+
17
  def run_command(command, cwd=None, env=None):
18
  """Run a shell command and return the result"""
19
  print(f"Running: {command}")
 
51
  except ImportError:
52
  pass
53
 
54
+ # Check current PyTorch version and adapt to it
55
  try:
56
  import torch
57
  print(f"Current PyTorch version: {torch.__version__}")
58
 
59
+ # Extract PyTorch version info for compatibility
60
+ if "+" in torch.__version__:
61
+ pytorch_base = torch.__version__.split("+")[0]
62
+ cuda_version = torch.__version__.split("+")[1]
63
  else:
64
+ pytorch_base = torch.__version__
65
+ cuda_version = "cpu"
66
+
67
+ print(f"PyTorch base version: {pytorch_base}, CUDA: {cuda_version}")
68
+
69
+ # Store version info for later use
70
+ global PYTORCH_VERSION, CUDA_VERSION
71
+ PYTORCH_VERSION = pytorch_base
72
+ CUDA_VERSION = cuda_version
73
+
74
  except ImportError:
75
+ print("PyTorch not found, installing default version...")
 
76
  if not run_command("pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117"):
77
  return False
78
+ PYTORCH_VERSION = "2.0.1"
79
+ CUDA_VERSION = "cu117"
80
 
81
  # Check PyTorch installation
82
  print("Checking PyTorch installation...")
 
92
  print(f"Error checking PyTorch version: {e}")
93
 
94
  # Now install torch-sparse with the compatible version
95
+ print(f"Installing torch-sparse with PyTorch {PYTORCH_VERSION}...")
96
+
97
+ # Try to find compatible torch-sparse wheel
98
+ wheel_url = f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+{CUDA_VERSION}.html"
99
+ print(f"Trying wheel URL: {wheel_url}")
100
+
101
+ if run_command(f"pip install torch-sparse -f {wheel_url}"):
102
  print("Successfully installed torch-sparse")
103
 
104
  # Verify torch-sparse is compatible
 
111
 
112
  return True
113
 
114
+ # If the specific wheel fails, try alternative versions
115
+ print("Trying alternative torch-sparse versions...")
116
+ alternative_versions = [
117
+ f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+{CUDA_VERSION}.html",
118
+ f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+cu118.html",
119
+ f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+cu117.html",
120
+ f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+cu116.html"
121
+ ]
122
+
123
+ for url in alternative_versions:
124
+ if url != wheel_url: # Skip the one we already tried
125
+ print(f"Trying alternative URL: {url}")
126
+ if run_command(f"pip install torch-sparse -f {url}"):
127
+ print("Successfully installed torch-sparse with alternative version")
128
+ return True
129
+
130
  return False
131
 
132
  def install_torch_scatter():
 
142
  pass
143
 
144
  # Install torch-scatter with the compatible PyTorch version
145
+ print(f"Installing torch-scatter with PyTorch {PYTORCH_VERSION}...")
146
+
147
+ # Try to find compatible torch-scatter wheel
148
+ wheel_url = f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+{CUDA_VERSION}.html"
149
+ print(f"Trying wheel URL: {wheel_url}")
150
+
151
+ if run_command(f"pip install torch-scatter -f {wheel_url}"):
152
  print("Successfully installed torch-scatter")
153
 
154
  # Verify torch-scatter is compatible
 
161
 
162
  return True
163
 
164
+ # If the specific wheel fails, try alternative versions
165
+ print("Trying alternative torch-scatter versions...")
166
+ alternative_versions = [
167
+ f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+{CUDA_VERSION}.html",
168
+ f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+cu118.html",
169
+ f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+cu117.html",
170
+ f"https://data.pyg.org/whl/torch-{PYTORCH_VERSION}+cu116.html"
171
+ ]
172
+
173
+ for url in alternative_versions:
174
+ if url != wheel_url: # Skip the one we already tried
175
+ print(f"Trying alternative URL: {url}")
176
+ if run_command(f"pip install torch-scatter -f {url}"):
177
+ print("Successfully installed torch-scatter with alternative version")
178
+ return True
179
+
180
  return False
181
 
182
  def install_nvdiffrast():
 
238
  return True
239
 
240
  # Try different wheel URLs for different Python/PyTorch versions
241
+ # Use the detected PyTorch version
242
+ print(f"Determining PyTorch3D wheel URL for PyTorch {PYTORCH_VERSION}+{CUDA_VERSION}")
243
+
244
+ # Try specific wheel URL for current PyTorch version
245
+ specific_url = f"https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_{CUDA_VERSION}_pyt{PYTORCH_VERSION.replace('.', '')}/download.html"
246
+ print(f"Trying specific wheel URL: {specific_url}")
247
+ if run_command(f"pip install pytorch3d -f {specific_url}"):
248
+ print(f"Successfully installed PyTorch3D with specific wheel URL")
249
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
  # Fallback to known working wheel URLs
252
  wheel_urls = [
253
+ f"https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_{CUDA_VERSION}_pyt{PYTORCH_VERSION.replace('.', '')}/download.html",
254
  "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu117_pyt201/download.html",
255
+ "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu126_pyt271/download.html",
256
  "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py39_cu117_pyt201/download.html",
257
  "https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py38_cu117_pyt201/download.html"
258
  ]
 
353
  print(f"✓ PyTorch {torch.__version__} - CUDA: {torch.cuda.is_available()}")
354
 
355
  # Check if PyTorch version is compatible
356
+ if PYTORCH_VERSION and not torch.__version__.startswith(PYTORCH_VERSION):
357
  print(f"⚠ Warning: PyTorch version {torch.__version__} may not be compatible with installed extensions")
358
+ print(f"Expected version: {PYTORCH_VERSION}")
359
 
360
  import torch_sparse
361
  print("✓ torch-sparse")