File size: 7,564 Bytes
5554ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#!/usr/bin/env python3
"""
Whisper Fine-tuning Setup
Purpose: Fine-tune Whisper-small on German data
GPU: RTX 5060 Ti optimized
"""

import torch
import sys
from pathlib import Path

def check_environment():
    """Verify all dependencies are installed"""
    print("=" * 60)
    print("ENVIRONMENT CHECK")
    print("=" * 60)
    
    # PyTorch
    print(f"βœ“ PyTorch: {torch.__version__}")
    print(f"βœ“ CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        print(f"βœ“ GPU: {torch.cuda.get_device_name(0)}")
        print(f"βœ“ CUDA Capability: {torch.cuda.get_device_capability(0)}")
        print(f"βœ“ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    
    # Check transformers
    try:
        from transformers import AutoModel
        print("βœ“ Transformers: Installed")
    except ImportError:
        print("βœ— Transformers: NOT INSTALLED")
        return False
    
    # Check datasets
    try:
        from datasets import load_dataset
        print("βœ“ Datasets: Installed")
    except ImportError:
        print("βœ— Datasets: NOT INSTALLED")
        return False
    
    # Check librosa
    try:
        import librosa
        print("βœ“ Librosa: Installed")
    except ImportError:
        print("βœ— Librosa: NOT INSTALLED")
        return False
    
    print("\nβœ… All checks passed! Ready to start.\n")
    return True

def download_data():
    """Download and prepare dataset"""
    # Download and prepare dataset
    print("\n" + "=" * 60)
    print("DATASET CONFIGURATION")
    print("=" * 60)
    
    # Dataset size options with estimated training times on RTX 5060 Ti
    DATASET_OPTIONS = {
        'tiny': {
            'split': "train[:5%]",  # ~30 samples
            'estimated_time': "2-5 minutes",
            'vram': "8-10 GB"
        },
        'small': {
            'split': "train[:20%]",  # ~120 samples
            'estimated_time': "10-15 minutes",
            'vram': "10-12 GB"
        },
        'medium': {
            'split': "train[:50%]",  # ~300 samples
            'estimated_time': "30-45 minutes",
            'vram': "12-14 GB"
        },
        'large': {
            'split': "train",  # Full dataset (600+ samples)
            'estimated_time': "1-2 hours",
            'vram': "14-16 GB"
        }
    }
    
    # Default to small dataset
    DATASET_SIZE = 'small'
    print("\nAvailable dataset sizes:")
    for size, info in DATASET_OPTIONS.items():
        print(f"- {size}: {info['split']} (est. {info['estimated_time']}, {info['vram']} VRAM)")
        
    user_choice = input("\nSelect dataset size [tiny/small/medium/large] (default: small): ").lower() or 'small'
    
    if user_choice not in DATASET_OPTIONS:
        print(f"Invalid choice '{user_choice}'. Defaulting to 'small'.")
        user_choice = 'small'
        
    dataset_config = DATASET_OPTIONS[user_choice]
    print(f"\nUsing {user_choice} dataset ({dataset_config['split']})")
    print(f"Estimated training time: {dataset_config['estimated_time']}")
    print(f"Estimated VRAM usage: {dataset_config['vram']}")
    
    # Check if dataset is already downloaded
    dataset_path = f"./data/minds14_{user_choice}"
    
    # Create data directory if it doesn't exist
    import os
    os.makedirs("./data", exist_ok=True)
    
    # First check if we already have the dataset downloaded locally
    if os.path.exists(dataset_path):
        print("\nFound existing dataset, loading from local storage...")
        try:
            from datasets import load_from_disk
            dataset = load_from_disk(dataset_path)
            print(f"\nβœ“ Loaded dataset from {dataset_path}")
            print(f"  Number of samples: {len(dataset)}")
            return dataset
        except Exception as e:
            print(f"\n⚠️  Could not load from local storage: {e}")
            print("Attempting to download again...")
    
    try:
        from datasets import load_dataset
        print("\nLoading PolyAI/minds14 dataset...")
        
        # Load a small subset of the dataset
        dataset = load_dataset(
            "PolyAI/minds14",
            "de-DE",  # German subset
            split=dataset_config['split']  # Use selected split
        )
        
        print(f"\nβœ“ Successfully loaded test dataset")
        print(f"  Number of samples: {len(dataset)}")
        print(f"  Features: {dataset.features}")
        
        # Save the dataset locally for faster loading next time
        dataset.save_to_disk(dataset_path)
        print(f"\nβœ“ Dataset saved to {dataset_path}")
        
        return dataset
        
    except Exception as e:
        print("\n❌ Failed to load test dataset. Here are some options:")
        print("\n1. CHECK YOUR INTERNET CONNECTION")
        print("   - Make sure you have a stable internet connection")
        print("   - Try using a VPN if you're in a restricted region")
        print("\n2. TRY MANUAL DOWNLOAD")
        print("   - Visit: https://huggingface.co/datasets/PolyAI/minds14")
        print("   - Follow the instructions to download the dataset")
        print("   - Place the downloaded files in the './data' directory")
        print("\n3. TRY A DIFFERENT DATASET")
        print("   - Let me know if you'd like to try a different dataset")
        print("\nError details:", str(e))
        raise

def optimize_settings():
    """Configure PyTorch for RTX 5060 Ti"""
    print("=" * 60)
    print("OPTIMIZING FOR RTX 5060 Ti")
    print("=" * 60)
    
    # Enable optimizations
    torch.set_float32_matmul_precision('high')
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    
    print("βœ“ torch.set_float32_matmul_precision('high')")
    print("βœ“ torch.backends.cuda.matmul.allow_tf32 = True")
    print("βœ“ torch.backends.cudnn.benchmark = True")
    print("\nThese settings will:")
    print("  β€’ Use Tensor Float 32 (TF32) for faster matrix operations")
    print("  β€’ Enable cuDNN auto-tuning for optimal kernel selection")
    print("  β€’ Expected speedup: 10-20%")
    
    return True

def main():
    """Main setup function"""
    print("\n" + "=" * 60)
    print("WHISPER FINE-TUNING SETUP")
    print("Project: Multilingual ASR for German")
    print("GPU: RTX 5060 Ti (16GB VRAM)")
    print("=" * 60 + "\n")
    
    # Check environment
    if not check_environment():
        print("❌ Environment check failed. Please install missing packages.")
        return False
    
    # Optimize settings
    optimize_settings()
    
    # Download data
    try:
        dataset = download_data()
        # Find which dataset was downloaded
        import os
        dataset_path = "./data/minds14_small"  # Default
        for size in ['large', 'medium', 'small', 'tiny']:
            path = f"./data/minds14_{size}"
            if os.path.exists(path):
                dataset_path = path
                break
    except Exception as e:
        print(f"⚠️  Data download failed: {e}")
        print("You can retry later with: python project1_whisper_setup.py")
        return False
    
    print("\n" + "=" * 60)
    print("βœ… SETUP COMPLETE!")
    print("=" * 60)
    print("\nNext steps:")
    print(f"1. Review the dataset in {dataset_path}/")
    print("2. Run: python project1_whisper_train.py")
    print("3. Fine-tuning will begin (expect 2-3 days on RTX 5060 Ti)")
    print("=" * 60 + "\n")
    
    return True

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)