Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Smart dependency installer | |
| Detects environment and installs appropriate PyTorch version | |
| """ | |
| import os | |
| import sys | |
| import subprocess | |
| import platform | |
| def detect_environment(): | |
| """Detect if running on HF Spaces or local""" | |
| is_hf_spaces = os.environ.get('SPACE_ID') is not None | |
| return 'hf_spaces' if is_hf_spaces else 'local' | |
| def detect_gpu_info(): | |
| """Detect GPU model and CUDA version""" | |
| gpu_model = None | |
| cuda_version = None | |
| try: | |
| # Try nvidia-smi first | |
| result = subprocess.run( | |
| ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 | |
| ) | |
| if result.returncode == 0: | |
| gpu_model = result.stdout.strip() | |
| print(f" Detected GPU: {gpu_model}") | |
| # Try to get CUDA version from nvcc | |
| try: | |
| nvcc_result = subprocess.run( | |
| ['nvcc', '--version'], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 | |
| ) | |
| if nvcc_result.returncode == 0: | |
| output = nvcc_result.stdout | |
| # Parse CUDA version (e.g., "release 12.1") | |
| if 'release' in output: | |
| version = output.split('release')[1].strip().split(',')[0].strip() | |
| major_minor = '.'.join(version.split('.')[:2]) | |
| print(f" Detected CUDA version: {major_minor}") | |
| cuda_version = major_minor | |
| except (FileNotFoundError, subprocess.TimeoutExpired): | |
| pass | |
| # If nvcc not found, try to get CUDA version from nvidia-smi output | |
| if not cuda_version: | |
| result = subprocess.run( | |
| ['nvidia-smi'], | |
| capture_output=True, | |
| text=True, | |
| timeout=5 | |
| ) | |
| for line in result.stdout.split('\n'): | |
| if 'CUDA Version:' in line: | |
| version = line.split('CUDA Version:')[1].strip().split()[0] | |
| major_minor = '.'.join(version.split('.')[:2]) | |
| print(f" Detected CUDA version from nvidia-smi: {major_minor}") | |
| cuda_version = major_minor | |
| break | |
| # GPU detected but CUDA version unknown, use latest | |
| if not cuda_version: | |
| print(" NVIDIA GPU detected but CUDA version unknown, using CUDA 12.4") | |
| cuda_version = '12.4' | |
| except (FileNotFoundError, subprocess.TimeoutExpired): | |
| pass | |
| return gpu_model, cuda_version | |
| def requires_pytorch_2_6(gpu_model): | |
| """Check if GPU requires PyTorch 2.6.0+ (for Blackwell/compute capability 12.0+)""" | |
| if not gpu_model: | |
| return False | |
| # Blackwell GPUs (RTX 50xx series) require PyTorch 2.6.0+ | |
| blackwell_gpus = ['rtx 50', 'rtx50', '5080', '5090', '5070'] | |
| gpu_lower = gpu_model.lower() | |
| return any(model in gpu_lower for model in blackwell_gpus) | |
| def get_pytorch_install_command(env): | |
| """Get appropriate PyTorch install command for environment""" | |
| if env == 'hf_spaces': | |
| # ZeroGPU compatible version | |
| return (['torch==2.2.0'], None) | |
| else: | |
| # Local environment | |
| system = platform.system() | |
| # Check if Apple Silicon | |
| if system == 'Darwin' and platform.machine() == 'arm64': | |
| print(" Detected Apple Silicon, installing PyTorch with MPS support") | |
| return (['torch>=2.2.0'], None) | |
| # Check for CUDA on Linux/Windows | |
| elif system in ['Linux', 'Windows']: | |
| gpu_model, cuda_version = detect_gpu_info() | |
| if cuda_version: | |
| # Check if GPU requires PyTorch 2.6.0+ | |
| needs_pytorch_2_6 = requires_pytorch_2_6(gpu_model) | |
| if needs_pytorch_2_6: | |
| print(f" β Detected Blackwell GPU ({gpu_model})") | |
| print(f" Installing PyTorch nightly with CUDA 12.8 support (sm_120 compatible)") | |
| print(f" Note: RTX 5080 requires PyTorch built with CUDA 12.8+ for full support") | |
| # Use nightly build for Blackwell GPU support with CUDA 12.8 | |
| return (['torch', 'torchvision', 'torchaudio'], 'https://download.pytorch.org/whl/nightly/cu128') | |
| # Map CUDA version to PyTorch index URL | |
| cuda_map = { | |
| '11.8': ('cu118', 'https://download.pytorch.org/whl/cu118'), | |
| '12.1': ('cu121', 'https://download.pytorch.org/whl/cu121'), | |
| '12.2': ('cu121', 'https://download.pytorch.org/whl/cu121'), # Use 12.1 for 12.2 | |
| '12.3': ('cu121', 'https://download.pytorch.org/whl/cu121'), # Use 12.1 for 12.3 | |
| '12.4': ('cu124', 'https://download.pytorch.org/whl/cu124'), | |
| '12.5': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.5 | |
| '12.6': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.6 | |
| '12.7': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.7 | |
| '12.8': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'), # CUDA 12.8 with sm_120 support | |
| '13.0': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'), # Use 12.8 nightly for 13.0 | |
| } | |
| cuda_suffix, index_url = cuda_map.get(cuda_version, ('cu124', 'https://download.pytorch.org/whl/cu124')) | |
| print(f" Installing PyTorch with CUDA {cuda_version} support ({cuda_suffix})") | |
| return (['torch', 'torchvision', 'torchaudio'], index_url) | |
| else: | |
| print(" No CUDA detected, installing CPU-only PyTorch") | |
| return (['torch>=2.2.0'], None) | |
| else: | |
| # Other systems, default to CPU | |
| return (['torch>=2.2.0'], None) | |
| def install_dependencies(): | |
| """Install dependencies based on environment""" | |
| env = detect_environment() | |
| print("=" * 60) | |
| print(f"π Detected environment: {env}") | |
| print("=" * 60) | |
| # Get PyTorch installation command | |
| pytorch_packages, index_url = get_pytorch_install_command(env) | |
| # Base dependencies (excluding PyTorch) | |
| base_deps = [ | |
| 'gradio==5.49.1', | |
| 'transformers==4.57.1', | |
| 'safetensors==0.6.2', | |
| 'accelerate==0.26.1', | |
| 'sentencepiece==0.2.0', | |
| 'protobuf==4.25.1', | |
| 'huggingface-hub>=0.19.0', | |
| 'python-dotenv==1.0.0', | |
| ] | |
| # Add spaces for HF Spaces only | |
| if env == 'hf_spaces': | |
| base_deps.append('spaces') | |
| print("=" * 60) | |
| print(f"π¦ Installing PyTorch...") | |
| print("=" * 60) | |
| # Install PyTorch (with optional index URL for CUDA) | |
| pytorch_cmd = [sys.executable, '-m', 'pip', 'install', '--upgrade'] + pytorch_packages | |
| if index_url: | |
| pytorch_cmd.extend(['--index-url', index_url]) | |
| try: | |
| subprocess.check_call(pytorch_cmd) | |
| print("β PyTorch installed successfully!") | |
| except subprocess.CalledProcessError as e: | |
| print(f"β PyTorch installation failed: {e}") | |
| print(" Falling back to CPU-only PyTorch...") | |
| subprocess.check_call([ | |
| sys.executable, '-m', 'pip', 'install', '--upgrade', 'torch>=2.2.0' | |
| ]) | |
| print("=" * 60) | |
| print(f"π¦ Installing remaining dependencies ({len(base_deps)} packages)...") | |
| print("=" * 60) | |
| # Install remaining dependencies | |
| subprocess.check_call([ | |
| sys.executable, '-m', 'pip', 'install', '--upgrade' | |
| ] + base_deps) | |
| # Verify PyTorch installation | |
| print("=" * 60) | |
| print("π Verifying PyTorch installation...") | |
| print("=" * 60) | |
| try: | |
| result = subprocess.run([ | |
| sys.executable, '-c', | |
| 'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA available: {torch.cuda.is_available()}"); print(f"CUDA version: {torch.version.cuda if torch.version.cuda else \"N/A\"}")' | |
| ], capture_output=True, text=True, timeout=10) | |
| print(result.stdout) | |
| except Exception as e: | |
| print(f"β οΈ Could not verify PyTorch: {e}") | |
| print("=" * 60) | |
| print("β Installation complete!") | |
| print("=" * 60) | |
| print(f"Environment: {env}") | |
| print(f"PyTorch packages: {', '.join(pytorch_packages)}") | |
| if index_url: | |
| print(f"Index URL: {index_url}") | |
| if __name__ == '__main__': | |
| install_dependencies() | |