simple-chat / setup.py
alex4cip's picture
feat: Enable RTX 5080 GPU support with PyTorch nightly (CUDA 12.8)
6612ab5
#!/usr/bin/env python3
"""
Smart dependency installer
Detects environment and installs appropriate PyTorch version
"""
import os
import sys
import subprocess
import platform
def detect_environment():
"""Detect if running on HF Spaces or local"""
is_hf_spaces = os.environ.get('SPACE_ID') is not None
return 'hf_spaces' if is_hf_spaces else 'local'
def detect_gpu_info():
"""Detect GPU model and CUDA version"""
gpu_model = None
cuda_version = None
try:
# Try nvidia-smi first
result = subprocess.run(
['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'],
capture_output=True,
text=True,
timeout=5
)
if result.returncode == 0:
gpu_model = result.stdout.strip()
print(f" Detected GPU: {gpu_model}")
# Try to get CUDA version from nvcc
try:
nvcc_result = subprocess.run(
['nvcc', '--version'],
capture_output=True,
text=True,
timeout=5
)
if nvcc_result.returncode == 0:
output = nvcc_result.stdout
# Parse CUDA version (e.g., "release 12.1")
if 'release' in output:
version = output.split('release')[1].strip().split(',')[0].strip()
major_minor = '.'.join(version.split('.')[:2])
print(f" Detected CUDA version: {major_minor}")
cuda_version = major_minor
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
# If nvcc not found, try to get CUDA version from nvidia-smi output
if not cuda_version:
result = subprocess.run(
['nvidia-smi'],
capture_output=True,
text=True,
timeout=5
)
for line in result.stdout.split('\n'):
if 'CUDA Version:' in line:
version = line.split('CUDA Version:')[1].strip().split()[0]
major_minor = '.'.join(version.split('.')[:2])
print(f" Detected CUDA version from nvidia-smi: {major_minor}")
cuda_version = major_minor
break
# GPU detected but CUDA version unknown, use latest
if not cuda_version:
print(" NVIDIA GPU detected but CUDA version unknown, using CUDA 12.4")
cuda_version = '12.4'
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
return gpu_model, cuda_version
def requires_pytorch_2_6(gpu_model):
"""Check if GPU requires PyTorch 2.6.0+ (for Blackwell/compute capability 12.0+)"""
if not gpu_model:
return False
# Blackwell GPUs (RTX 50xx series) require PyTorch 2.6.0+
blackwell_gpus = ['rtx 50', 'rtx50', '5080', '5090', '5070']
gpu_lower = gpu_model.lower()
return any(model in gpu_lower for model in blackwell_gpus)
def get_pytorch_install_command(env):
"""Get appropriate PyTorch install command for environment"""
if env == 'hf_spaces':
# ZeroGPU compatible version
return (['torch==2.2.0'], None)
else:
# Local environment
system = platform.system()
# Check if Apple Silicon
if system == 'Darwin' and platform.machine() == 'arm64':
print(" Detected Apple Silicon, installing PyTorch with MPS support")
return (['torch>=2.2.0'], None)
# Check for CUDA on Linux/Windows
elif system in ['Linux', 'Windows']:
gpu_model, cuda_version = detect_gpu_info()
if cuda_version:
# Check if GPU requires PyTorch 2.6.0+
needs_pytorch_2_6 = requires_pytorch_2_6(gpu_model)
if needs_pytorch_2_6:
print(f" βœ… Detected Blackwell GPU ({gpu_model})")
print(f" Installing PyTorch nightly with CUDA 12.8 support (sm_120 compatible)")
print(f" Note: RTX 5080 requires PyTorch built with CUDA 12.8+ for full support")
# Use nightly build for Blackwell GPU support with CUDA 12.8
return (['torch', 'torchvision', 'torchaudio'], 'https://download.pytorch.org/whl/nightly/cu128')
# Map CUDA version to PyTorch index URL
cuda_map = {
'11.8': ('cu118', 'https://download.pytorch.org/whl/cu118'),
'12.1': ('cu121', 'https://download.pytorch.org/whl/cu121'),
'12.2': ('cu121', 'https://download.pytorch.org/whl/cu121'), # Use 12.1 for 12.2
'12.3': ('cu121', 'https://download.pytorch.org/whl/cu121'), # Use 12.1 for 12.3
'12.4': ('cu124', 'https://download.pytorch.org/whl/cu124'),
'12.5': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.5
'12.6': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.6
'12.7': ('cu124', 'https://download.pytorch.org/whl/cu124'), # Use 12.4 for 12.7
'12.8': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'), # CUDA 12.8 with sm_120 support
'13.0': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'), # Use 12.8 nightly for 13.0
}
cuda_suffix, index_url = cuda_map.get(cuda_version, ('cu124', 'https://download.pytorch.org/whl/cu124'))
print(f" Installing PyTorch with CUDA {cuda_version} support ({cuda_suffix})")
return (['torch', 'torchvision', 'torchaudio'], index_url)
else:
print(" No CUDA detected, installing CPU-only PyTorch")
return (['torch>=2.2.0'], None)
else:
# Other systems, default to CPU
return (['torch>=2.2.0'], None)
def install_dependencies():
"""Install dependencies based on environment"""
env = detect_environment()
print("=" * 60)
print(f"πŸ” Detected environment: {env}")
print("=" * 60)
# Get PyTorch installation command
pytorch_packages, index_url = get_pytorch_install_command(env)
# Base dependencies (excluding PyTorch)
base_deps = [
'gradio==5.49.1',
'transformers==4.57.1',
'safetensors==0.6.2',
'accelerate==0.26.1',
'sentencepiece==0.2.0',
'protobuf==4.25.1',
'huggingface-hub>=0.19.0',
'python-dotenv==1.0.0',
]
# Add spaces for HF Spaces only
if env == 'hf_spaces':
base_deps.append('spaces')
print("=" * 60)
print(f"πŸ“¦ Installing PyTorch...")
print("=" * 60)
# Install PyTorch (with optional index URL for CUDA)
pytorch_cmd = [sys.executable, '-m', 'pip', 'install', '--upgrade'] + pytorch_packages
if index_url:
pytorch_cmd.extend(['--index-url', index_url])
try:
subprocess.check_call(pytorch_cmd)
print("βœ… PyTorch installed successfully!")
except subprocess.CalledProcessError as e:
print(f"❌ PyTorch installation failed: {e}")
print(" Falling back to CPU-only PyTorch...")
subprocess.check_call([
sys.executable, '-m', 'pip', 'install', '--upgrade', 'torch>=2.2.0'
])
print("=" * 60)
print(f"πŸ“¦ Installing remaining dependencies ({len(base_deps)} packages)...")
print("=" * 60)
# Install remaining dependencies
subprocess.check_call([
sys.executable, '-m', 'pip', 'install', '--upgrade'
] + base_deps)
# Verify PyTorch installation
print("=" * 60)
print("πŸ” Verifying PyTorch installation...")
print("=" * 60)
try:
result = subprocess.run([
sys.executable, '-c',
'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA available: {torch.cuda.is_available()}"); print(f"CUDA version: {torch.version.cuda if torch.version.cuda else \"N/A\"}")'
], capture_output=True, text=True, timeout=10)
print(result.stdout)
except Exception as e:
print(f"⚠️ Could not verify PyTorch: {e}")
print("=" * 60)
print("βœ… Installation complete!")
print("=" * 60)
print(f"Environment: {env}")
print(f"PyTorch packages: {', '.join(pytorch_packages)}")
if index_url:
print(f"Index URL: {index_url}")
if __name__ == '__main__':
install_dependencies()