ASR-finetuning / project1_whisper_setup.py
saadmannan's picture
HF space application - exclude binary PDFs
5554ef1
#!/usr/bin/env python3
"""
Whisper Fine-tuning Setup
Purpose: Fine-tune Whisper-small on German data
GPU: RTX 5060 Ti optimized
"""
import torch
import sys
from pathlib import Path
def check_environment():
"""Verify all dependencies are installed"""
print("=" * 60)
print("ENVIRONMENT CHECK")
print("=" * 60)
# PyTorch
print(f"βœ“ PyTorch: {torch.__version__}")
print(f"βœ“ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"βœ“ GPU: {torch.cuda.get_device_name(0)}")
print(f"βœ“ CUDA Capability: {torch.cuda.get_device_capability(0)}")
print(f"βœ“ VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# Check transformers
try:
from transformers import AutoModel
print("βœ“ Transformers: Installed")
except ImportError:
print("βœ— Transformers: NOT INSTALLED")
return False
# Check datasets
try:
from datasets import load_dataset
print("βœ“ Datasets: Installed")
except ImportError:
print("βœ— Datasets: NOT INSTALLED")
return False
# Check librosa
try:
import librosa
print("βœ“ Librosa: Installed")
except ImportError:
print("βœ— Librosa: NOT INSTALLED")
return False
print("\nβœ… All checks passed! Ready to start.\n")
return True
def download_data():
"""Download and prepare dataset"""
# Download and prepare dataset
print("\n" + "=" * 60)
print("DATASET CONFIGURATION")
print("=" * 60)
# Dataset size options with estimated training times on RTX 5060 Ti
DATASET_OPTIONS = {
'tiny': {
'split': "train[:5%]", # ~30 samples
'estimated_time': "2-5 minutes",
'vram': "8-10 GB"
},
'small': {
'split': "train[:20%]", # ~120 samples
'estimated_time': "10-15 minutes",
'vram': "10-12 GB"
},
'medium': {
'split': "train[:50%]", # ~300 samples
'estimated_time': "30-45 minutes",
'vram': "12-14 GB"
},
'large': {
'split': "train", # Full dataset (600+ samples)
'estimated_time': "1-2 hours",
'vram': "14-16 GB"
}
}
# Default to small dataset
DATASET_SIZE = 'small'
print("\nAvailable dataset sizes:")
for size, info in DATASET_OPTIONS.items():
print(f"- {size}: {info['split']} (est. {info['estimated_time']}, {info['vram']} VRAM)")
user_choice = input("\nSelect dataset size [tiny/small/medium/large] (default: small): ").lower() or 'small'
if user_choice not in DATASET_OPTIONS:
print(f"Invalid choice '{user_choice}'. Defaulting to 'small'.")
user_choice = 'small'
dataset_config = DATASET_OPTIONS[user_choice]
print(f"\nUsing {user_choice} dataset ({dataset_config['split']})")
print(f"Estimated training time: {dataset_config['estimated_time']}")
print(f"Estimated VRAM usage: {dataset_config['vram']}")
# Check if dataset is already downloaded
dataset_path = f"./data/minds14_{user_choice}"
# Create data directory if it doesn't exist
import os
os.makedirs("./data", exist_ok=True)
# First check if we already have the dataset downloaded locally
if os.path.exists(dataset_path):
print("\nFound existing dataset, loading from local storage...")
try:
from datasets import load_from_disk
dataset = load_from_disk(dataset_path)
print(f"\nβœ“ Loaded dataset from {dataset_path}")
print(f" Number of samples: {len(dataset)}")
return dataset
except Exception as e:
print(f"\n⚠️ Could not load from local storage: {e}")
print("Attempting to download again...")
try:
from datasets import load_dataset
print("\nLoading PolyAI/minds14 dataset...")
# Load a small subset of the dataset
dataset = load_dataset(
"PolyAI/minds14",
"de-DE", # German subset
split=dataset_config['split'] # Use selected split
)
print(f"\nβœ“ Successfully loaded test dataset")
print(f" Number of samples: {len(dataset)}")
print(f" Features: {dataset.features}")
# Save the dataset locally for faster loading next time
dataset.save_to_disk(dataset_path)
print(f"\nβœ“ Dataset saved to {dataset_path}")
return dataset
except Exception as e:
print("\n❌ Failed to load test dataset. Here are some options:")
print("\n1. CHECK YOUR INTERNET CONNECTION")
print(" - Make sure you have a stable internet connection")
print(" - Try using a VPN if you're in a restricted region")
print("\n2. TRY MANUAL DOWNLOAD")
print(" - Visit: https://huggingface.co/datasets/PolyAI/minds14")
print(" - Follow the instructions to download the dataset")
print(" - Place the downloaded files in the './data' directory")
print("\n3. TRY A DIFFERENT DATASET")
print(" - Let me know if you'd like to try a different dataset")
print("\nError details:", str(e))
raise
def optimize_settings():
"""Configure PyTorch for RTX 5060 Ti"""
print("=" * 60)
print("OPTIMIZING FOR RTX 5060 Ti")
print("=" * 60)
# Enable optimizations
torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print("βœ“ torch.set_float32_matmul_precision('high')")
print("βœ“ torch.backends.cuda.matmul.allow_tf32 = True")
print("βœ“ torch.backends.cudnn.benchmark = True")
print("\nThese settings will:")
print(" β€’ Use Tensor Float 32 (TF32) for faster matrix operations")
print(" β€’ Enable cuDNN auto-tuning for optimal kernel selection")
print(" β€’ Expected speedup: 10-20%")
return True
def main():
"""Main setup function"""
print("\n" + "=" * 60)
print("WHISPER FINE-TUNING SETUP")
print("Project: Multilingual ASR for German")
print("GPU: RTX 5060 Ti (16GB VRAM)")
print("=" * 60 + "\n")
# Check environment
if not check_environment():
print("❌ Environment check failed. Please install missing packages.")
return False
# Optimize settings
optimize_settings()
# Download data
try:
dataset = download_data()
# Find which dataset was downloaded
import os
dataset_path = "./data/minds14_small" # Default
for size in ['large', 'medium', 'small', 'tiny']:
path = f"./data/minds14_{size}"
if os.path.exists(path):
dataset_path = path
break
except Exception as e:
print(f"⚠️ Data download failed: {e}")
print("You can retry later with: python project1_whisper_setup.py")
return False
print("\n" + "=" * 60)
print("βœ… SETUP COMPLETE!")
print("=" * 60)
print("\nNext steps:")
print(f"1. Review the dataset in {dataset_path}/")
print("2. Run: python project1_whisper_train.py")
print("3. Fine-tuning will begin (expect 2-3 days on RTX 5060 Ti)")
print("=" * 60 + "\n")
return True
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)