File size: 8,671 Bytes
51c066f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c96300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51c066f
 
2c96300
51c066f
2c96300
 
 
51c066f
2c96300
 
 
 
 
 
 
 
 
 
 
 
 
6612ab5
 
 
 
 
2c96300
 
 
 
 
 
 
 
 
 
 
6612ab5
 
2c96300
 
 
 
 
 
 
 
51c066f
2c96300
 
51c066f
 
 
 
2c96300
 
 
 
 
 
51c066f
2c96300
51c066f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c96300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51c066f
2c96300
 
 
 
 
51c066f
 
 
 
2c96300
 
 
 
 
 
 
 
 
 
 
 
 
 
51c066f
2c96300
51c066f
2c96300
 
 
51c066f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
#!/usr/bin/env python3
"""
Smart dependency installer
Detects environment and installs appropriate PyTorch version
"""

import os
import sys
import subprocess
import platform

def detect_environment():
    """Detect if running on HF Spaces or local"""
    is_hf_spaces = os.environ.get('SPACE_ID') is not None
    return 'hf_spaces' if is_hf_spaces else 'local'

def detect_gpu_info():
    """Detect GPU model and CUDA version"""
    gpu_model = None
    cuda_version = None

    try:
        # Try nvidia-smi first
        result = subprocess.run(
            ['nvidia-smi', '--query-gpu=gpu_name', '--format=csv,noheader'],
            capture_output=True,
            text=True,
            timeout=5
        )
        if result.returncode == 0:
            gpu_model = result.stdout.strip()
            print(f"   Detected GPU: {gpu_model}")

            # Try to get CUDA version from nvcc
            try:
                nvcc_result = subprocess.run(
                    ['nvcc', '--version'],
                    capture_output=True,
                    text=True,
                    timeout=5
                )
                if nvcc_result.returncode == 0:
                    output = nvcc_result.stdout
                    # Parse CUDA version (e.g., "release 12.1")
                    if 'release' in output:
                        version = output.split('release')[1].strip().split(',')[0].strip()
                        major_minor = '.'.join(version.split('.')[:2])
                        print(f"   Detected CUDA version: {major_minor}")
                        cuda_version = major_minor
            except (FileNotFoundError, subprocess.TimeoutExpired):
                pass

            # If nvcc not found, try to get CUDA version from nvidia-smi output
            if not cuda_version:
                result = subprocess.run(
                    ['nvidia-smi'],
                    capture_output=True,
                    text=True,
                    timeout=5
                )
                for line in result.stdout.split('\n'):
                    if 'CUDA Version:' in line:
                        version = line.split('CUDA Version:')[1].strip().split()[0]
                        major_minor = '.'.join(version.split('.')[:2])
                        print(f"   Detected CUDA version from nvidia-smi: {major_minor}")
                        cuda_version = major_minor
                        break

            # GPU detected but CUDA version unknown, use latest
            if not cuda_version:
                print("   NVIDIA GPU detected but CUDA version unknown, using CUDA 12.4")
                cuda_version = '12.4'

    except (FileNotFoundError, subprocess.TimeoutExpired):
        pass

    return gpu_model, cuda_version

def requires_pytorch_2_6(gpu_model):
    """Check if GPU requires PyTorch 2.6.0+ (for Blackwell/compute capability 12.0+)"""
    if not gpu_model:
        return False

    # Blackwell GPUs (RTX 50xx series) require PyTorch 2.6.0+
    blackwell_gpus = ['rtx 50', 'rtx50', '5080', '5090', '5070']
    gpu_lower = gpu_model.lower()
    return any(model in gpu_lower for model in blackwell_gpus)

