Camais03 commited on
Commit
9387c10
·
verified ·
1 Parent(s): 879b44d

Update setup.py

Browse files
Files changed (1) hide show
  1. setup.py +475 -340
setup.py CHANGED
@@ -1,341 +1,476 @@
1
- #!/usr/bin/env python3
2
- """
3
- Setup script for the Image Tagger application.
4
- This script checks and installs all required dependencies.
5
- """
6
-
7
- # Python 3.12+ compatibility patch for pkgutil.ImpImporter
8
- import sys
9
- if sys.version_info >= (3, 12):
10
- import pkgutil
11
- import importlib.machinery
12
-
13
- # Add ImpImporter as a compatibility shim for older packages
14
- if not hasattr(pkgutil, 'ImpImporter'):
15
- class ImpImporter:
16
- def __init__(self, path=None):
17
- self.path = path
18
-
19
- def find_module(self, fullname, path=None):
20
- return None
21
-
22
- pkgutil.ImpImporter = ImpImporter
23
-
24
- import os
25
- import sys
26
- import subprocess
27
- import platform
28
- from pathlib import Path
29
- import re
30
- import urllib.request
31
- import shutil
32
- import tempfile
33
- import time
34
- import webbrowser
35
-
36
- # Define the required packages
37
- SETUPTOOLS_PACKAGES = [
38
- "setuptools>=58.0.0",
39
- "setuptools-distutils>=0.3.0",
40
- "wheel>=0.38.0",
41
- ]
42
-
43
- REQUIRED_PACKAGES = [
44
- "streamlit>=1.21.0",
45
- "pillow>=9.0.0",
46
- # CRITICAL: Pin NumPy to 1.24.x and prevent 2.x installation
47
- "numpy>=1.24.0,<2.0.0",
48
- "ninja>=1.10.0",
49
- "packaging>=20.0",
50
- "matplotlib>=3.5.0",
51
- "tqdm>=4.62.0",
52
- "scipy>=1.7.0",
53
- "safetensors>=0.3.0",
54
- "timm>=0.9.0", # Add this line for PyTorch Image Models
55
- ]
56
-
57
- # Packages to install after PyTorch
58
- POST_TORCH_PACKAGES = [
59
- "einops>=0.6.1",
60
- ]
61
-
62
- CUDA_PACKAGES = {
63
- "11.8": "torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
64
- "11.7": "torch==2.0.1+cu117 torchvision==0.15.2+cu117 --index-url https://download.pytorch.org/whl/cu117",
65
- "11.6": "torch==2.0.1+cu116 torchvision==0.15.2+cu116 --index-url https://download.pytorch.org/whl/cu116",
66
- "cpu": "torch==2.0.1+cpu torchvision==0.15.2+cpu --index-url https://download.pytorch.org/whl/cpu"
67
- }
68
-
69
- # ONNX and acceleration packages
70
- ONNX_PACKAGES = [
71
- "onnx>=1.14.0",
72
- "onnxruntime>=1.15.0",
73
- "onnxruntime-gpu>=1.15.0;platform_system!='Darwin'",
74
- ]
75
-
76
- # Colors for terminal output
77
- class Colors:
78
- HEADER = '\033[95m'
79
- BLUE = '\033[94m'
80
- GREEN = '\033[92m'
81
- WARNING = '\033[93m'
82
- FAIL = '\033[91m'
83
- ENDC = '\033[0m'
84
- BOLD = '\033[1m'
85
-
86
- def print_colored(text, color):
87
- """Print text in color"""
88
- if sys.platform == "win32":
89
- print(text)
90
- else:
91
- print(f"{color}{text}{Colors.ENDC}")
92
-
93
- def check_and_fix_numpy():
94
- """Check for NumPy 2.x and fix compatibility issues"""
95
- print_colored("\nChecking NumPy compatibility...", Colors.BLUE)
96
-
97
- pip_path = get_venv_pip()
98
-
99
- try:
100
- result = subprocess.run([pip_path, "show", "numpy"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
101
- if result.returncode == 0:
102
- version_match = re.search(r"Version: ([\d\.]+)", result.stdout)
103
- if version_match:
104
- current_version = version_match.group(1)
105
- print_colored(f"Found NumPy {current_version}", Colors.BLUE)
106
-
107
- if current_version.startswith("2."):
108
- print_colored(f"ERROR: NumPy {current_version} is incompatible with PyTorch!", Colors.FAIL)
109
- print_colored("Uninstalling NumPy 2.x and installing compatible version...", Colors.BLUE)
110
-
111
- # Force uninstall NumPy 2.x
112
- subprocess.run([pip_path, "uninstall", "-y", "numpy"], check=True)
113
-
114
- # Install compatible NumPy version with constraints
115
- subprocess.run([pip_path, "install", "numpy>=1.24.0,<2.0.0"], check=True)
116
-
117
- print_colored("[OK] NumPy downgraded to compatible version", Colors.GREEN)
118
- return True
119
-
120
- elif current_version.startswith("1.24."):
121
- print_colored(f"[OK] NumPy {current_version} is compatible", Colors.GREEN)
122
- return False
123
- else:
124
- print_colored(f"Updating NumPy to recommended version...", Colors.BLUE)
125
- subprocess.run([pip_path, "install", "numpy>=1.24.0,<2.0.0"], check=True)
126
- return True
127
- except Exception as e:
128
- print_colored(f"Error checking NumPy: {e}", Colors.WARNING)
129
-
130
- return False
131
-
132
- def install_packages(cuda_version):
133
- """Install required packages using pip"""
134
- print_colored("\nInstalling required packages...", Colors.BLUE)
135
-
136
- pip_path = get_venv_pip()
137
-
138
- # Upgrade pip first
139
- try:
140
- subprocess.run([pip_path, "install", "--upgrade", "pip"], check=True)
141
- print_colored("[OK] Pip upgraded successfully", Colors.GREEN)
142
- except subprocess.CalledProcessError:
143
- print_colored("Warning: Failed to upgrade pip", Colors.WARNING)
144
-
145
- # Install setuptools packages first
146
- print_colored("\nInstalling setuptools...", Colors.BLUE)
147
- for package in SETUPTOOLS_PACKAGES:
148
- try:
149
- subprocess.run([pip_path, "install", package], check=True)
150
- print_colored(f"[OK] Installed {package}", Colors.GREEN)
151
- except subprocess.CalledProcessError as e:
152
- print_colored(f"Warning: Issue installing {package}: {e}", Colors.WARNING)
153
-
154
- # Check and fix NumPy compatibility before installing other packages
155
- numpy_was_updated = check_and_fix_numpy()
156
-
157
- # Install base packages
158
- for package in REQUIRED_PACKAGES:
159
- try:
160
- print_colored(f"Installing {package}...", Colors.BLUE)
161
- subprocess.run([pip_path, "install", package], check=True)
162
- print_colored(f"[OK] Installed {package}", Colors.GREEN)
163
- except subprocess.CalledProcessError as e:
164
- print_colored(f"Error installing {package}: {e}", Colors.FAIL)
165
- return False
166
-
167
- # If NumPy was updated, we need to reinstall PyTorch to ensure compatibility
168
- if numpy_was_updated:
169
- print_colored("\nNumPy was updated, ensuring PyTorch compatibility...", Colors.BLUE)
170
- try:
171
- # Uninstall existing PyTorch
172
- subprocess.run([pip_path, "uninstall", "-y", "torch", "torchvision"], check=False)
173
- except:
174
- pass # Ignore errors if not installed
175
-
176
- # Install PyTorch with appropriate CUDA version
177
- print_colored(f"\nInstalling PyTorch {'with CUDA support' if cuda_version != 'cpu' else '(CPU version)'}...", Colors.BLUE)
178
- torch_command = CUDA_PACKAGES[cuda_version].split()
179
- try:
180
- subprocess.run([pip_path, "install"] + torch_command, check=True)
181
- print_colored("[OK] PyTorch installed successfully", Colors.GREEN)
182
- except subprocess.CalledProcessError as e:
183
- print_colored(f"Error installing PyTorch: {e}", Colors.FAIL)
184
- return False
185
-
186
- # Install post-PyTorch packages
187
- for package in POST_TORCH_PACKAGES:
188
- try:
189
- subprocess.run([pip_path, "install", package], check=True)
190
- print_colored(f"[OK] Installed {package}", Colors.GREEN)
191
- except subprocess.CalledProcessError as e:
192
- print_colored(f"Error installing {package}: {e}", Colors.FAIL)
193
- return False
194
-
195
- # Final NumPy compatibility check
196
- print_colored("\nPerforming final compatibility check...", Colors.BLUE)
197
- try:
198
- # Test import in the virtual environment
199
- python_path = get_venv_python()
200
- test_cmd = [python_path, "-c", "import torch; import torchvision; import numpy; print('All imports successful')"]
201
- result = subprocess.run(test_cmd, capture_output=True, text=True)
202
-
203
- if result.returncode == 0:
204
- print_colored("[OK] All packages are compatible", Colors.GREEN)
205
- else:
206
- print_colored(f"Warning: Compatibility test failed: {result.stderr}", Colors.WARNING)
207
- # Try to fix by reinstalling with --force-reinstall
208
- print_colored("Attempting to fix with force reinstall...", Colors.BLUE)
209
- subprocess.run([pip_path, "install", "--force-reinstall", "numpy>=1.24.0,<2.0.0"], check=True)
210
-
211
- except Exception as e:
212
- print_colored(f"Warning: Could not perform compatibility check: {e}", Colors.WARNING)
213
-
214
- return True
215
-
216
- def check_python_version():
217
- """Check if Python version is 3.8 or higher"""
218
- print_colored("Checking Python version...", Colors.BLUE)
219
-
220
- version = sys.version_info
221
- if version.major < 3 or (version.major == 3 and version.minor < 8):
222
- print_colored("Error: Python 3.8 or higher is required. You have " + sys.version, Colors.FAIL)
223
- return False
224
-
225
- print_colored(f"[OK] Python {version.major}.{version.minor}.{version.micro} detected", Colors.GREEN)
226
- return True
227
-
228
- def create_virtual_env():
229
- """Create a virtual environment if one doesn't exist"""
230
- print_colored("\nChecking for virtual environment...", Colors.BLUE)
231
-
232
- venv_path = Path("venv")
233
- if venv_path.exists():
234
- print_colored("[OK] Virtual environment already exists", Colors.GREEN)
235
- return True
236
-
237
- print_colored("Creating a new virtual environment...", Colors.BLUE)
238
- try:
239
- subprocess.run([sys.executable, "-m", "venv", "venv"], check=True)
240
- print_colored("[OK] Virtual environment created successfully", Colors.GREEN)
241
- return True
242
- except subprocess.CalledProcessError:
243
- print_colored("Error: Failed to create virtual environment", Colors.FAIL)
244
- return False
245
-
246
- def get_venv_python():
247
- """Get path to Python in the virtual environment"""
248
- if sys.platform == "win32":
249
- return os.path.join("venv", "Scripts", "python.exe")
250
- else:
251
- return os.path.join("venv", "bin", "python")
252
-
253
- def get_venv_pip():
254
- """Get path to pip in the virtual environment"""
255
- if sys.platform == "win32":
256
- return os.path.join("venv", "Scripts", "pip.exe")
257
- else:
258
- return os.path.join("venv", "bin", "pip")
259
-
260
- def check_cuda():
261
- """Check CUDA availability and version"""
262
- print_colored("\nChecking for CUDA...", Colors.BLUE)
263
-
264
- cuda_available = False
265
- cuda_version = None
266
-
267
- try:
268
- if sys.platform == "win32":
269
- process = subprocess.run(["where", "nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
270
- else:
271
- process = subprocess.run(["which", "nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
272
-
273
- if process.returncode == 0:
274
- nvidia_smi = subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
275
- if nvidia_smi.returncode == 0:
276
- cuda_available = True
277
- match = re.search(r"CUDA Version: (\d+\.\d+)", nvidia_smi.stdout)
278
- if match:
279
- cuda_version = match.group(1)
280
- except Exception as e:
281
- print_colored(f"Error checking CUDA: {str(e)}", Colors.WARNING)
282
-
283
- if cuda_available and cuda_version:
284
- print_colored(f"[OK] CUDA {cuda_version} detected", Colors.GREEN)
285
- for supported_version in CUDA_PACKAGES.keys():
286
- if supported_version != "cpu" and float(supported_version) <= float(cuda_version):
287
- return supported_version
288
-
289
- print_colored("No CUDA detected, using CPU-only version", Colors.WARNING)
290
- return "cpu"
291
-
292
- def install_onnx_packages(cuda_version):
293
- """Install ONNX packages"""
294
- print_colored("\nInstalling ONNX packages...", Colors.BLUE)
295
-
296
- pip_path = get_venv_pip()
297
-
298
- try:
299
- subprocess.run([pip_path, "install", "onnx>=1.14.0"], check=True)
300
-
301
- if cuda_version != "cpu":
302
- subprocess.run([pip_path, "install", "onnxruntime-gpu>=1.15.0"], check=True)
303
- else:
304
- subprocess.run([pip_path, "install", "onnxruntime>=1.15.0"], check=True)
305
-
306
- print_colored("[OK] ONNX packages installed", Colors.GREEN)
307
- except subprocess.CalledProcessError as e:
308
- print_colored(f"Warning: ONNX installation issues: {e}", Colors.WARNING)
309
-
310
- return True
311
-
312
- def main():
313
- """Main setup function"""
314
- print_colored("=" * 60, Colors.HEADER)
315
- print_colored(" Image Tagger - Setup Script", Colors.HEADER)
316
- print_colored("=" * 60, Colors.HEADER)
317
-
318
- if not check_python_version():
319
- return False
320
-
321
- if not create_virtual_env():
322
- return False
323
-
324
- cuda_version = check_cuda()
325
-
326
- if not install_packages(cuda_version):
327
- return False
328
-
329
- if not install_onnx_packages(cuda_version):
330
- print_colored("Warning: ONNX packages had issues", Colors.WARNING)
331
-
332
- print_colored("\n" + "=" * 60, Colors.HEADER)
333
- print_colored(" Setup completed successfully!", Colors.GREEN)
334
- print_colored("=" * 60, Colors.HEADER)
335
-
336
- return True
337
-
338
- if __name__ == "__main__":
339
- success = main()
340
- if not success:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  sys.exit(1)
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Setup script for the Image Tagger application.
4
+ This script checks and installs all required dependencies.
5
+ """
6
+
7
+ # Python 3.12+ compatibility patch for pkgutil.ImpImporter
8
+ import sys
9
+ if sys.version_info >= (3, 12):
10
+ import pkgutil
11
+ import importlib.machinery
12
+
13
+ # Add ImpImporter as a compatibility shim for older packages
14
+ if not hasattr(pkgutil, 'ImpImporter'):
15
+ class ImpImporter:
16
+ def __init__(self, path=None):
17
+ self.path = path
18
+
19
+ def find_module(self, fullname, path=None):
20
+ return None
21
+
22
+ pkgutil.ImpImporter = ImpImporter
23
+
24
+ import os
25
+ import sys
26
+ import subprocess
27
+ import platform
28
+ from pathlib import Path
29
+ import re
30
+ import urllib.request
31
+ import shutil
32
+ import tempfile
33
+ import time
34
+ import webbrowser
35
+
36
+ # Define the required packages
37
+ SETUPTOOLS_PACKAGES = [
38
+ "setuptools>=58.0.0",
39
+ "setuptools-distutils>=0.3.0",
40
+ "wheel>=0.38.0",
41
+ ]
42
+
43
+ REQUIRED_PACKAGES = [
44
+ "streamlit>=1.21.0",
45
+ "pillow>=9.0.0",
46
+ "ninja>=1.10.0",
47
+ "packaging>=20.0",
48
+ "matplotlib>=3.5.0",
49
+ "tqdm>=4.62.0",
50
+ "scipy>=1.7.0",
51
+ "safetensors>=0.3.0",
52
+ ]
53
+
54
+
55
+ # Packages to install after PyTorch
56
+ POST_TORCH_PACKAGES = [
57
+ "einops>=0.6.1",
58
+ "timm>=0.9.0",
59
+ ]
60
+
61
+
62
+ CUDA_PACKAGES = {
63
+ "11.8": "torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
64
+ "11.7": "torch==2.0.1+cu117 torchvision==0.15.2+cu117 --index-url https://download.pytorch.org/whl/cu117",
65
+ "11.6": "torch==2.0.1+cu116 torchvision==0.15.2+cu116 --index-url https://download.pytorch.org/whl/cu116",
66
+ "cpu": "torch==2.0.1+cpu torchvision==0.15.2+cpu --index-url https://download.pytorch.org/whl/cpu"
67
+ }
68
+
69
+ # ONNX and acceleration packages
70
+ ONNX_PACKAGES = [
71
+ "onnx>=1.14.0",
72
+ "onnxruntime>=1.15.0",
73
+ "onnxruntime-gpu>=1.15.0;platform_system!='Darwin'",
74
+ ]
75
+
76
+ # Colors for terminal output
77
+ class Colors:
78
+ HEADER = '\033[95m'
79
+ BLUE = '\033[94m'
80
+ GREEN = '\033[92m'
81
+ WARNING = '\033[93m'
82
+ FAIL = '\033[91m'
83
+ ENDC = '\033[0m'
84
+ BOLD = '\033[1m'
85
+
86
+ def print_colored(text, color):
87
+ """Print text in color"""
88
+ if sys.platform == "win32":
89
+ print(text)
90
+ else:
91
+ print(f"{color}{text}{Colors.ENDC}")
92
+
93
+ def ensure_compatible_numpy() -> bool:
94
+ """
95
+ Install a NumPy version compatible with the current Python version.
96
+
97
+ - Python < 3.13: NumPy 1.24.x–1.x (we keep <2 to avoid ecosystem surprises)
98
+ - Python >= 3.13: NumPy 2.1+ (since 1.26 wheels don't support 3.13)
99
+ """
100
+ print_colored("\nEnsuring NumPy compatibility...", Colors.BLUE)
101
+
102
+ pip_path = get_venv_pip()
103
+
104
+ if sys.version_info >= (3, 13):
105
+ desired_spec = "numpy>=2.1.0,<3.0.0"
106
+ desired_major = 2
107
+ else:
108
+ desired_spec = "numpy>=1.24.0,<2.0.0"
109
+ desired_major = 1
110
+
111
+ def parse_ver(text: str):
112
+ m = re.search(r"Version:\s*([0-9]+)\.([0-9]+)\.([0-9]+)", text)
113
+ if not m:
114
+ return None
115
+ return (int(m.group(1)), int(m.group(2)), int(m.group(3)))
116
+
117
+ try:
118
+ show = subprocess.run(
119
+ [pip_path, "show", "numpy"],
120
+ stdout=subprocess.PIPE,
121
+ stderr=subprocess.PIPE,
122
+ text=True
123
+ )
124
+
125
+ if show.returncode == 0:
126
+ ver = parse_ver(show.stdout)
127
+ if ver:
128
+ major, minor, patch = ver
129
+ print_colored(f"Found NumPy {major}.{minor}.{patch}", Colors.BLUE)
130
+
131
+ ok = (major == desired_major) and (major != 1 or minor >= 24)
132
+ if ok:
133
+ print_colored("[OK] NumPy version is compatible", Colors.GREEN)
134
+ return False
135
+
136
+ print_colored("NumPy version not compatible for this Python; reinstalling...", Colors.BLUE)
137
+ subprocess.run([pip_path, "uninstall", "-y", "numpy"], check=False)
138
+
139
+ print_colored(f"Installing {desired_spec}...", Colors.BLUE)
140
+ subprocess.run([pip_path, "install", desired_spec], check=True)
141
+ print_colored(f"[OK] Installed {desired_spec}", Colors.GREEN)
142
+ return True
143
+
144
+ except subprocess.CalledProcessError as e:
145
+ print_colored(f"Error installing NumPy: {e}", Colors.FAIL)
146
+ return False
147
+
148
+ def install_packages(cuda_version):
149
+ """Install required packages using pip"""
150
+ print_colored("\nInstalling required packages...", Colors.BLUE)
151
+
152
+ pip_path = get_venv_pip()
153
+
154
+ # Upgrade pip first
155
+ try:
156
+ subprocess.run([pip_path, "install", "--upgrade", "pip"], check=True)
157
+ print_colored("[OK] Pip upgraded successfully", Colors.GREEN)
158
+ except subprocess.CalledProcessError:
159
+ print_colored("Warning: Failed to upgrade pip", Colors.WARNING)
160
+
161
+ # Install setuptools packages first
162
+ print_colored("\nInstalling setuptools...", Colors.BLUE)
163
+ for package in SETUPTOOLS_PACKAGES:
164
+ try:
165
+ subprocess.run([pip_path, "install", package], check=True)
166
+ print_colored(f"[OK] Installed {package}", Colors.GREEN)
167
+ except subprocess.CalledProcessError as e:
168
+ print_colored(f"Warning: Issue installing {package}: {e}", Colors.WARNING)
169
+
170
+ # Check and fix NumPy compatibility before installing other packages
171
+ numpy_was_updated = ensure_compatible_numpy()
172
+
173
+ # Install base packages
174
+ for package in REQUIRED_PACKAGES:
175
+ try:
176
+ print_colored(f"Installing {package}...", Colors.BLUE)
177
+ subprocess.run([pip_path, "install", package], check=True)
178
+ print_colored(f"[OK] Installed {package}", Colors.GREEN)
179
+ except subprocess.CalledProcessError as e:
180
+ print_colored(f"Error installing {package}: {e}", Colors.FAIL)
181
+ return False
182
+
183
+ # Install PyTorch with appropriate CUDA version
184
+ print_colored(
185
+ f"\nInstalling PyTorch ({'GPU '+cuda_version if cuda_version != 'cpu' else 'CPU-only'})...",
186
+ Colors.BLUE
187
+ )
188
+ if not install_pytorch(cuda_version):
189
+ return False
190
+
191
+ # Install post-PyTorch packages
192
+ for package in POST_TORCH_PACKAGES:
193
+ try:
194
+ subprocess.run([pip_path, "install", package], check=True)
195
+ print_colored(f"[OK] Installed {package}", Colors.GREEN)
196
+ except subprocess.CalledProcessError as e:
197
+ print_colored(f"Error installing {package}: {e}", Colors.FAIL)
198
+ return False
199
+
200
+ # Final NumPy compatibility check
201
+ print_colored("\nPerforming final compatibility check...", Colors.BLUE)
202
+ try:
203
+ # Test import in the virtual environment
204
+ python_path = get_venv_python()
205
+ test_cmd = [python_path, "-c", "import torch; import torchvision; import numpy; print('All imports successful')"]
206
+ result = subprocess.run(test_cmd, capture_output=True, text=True)
207
+
208
+ if result.returncode == 0:
209
+ print_colored("[OK] All packages import successfully", Colors.GREEN)
210
+ else:
211
+ print_colored(f"Warning: Import test failed:\n{result.stderr}", Colors.WARNING)
212
+
213
+ except Exception as e:
214
+ print_colored(f"Warning: Could not perform compatibility check: {e}", Colors.WARNING)
215
+
216
+ return True
217
+
218
+ def check_python_version():
219
+ v = sys.version_info
220
+ if (v.major, v.minor) < (3, 10) or (v.major, v.minor) >= (3, 14):
221
+ print_colored(
222
+ f"Error: Python 3.10–3.13 required. You have {v.major}.{v.minor}.",
223
+ Colors.FAIL
224
+ )
225
+ return False
226
+ return True
227
+
228
+ def create_virtual_env():
229
+ """Create a virtual environment if one doesn't exist"""
230
+ print_colored("\nChecking for virtual environment...", Colors.BLUE)
231
+
232
+ venv_path = Path("venv")
233
+ if venv_path.exists():
234
+ print_colored("[OK] Virtual environment already exists", Colors.GREEN)
235
+ return True
236
+
237
+ print_colored("Creating a new virtual environment...", Colors.BLUE)
238
+ try:
239
+ subprocess.run([sys.executable, "-m", "venv", "venv"], check=True)
240
+ print_colored("[OK] Virtual environment created successfully", Colors.GREEN)
241
+ return True
242
+ except subprocess.CalledProcessError:
243
+ print_colored("Error: Failed to create virtual environment", Colors.FAIL)
244
+ return False
245
+
246
+ def get_venv_python():
247
+ """Get path to Python in the virtual environment"""
248
+ if sys.platform == "win32":
249
+ return os.path.join("venv", "Scripts", "python.exe")
250
+ else:
251
+ return os.path.join("venv", "bin", "python")
252
+
253
+ def get_venv_pip():
254
+ """Get path to pip in the virtual environment"""
255
+ if sys.platform == "win32":
256
+ return os.path.join("venv", "Scripts", "pip.exe")
257
+ else:
258
+ return os.path.join("venv", "bin", "pip")
259
+
260
+ def check_cuda():
261
+ """Check CUDA availability and version"""
262
+ print_colored("\nChecking for CUDA...", Colors.BLUE)
263
+
264
+ cuda_available = False
265
+ cuda_version = None
266
+
267
+ try:
268
+ if sys.platform == "win32":
269
+ process = subprocess.run(["where", "nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
270
+ else:
271
+ process = subprocess.run(["which", "nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
272
+
273
+ if process.returncode == 0:
274
+ nvidia_smi = subprocess.run(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
275
+ if nvidia_smi.returncode == 0:
276
+ cuda_available = True
277
+ match = re.search(r"CUDA Version: (\d+\.\d+)", nvidia_smi.stdout)
278
+ if match:
279
+ cuda_version = match.group(1)
280
+ except Exception as e:
281
+ print_colored(f"Error checking CUDA: {str(e)}", Colors.WARNING)
282
+
283
+ if cuda_available and cuda_version:
284
+ print_colored(f"[OK] CUDA {cuda_version} detected", Colors.GREEN)
285
+ return _pick_pytorch_tag_from_driver(cuda_version)
286
+
287
+ print_colored("No CUDA detected, using CPU-only version", Colors.WARNING)
288
+ return "cpu"
289
+
290
+
291
+ # Order matters: we try the newest runtime your driver supports, then fall back.
292
+ PYTORCH_CUDA_TAGS = [
293
+ (13.0, "cu130"),
294
+ (12.9, "cu129"),
295
+ (12.8, "cu128"),
296
+ (12.6, "cu126"),
297
+ (12.4, "cu124"),
298
+ (12.1, "cu121"),
299
+ (11.8, "cu118"),
300
+ ]
301
+
302
+ def _pick_pytorch_tag_from_driver(cuda_version_str: str) -> str:
303
+ try:
304
+ v = float(cuda_version_str)
305
+ except Exception:
306
+ return "cpu"
307
+
308
+ for min_v, tag in PYTORCH_CUDA_TAGS:
309
+ if v >= min_v:
310
+ return tag
311
+ return "cpu"
312
+
313
+ def _torch_install_attempt(pip_path: str, index_url: str, pre: bool = False) -> bool:
314
+ # Always add PyPI as extra index so dependencies are resolvable.
315
+ cmd = [pip_path, "install", "--upgrade"]
316
+ if pre:
317
+ cmd.append("--pre")
318
+ cmd += [
319
+ "torch", "torchvision",
320
+ "--index-url", index_url,
321
+ "--extra-index-url", "https://pypi.org/simple",
322
+ ]
323
+ try:
324
+ subprocess.run(cmd, check=True)
325
+ return True
326
+ except subprocess.CalledProcessError:
327
+ return False
328
+
329
+ def install_pytorch(cuda_tag: str) -> bool:
330
+ """
331
+ Install torch/torchvision.
332
+ - On macOS: install from PyPI (MPS/CPU handled by upstream)
333
+ - On Windows/Linux: install from the PyTorch CUDA/CPU index matching the driver.
334
+ - If Python 3.13+ and stable fails, try nightly as a fallback.
335
+ """
336
+ pip_path = get_venv_pip()
337
+
338
+ print_colored("\nInstalling PyTorch...", Colors.BLUE)
339
+
340
+ # Clean out any prior torch installs (helps reruns / partial installs)
341
+ subprocess.run([pip_path, "uninstall", "-y", "torch", "torchvision", "torchaudio"], check=False)
342
+
343
+ if sys.platform == "darwin":
344
+ try:
345
+ subprocess.run([pip_path, "install", "--upgrade", "torch", "torchvision"], check=True)
346
+ print_colored("[OK] PyTorch installed (macOS)", Colors.GREEN)
347
+ return True
348
+ except subprocess.CalledProcessError as e:
349
+ print_colored(f"Error installing PyTorch on macOS: {e}", Colors.FAIL)
350
+ return False
351
+
352
+ # Build a fallback list from the detected tag downwards, then CPU last.
353
+ all_tags = [tag for _, tag in PYTORCH_CUDA_TAGS]
354
+ if cuda_tag == "cpu":
355
+ candidates = ["cpu"]
356
+ else:
357
+ if cuda_tag in all_tags:
358
+ start = all_tags.index(cuda_tag)
359
+ candidates = all_tags[start:] + ["cpu"]
360
+ else:
361
+ candidates = ["cu118", "cpu"]
362
+
363
+ # 1) Try stable
364
+ for tag in candidates:
365
+ index_url = f"https://download.pytorch.org/whl/{tag}"
366
+ print_colored(f"Trying stable PyTorch from {index_url} ...", Colors.BLUE)
367
+ if _torch_install_attempt(pip_path, index_url=index_url, pre=False):
368
+ print_colored(f"[OK] PyTorch installed successfully ({tag})", Colors.GREEN)
369
+ return True
370
+
371
+ # 2) If on 3.13+, try nightly (optional but helps early adopters)
372
+ if sys.version_info >= (3, 13):
373
+ for tag in candidates:
374
+ index_url = f"https://download.pytorch.org/whl/nightly/{tag}"
375
+ print_colored(f"Trying nightly PyTorch from {index_url} ...", Colors.BLUE)
376
+ if _torch_install_attempt(pip_path, index_url=index_url, pre=True):
377
+ print_colored(f"[OK] PyTorch nightly installed successfully ({tag})", Colors.GREEN)
378
+ return True
379
+
380
+ print_colored(
381
+ "Error: Could not install a compatible PyTorch build for this Python/CUDA combo.\n"
382
+ "Tip: Python 3.12 is the most broadly supported choice if you hit this.",
383
+ Colors.FAIL
384
+ )
385
+ return False
386
+
387
+ def _try_pip_install(pip_path: str, spec: str) -> bool:
388
+ try:
389
+ subprocess.run([pip_path, "install", spec], check=True)
390
+ return True
391
+ except subprocess.CalledProcessError:
392
+ return False
393
+
394
+ def install_onnx_packages(cuda_version):
395
+ """Install ONNX + an available ONNX Runtime for this OS/Python."""
396
+ print_colored("\nInstalling ONNX packages...", Colors.BLUE)
397
+ pip_path = get_venv_pip()
398
+
399
+ # ONNX itself is fine
400
+ if not _try_pip_install(pip_path, "onnx>=1.14.0"):
401
+ print_colored("Error: failed to install onnx", Colors.FAIL)
402
+ return False
403
+
404
+ # Python 3.14+ currently often has no official onnxruntime* wheels on PyPI
405
+ # (PyPI classifiers are up to 3.13 for onnxruntime/onnxruntime-gpu).
406
+ if sys.version_info >= (3, 14):
407
+ print_colored(
408
+ "Warning: ONNX Runtime wheels are typically not available on PyPI for Python 3.14+ yet.\n"
409
+ "ONNX installed, but runtime install is skipped. Use Python 3.12/3.13 for ONNX Runtime.",
410
+ Colors.WARNING
411
+ )
412
+
413
+ # Decide what to try for runtime
414
+ is_windows = (sys.platform == "win32")
415
+ has_cuda = (cuda_version != "cpu") # in your script this means NVIDIA/CUDA detected
416
+
417
+ # Preferred order:
418
+ # - If CUDA detected: onnxruntime-gpu
419
+ # - On Windows without CUDA: onnxruntime-directml (GPU via DirectML)
420
+ # - Fallback: onnxruntime (CPU)
421
+ tried = []
422
+
423
+ if has_cuda:
424
+ tried.append("onnxruntime-gpu>=1.15.0")
425
+ if _try_pip_install(pip_path, tried[-1]):
426
+ print_colored("[OK] Installed onnxruntime-gpu", Colors.GREEN)
427
+ return True
428
+
429
+ if is_windows:
430
+ tried.append("onnxruntime-directml>=1.15.0")
431
+ if _try_pip_install(pip_path, tried[-1]):
432
+ print_colored("[OK] Installed onnxruntime-directml", Colors.GREEN)
433
+ return True
434
+
435
+ tried.append("onnxruntime>=1.15.0")
436
+ if _try_pip_install(pip_path, tried[-1]):
437
+ print_colored("[OK] Installed onnxruntime (CPU)", Colors.GREEN)
438
+ return True
439
+
440
+ print_colored(
441
+ "Warning: Could not install any ONNX Runtime variant.\n"
442
+ f"Tried: {', '.join(tried)}",
443
+ Colors.WARNING
444
+ )
445
+ return False
446
+
447
+ def main():
448
+ """Main setup function"""
449
+ print_colored("=" * 60, Colors.HEADER)
450
+ print_colored(" Image Tagger - Setup Script", Colors.HEADER)
451
+ print_colored("=" * 60, Colors.HEADER)
452
+
453
+ if not check_python_version():
454
+ return False
455
+
456
+ if not create_virtual_env():
457
+ return False
458
+
459
+ cuda_version = check_cuda()
460
+
461
+ if not install_packages(cuda_version):
462
+ return False
463
+
464
+ if not install_onnx_packages(cuda_version):
465
+ print_colored("Warning: ONNX packages had issues", Colors.WARNING)
466
+
467
+ print_colored("\n" + "=" * 60, Colors.HEADER)
468
+ print_colored(" Setup completed successfully!", Colors.GREEN)
469
+ print_colored("=" * 60, Colors.HEADER)
470
+
471
+ return True
472
+
473
+ if __name__ == "__main__":
474
+ success = main()
475
+ if not success:
476
  sys.exit(1)