Spaces:
Sleeping
Sleeping
File size: 8,671 Bytes
51c066f 2c96300 51c066f 2c96300 51c066f 2c96300 51c066f 2c96300 6612ab5 2c96300 6612ab5 2c96300 51c066f 2c96300 51c066f 2c96300 51c066f 2c96300 51c066f 2c96300 51c066f 2c96300 51c066f 2c96300 51c066f 2c96300 51c066f 2c96300 51c066f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 |
#!/usr/bin/env python3
"""
Smart dependency installer
Detects environment and installs appropriate PyTorch version
"""
import os
import sys
import subprocess
import platform
def detect_environment():
"""Detect if running on HF Spaces or local"""
is_hf_spaces = os.environ.get('SPACE_ID') is not None
return 'hf_spaces' if is_hf_spaces else 'local'
def detect_gpu_info():
"""Detect GPU model and CUDA version"""
gpu_model = None
cuda_version = None
try:
# Try nvidia-smi first
result = subprocess.run(
['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0:
gpu_model = result.stdout.strip()
print(f" Detected GPU: {gpu_model}")
# Try to get CUDA version from nvcc
try:
nvcc_result = subprocess.run(
['nvcc', '--version'],
capture_output=True,
text=True,
timeout=5
)
if nvcc_result.returncode == 0:
output = nvcc_result.stdout
# Parse CUDA version (e.g., "release 12.1")
if 'release' in output:
version = output.split('release')[1].strip().split(',')[0].strip()
major_minor = '.'.join(version.split('.')[:2])
print(f" Detected CUDA version: {major_minor}")
cuda_version = major_minor
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
# If nvcc not found, try to get CUDA version from nvidia-smi output
if not cuda_version:
result = subprocess.run(
['nvidia-smi'],
capture_output=True,
text=True,
timeout=5
)
for line in result.stdout.split('\n'):
if 'CUDA Version:' in line:
version = line.split('CUDA Version:')[1].strip().split()[0]
major_minor = '.'.join(version.split('.')[:2])
print(f" Detected CUDA version from nvidia-smi: {major_minor}")
cuda_version = major_minor
break
# GPU detected but CUDA version unknown, use latest
if not cuda_version:
print(" NVIDIA GPU detected but CUDA version unknown, using CUDA 12.4")
cuda_version = '12.4'
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
return gpu_model, cuda_version
def requires_pytorch_2_6(gpu_model):
"""Check if GPU requires PyTorch 2.6.0+ (for Blackwell/compute capability 12.0+)"""
if not gpu_model:
return False
# Blackwell GPUs (RTX 50xx series) require PyTorch 2.6.0+
blackwell_gpus = ['rtx 50', 'rtx50', '5080', '5090', '5070']
gpu_lower = gpu_model.lower()
return any(model in gpu_lower for model in blackwell_gpus)
def get_pytorch_install_command(env):
"""Get appropriate PyTorch install command for environment"""
if env == 'hf_spaces':
# ZeroGPU compatible version
return (['torch==2.2.0'], None)
else:
# Local environment
system = platform.system()
# Check if Apple Silicon
if system == 'Darwin' and platform.machine() == 'arm64':
print(" Detected Apple Silicon, installing PyTorch with MPS support")
return (['torch>=2.2.0'], None)
# Check for CUDA on Linux/Windows
elif system in ['Linux', 'Windows']:
gpu_model, cuda_version = detect_gpu_info()
if cuda_version:
# Check if GPU requires PyTorch 2.6.0+
needs_pytorch_2_6 = requires_pytorch_2_6(gpu_model)
if needs_pytorch_2_6:
print(f" β
Detected Blackwell GPU ({gpu_model})")
print(f" Installing PyTorch nightly with CUDA 12.8 support (sm_120 compatible)")
print(f" Note: RTX 5080 requires PyTorch built with CUDA 12.8+ for full support")
# Use nightly build for Blackwell GPU support with CUDA 12.8
return (['torch', 'torchvision', 'torchaudio'], 'https://download.pytorch.org/whl/nightly/cu128')
# Map CUDA version to PyTorch index URL
cuda_map = {
'11.8': ('cu118', 'https://download.pytorch.org/whl/cu118'),
'12.1': ('cu121', 'https://download.pytorch.org/whl/cu121'),
'12.2': ('cu121', 'https://download.pytorch.org/whl/cu121'), # Use 12.1 for 12.2
'12.3': ('cu121', 'https://download.pytorch.org/whl/cu121'), # Use 12.1 for 12.3
'12.4': ('cu124', 'https://download.pytorch.org/whl/cu124'),
'12.5': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.5
'12.6': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.6
'12.7': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.7
'12.8': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'), # CUDA 12.8 with sm_120 support
'13.0': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'), # Use 12.8 nightly for 13.0
}
cuda_suffix, index_url = cuda_map.get(cuda_version, ('cu124', 'https://download.pytorch.org/whl/cu124'))
print(f" Installing PyTorch with CUDA {cuda_version} support ({cuda_suffix})")
return (['torch', 'torchvision', 'torchaudio'], index_url)
else:
print(" No CUDA detected, installing CPU-only PyTorch")
return (['torch>=2.2.0'], None)
else:
# Other systems, default to CPU
return (['torch>=2.2.0'], None)
def install_dependencies():
"""Install dependencies based on environment"""
env = detect_environment()
print("=" * 60)
print(f"π Detected environment: {env}")
print("=" * 60)
# Get PyTorch installation command
pytorch_packages, index_url = get_pytorch_install_command(env)
# Base dependencies (excluding PyTorch)
base_deps = [
'gradio==5.49.1',
'transformers==4.57.1',
'safetensors==0.6.2',
'accelerate==0.26.1',
'sentencepiece==0.2.0',
'protobuf==4.25.1',
'huggingface-hub>=0.19.0',
'python-dotenv==1.0.0',
]
# Add spaces for HF Spaces only
if env == 'hf_spaces':
base_deps.append('spaces')
print("=" * 60)
print(f"π¦ Installing PyTorch...")
print("=" * 60)
# Install PyTorch (with optional index URL for CUDA)
pytorch_cmd = [sys.executable, '-m', 'pip', 'install', '--upgrade'] + pytorch_packages
if index_url:
pytorch_cmd.extend(['--index-url', index_url])
try:
subprocess.check_call(pytorch_cmd)
print("β
PyTorch installed successfully!")
except subprocess.CalledProcessError as e:
print(f"β PyTorch installation failed: {e}")
print(" Falling back to CPU-only PyTorch...")
subprocess.check_call([
sys.executable, '-m', 'pip', 'install', '--upgrade', 'torch>=2.2.0'
])
print("=" * 60)
print(f"π¦ Installing remaining dependencies ({len(base_deps)} packages)...")
print("=" * 60)
# Install remaining dependencies
subprocess.check_call([
sys.executable, '-m', 'pip', 'install', '--upgrade'
] + base_deps)
# Verify PyTorch installation
print("=" * 60)
print("π Verifying PyTorch installation...")
print("=" * 60)
try:
result = subprocess.run([
sys.executable, '-c',
'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA available: {torch.cuda.is_available()}"); print(f"CUDA version: {torch.version.cuda if torch.version.cuda else \"N/A\"}")'
], capture_output=True, text=True, timeout=10)
print(result.stdout)
except Exception as e:
print(f"β οΈ Could not verify PyTorch: {e}")
print("=" * 60)
print("β
Installation complete!")
print("=" * 60)
print(f"Environment: {env}")
print(f"PyTorch packages: {', '.join(pytorch_packages)}")
if index_url:
print(f"Index URL: {index_url}")
if __name__ == '__main__':
install_dependencies()
|