def get_pytorch_install_command(env):
    """Get appropriate PyTorch install command for environment"""
    if env == 'hf_spaces':
        # ZeroGPU compatible version
        return (['torch==2.2.0'], None)
    else:
        # Local environment
        system = platform.system()

        # Check if Apple Silicon
        if system == 'Darwin' and platform.machine() == 'arm64':
            print("   Detected Apple Silicon, installing PyTorch with MPS support")
            return (['torch>=2.2.0'], None)

        # Check for CUDA on Linux/Windows
        elif system in ['Linux', 'Windows']:
            gpu_model, cuda_version = detect_gpu_info()

            if cuda_version:
                # Check if GPU requires PyTorch 2.6.0+
                needs_pytorch_2_6 = requires_pytorch_2_6(gpu_model)

                if needs_pytorch_2_6:
                    print(f"   βœ… Detected Blackwell GPU ({gpu_model})")
                    print(f"   Installing PyTorch nightly with CUDA 12.8 support (sm_120 compatible)")
                    print(f"   Note: RTX 5080 requires PyTorch built with CUDA 12.8+ for full support")
                    # Use nightly build for Blackwell GPU support with CUDA 12.8
                    return (['torch', 'torchvision', 'torchaudio'], 'https://download.pytorch.org/whl/nightly/cu128')

                # Map CUDA version to PyTorch index URL
                cuda_map = {
                    '11.8': ('cu118', 'https://download.pytorch.org/whl/cu118'),
                    '12.1': ('cu121', 'https://download.pytorch.org/whl/cu121'),
                    '12.2': ('cu121', 'https://download.pytorch.org/whl/cu121'),  # Use 12.1 for 12.2
                    '12.3': ('cu121', 'https://download.pytorch.org/whl/cu121'),  # Use 12.1 for 12.3
                    '12.4': ('cu124', 'https://download.pytorch.org/whl/cu124'),
                    '12.5': ('cu124', 'https://download.pytorch.org/whl/cu124'),  # Use 12.4 for 12.5
                    '12.6': ('cu124', 'https://download.pytorch.org/whl/cu124'),  # Use 12.4 for 12.6
                    '12.7': ('cu124', 'https://download.pytorch.org/whl/cu124'),  # Use 12.4 for 12.7
                    '12.8': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'),  # CUDA 12.8 with sm_120 support
                    '13.0': ('cu128', 'https://download.pytorch.org/whl/nightly/cu128'),  # Use 12.8 nightly for 13.0
                }

                cuda_suffix, index_url = cuda_map.get(cuda_version, ('cu124', 'https://download.pytorch.org/whl/cu124'))
                print(f"   Installing PyTorch with CUDA {cuda_version} support ({cuda_suffix})")
                return (['torch', 'torchvision', 'torchaudio'], index_url)
            else:
                print("   No CUDA detected, installing CPU-only PyTorch")
                return (['torch>=2.2.0'], None)
        else:
            # Other systems, default to CPU
            return (['torch>=2.2.0'], None)

def install_dependencies():
    """Install dependencies based on environment"""
    env = detect_environment()
    print("=" * 60)
    print(f"πŸ” Detected environment: {env}")
    print("=" * 60)

    # Get PyTorch installation command
    pytorch_packages, index_url = get_pytorch_install_command(env)

    # Base dependencies (excluding PyTorch)
    base_deps = [
        'gradio==5.49.1',
        'transformers==4.57.1',
        'safetensors==0.6.2',
        'accelerate==0.26.1',
        'sentencepiece==0.2.0',
        'protobuf==4.25.1',
        'huggingface-hub>=0.19.0',
        'python-dotenv==1.0.0',
    ]

    # Add spaces for HF Spaces only
    if env == 'hf_spaces':
        base_deps.append('spaces')

    print("=" * 60)
    print(f"πŸ“¦ Installing PyTorch...")
    print("=" * 60)

    # Install PyTorch (with optional index URL for CUDA)
    pytorch_cmd = [sys.executable, '-m', 'pip', 'install', '--upgrade'] + pytorch_packages
    if index_url:
        pytorch_cmd.extend(['--index-url', index_url])

    try:
        subprocess.check_call(pytorch_cmd)
        print("βœ… PyTorch installed successfully!")
    except subprocess.CalledProcessError as e:
        print(f"❌ PyTorch installation failed: {e}")
        print("   Falling back to CPU-only PyTorch...")
        subprocess.check_call([
            sys.executable, '-m', 'pip', 'install', '--upgrade', 'torch>=2.2.0'
        ])

    print("=" * 60)
    print(f"πŸ“¦ Installing remaining dependencies ({len(base_deps)} packages)...")
    print("=" * 60)

    # Install remaining dependencies
    subprocess.check_call([
        sys.executable, '-m', 'pip', 'install', '--upgrade'
    ] + base_deps)

    # Verify PyTorch installation
    print("=" * 60)
    print("πŸ” Verifying PyTorch installation...")
    print("=" * 60)
    try:
        result = subprocess.run([
            sys.executable, '-c',
            'import torch; print(f"PyTorch: {torch.__version__}"); print(f"CUDA available: {torch.cuda.is_available()}"); print(f"CUDA version: {torch.version.cuda if torch.version.cuda else \"N/A\"}")'
        ], capture_output=True, text=True, timeout=10)
        print(result.stdout)
    except Exception as e:
        print(f"⚠️  Could not verify PyTorch: {e}")

    print("=" * 60)
    print("βœ… Installation complete!")
    print("=" * 60)
    print(f"Environment: {env}")
    print(f"PyTorch packages: {', '.join(pytorch_packages)}")
    if index_url:
        print(f"Index URL: {index_url}")

if __name__ == '__main__':
    install_dependencies()