#!/usr/bin/env python3 """ Smart dependency installer Detects environment and installs appropriate PyTorch version """ import os import sys import subprocess import platform def detect_environment(): """Detect if running on HF Spaces or local""" is_hf_spaces = os.environ.get('SPACE_ID') is not None return 'hf_spaces' if is_hf_spaces else 'local' def detect_gpu_info(): """Detect GPU model and CUDA version""" gpu_model = None cuda_version = None try: # Try nvidia-smi first result = subprocess.run( ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'], capture_output=True, text=True, timeout=5 ) if result.returncode == 0: gpu_model = result.stdout.strip() print(f" Detected GPU: {gpu_model}") # Try to get CUDA version from nvcc try: nvcc_result = subprocess.run( ['nvcc', '--version'], capture_output=True, text=True, timeout=5 ) if nvcc_result.returncode == 0: output = nvcc_result.stdout # Parse CUDA version (e.g., "release 12.1") if 'release' in output: version = output.split('release')[1].strip().split(',')[0].strip() major_minor = '.'.join(version.split('.')[:2]) print(f" Detected CUDA version: {major_minor}") cuda_version = major_minor except (FileNotFoundError, subprocess.TimeoutExpired): pass # If nvcc not found, try to get CUDA version from nvidia-smi output if not cuda_version: result = subprocess.run( ['nvidia-smi'], capture_output=True, text=True, timeout=5 ) for line in result.stdout.split('\n'): if 'CUDA Version:' in line: version = line.split('CUDA Version:')[1].strip().split()[0] major_minor = '.'.join(version.split('.')[:2]) print(f" Detected CUDA version from nvidia-smi: {major_minor}") cuda_version = major_minor break # GPU detected but CUDA version unknown, use latest if not cuda_version: print(" NVIDIA GPU detected but CUDA version unknown, using CUDA 12.4") cuda_version = '12.4' except (FileNotFoundError, subprocess.TimeoutExpired): pass return gpu_model, cuda_version def requires_pytorch_2_6(gpu_model): """Check if GPU requires PyTorch 2.6.0+ (for Blackwell/compute capability 12.0+)""" if not gpu_model: return False # Blackwell GPUs (RTX 50xx series) require PyTorch 2.6.0+ blackwell_gpus = ['rtx 50', 'rtx50', '5080', '5090', '5070'] gpu_lower = gpu_model.lower() return any(model in gpu_lower for model in blackwell_gpus) def get_pytorch_install_command(env): """Get appropriate PyTorch install command for environment""" if env == 'hf_spaces': # ZeroGPU compatible version return (['torch==2.2.0'], None) else: # Local environment system = platform.system() # Check if Apple Silicon if system == 'Darwin' and platform.machine() == 'arm64': print(" Detected Apple Silicon, installing PyTorch with MPS support") return (['torch>=2.2.0'], None) # Check for CUDA on Linux/Windows elif system in ['Linux', 'Windows']: gpu_model, cuda_version = detect_gpu_info() if cuda_version: # Check if GPU requires PyTorch 2.6.0+ needs_pytorch_2_6 = requires_pytorch_2_6(gpu_model) if needs_pytorch_2_6: print(f" ✅ Detected Blackwell GPU ({gpu_model})") print(f" Installing PyTorch nightly with CUDA 12.8 support (sm_120 compatible)") print(f" Note: RTX 5080 requires PyTorch built with CUDA 12.8+ for full support") # Use nightly build for Blackwell GPU support with CUDA 12.8 return (['torch', 'torchvision', 'torchaudio'], 'https://download.pytorch.org/whl/nightly/cu128') # Map CUDA version to PyTorch index URL cuda_map = { '11.8': ('cu118', 'https://download.pytorch.org/whl/cu118'), '12.1': ('cu121', 'https://download.pytorch.org/whl/cu121'), '12.2': ('cu121', 'https://download.pytorch.org/whl/cu121'), # Use 12.1 for 12.2 '12.3': ('cu121', 'https://download.pytorch.org/whl/cu121'), # Use 12.1 for 12.3 '12.4': ('cu124', 'https://download.pytorch.org/whl/cu124'), '12.5': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.5 '12.6': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.6 '12.7': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.7 '12.8': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'), # CUDA 12.8 with sm_120 support '13.0': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'), # Use 12.8 nightly for 13.0 } cuda_suffix, index_url = cuda_map.get(cuda_version, ('cu124', 'https://download.pytorch.org/whl/cu124')) print(f" Installing PyTorch with CUDA {cuda_version} support ({cuda_suffix})") return (['torch', 'torchvision', 'torchaudio'], index_url) else: print(" No CUDA detected, installing CPU-only PyTorch") return (['torch>=2.2.0'], None) else: # Other systems, default to CPU return (['torch>=2.2.0'], None) def install_dependencies(): """Install dependencies based on environment""" env = detect_environment() print("=" * 60) print(f"🔍 Detected environment: {env}") print("=" * 60) # Get PyTorch installation command pytorch_packages, index_url = get_pytorch_install_command(env) # Base dependencies (excluding PyTorch) base_deps = [ 'gradio==5.49.1', 'transformers==4.57.1', 'safetensors==0.6.2', 'accelerate==0.26.1', 'sentencepiece==0.2.0', 'protobuf==4.25.1', 'huggingface-hub>=0.19.0', 'python-dotenv==1.0.0', ] # Add spaces for HF Spaces only if env == 'hf_spaces': base_deps.append('spaces') print("=" * 60) print(f"📦 Installing PyTorch...") print("=" * 60) # Install PyTorch (with optional index URL for CUDA) pytorch_cmd = [sys.executable, '-m', 'pip', 'install', '--upgrade'] + pytorch_packages if index_url: pytorch_cmd.extend(['--index-url', index_url]) try: subprocess.check_call(pytorch_cmd) print("✅ PyTorch installed successfully!") except subprocess.CalledProcessError as e: print(f"❌ PyTorch installation failed: {e}") print(" Falling back to CPU-only PyTorch...") subprocess.check_call([ sys.executable, '-m', 'pip', 'install', '--upgrade', 'torch>=2.2.0' ]) print("=" * 60) print(f"📦 Installing remaining dependencies ({len(base_deps)} packages)...") print("=" * 60) # Install remaining dependencies subprocess.check_call([ sys.executable, '-m', 'pip', 'install', '--upgrade' ] + base_deps) # Verify PyTorch installation print("=" * 60) print("🔍 Verifying PyTorch installation...") print("=" * 60) try: result = subprocess.run([ sys.executable, '-c', 'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA available: {torch.cuda.is_available()}"); print(f"CUDA version: {torch.version.cuda if torch.version.cuda else \"N/A\"}")' ], capture_output=True, text=True, timeout=10) print(result.stdout) except Exception as e: print(f"⚠️ Could not verify PyTorch: {e}") print("=" * 60) print("✅ Installation complete!") print("=" * 60) print(f"Environment: {env}") print(f"PyTorch packages: {', '.join(pytorch_packages)}") if index_url: print(f"Index URL: {index_url}") if __name__ == '__main__': install_dependencies()