Spaces:
Sleeping
A newer version of the Gradio SDK is available:
6.3.0
Whisper German ASR Fine-Tuning Project
Project Overview
This project fine-tunes OpenAI's Whisper model for German Automatic Speech Recognition (ASR) using the PolyAI/minds14 dataset.
Hardware Setup
- GPU: NVIDIA GeForce RTX 5060 Ti (16GB VRAM)
- CUDA: 13.0
- PyTorch: 2.9.0+cu130
- Flash Attention 2: Enabled (v2.8.3)
Project Structure
ai-career-project/
βββ project1_whisper_setup.py # Dataset download and preparation
βββ project1_whisper_train.py # Model training script
βββ project1_whisper_inference.py # Inference and testing script
βββ data/
β βββ minds14_small/ # Training dataset (122 samples)
βββ whisper_test_tuned/ # Fine-tuned model checkpoints
βββ checkpoint-28/
βββ checkpoint-224/ # Final checkpoint
Dataset Options
| Size | Split | Samples | Training Time | VRAM Usage | Best For |
|---|---|---|---|---|---|
| Tiny | 5% | ~30 | 30 seconds | 8-10 GB | Quick testing |
| Small | 20% | ~120 | 2 minutes | 10-12 GB | Experiments β |
| Medium | 50% | ~300 | 5-6 minutes | 12-14 GB | Good results |
| Large | 100% | ~600 | 10-12 minutes | 14-16 GB | Best performance |
Training Results (Small Dataset)
Configuration
- Model: Whisper-small (242M parameters)
- Training samples: 109
- Evaluation samples: 13
- Batch size: 4
- Learning rate: 2e-05
- Epochs: 8
- Mixed precision: BF16
- Flash Attention 2: Enabled
- Gradient checkpointing: Disabled
Performance
- Training time: ~2 minutes (119 seconds)
- Training speed: 7.27 samples/second
- Final training loss: 4684.90
- Final evaluation loss: 2490.13
Current Issues
β οΈ Model Performance: The model trained on the small dataset (109 samples) shows poor inference quality, generating repetitive outputs. This is expected with such a small dataset.
Recommendations for Better Results
1. Use Larger Dataset β RECOMMENDED
# Run setup with medium or large dataset
python project1_whisper_setup.py
# Select 'medium' or 'large' when prompted
Expected improvements:
- Medium (300 samples): 5-6 minutes training, significantly better quality
- Large (600 samples): 10-12 minutes training, best quality
2. Adjust Training Parameters
For larger datasets, the training script automatically adjusts:
- Batch size: 4
- Gradient accumulation: 2
- Learning rate: 1e-5
- Epochs: 5
3. Use Pre-trained Model for Inference
If you need immediate results, use the base Whisper model:
from transformers import pipeline
# Use base Whisper model (no fine-tuning needed)
pipe = pipeline("automatic-speech-recognition",
model="openai/whisper-small",
device=0) # Use GPU
result = pipe("audio.wav", generate_kwargs={"language": "german"})
print(result["text"])
Recent Improvements (v2.0)
Training Pipeline Enhancements
β Fixed Trainer API Issues
- Corrected
evaluation_strategyparameter (waseval_strategy) - Fixed
tokenizerparameter (wasprocessing_class) - Added German language/task conditioning for proper decoder behavior
β Improved Hyperparameters
- Increased learning rates: 1e-5 to 2e-5 (was 5e-6)
- Added warmup ratio (3-5%) for better convergence
- Removed dtype conflicts (let Trainer control precision)
- Optimized epochs by dataset size (8-15 epochs)
β Data Quality & Processing
- Duration filtering (0.5s - 30s)
- Transcript length validation
- Text normalization for consistent WER computation
- Group by length for reduced padding
β Evaluation & Monitoring
- WER (Word Error Rate) metric with jiwer
- TensorBoard logging for all metrics
- Best model selection by WER (not just loss)
- Predict with generate for proper evaluation
Why Training Should Improve Now
- Proper evaluation: WER tracking shows actual quality improvements
- Better learning rate: 2-4x higher LR enables faster convergence
- Language conditioning: Model knows it's transcribing German
- Data filtering: Removes noisy/invalid samples that hurt training
- Best model selection: Saves checkpoint with lowest WER, not just loss
Installation
1. Install Dependencies
pip install -r requirements.txt
2. (Optional) Install Flash Attention 2
For faster training (requires CUDA toolkit):
pip install flash-attn --no-build-isolation
Usage
1. Setup Dataset
python project1_whisper_setup.py
Select dataset size when prompted (recommend 'medium' or 'large')
2. Train Model
python project1_whisper_train.py
3. Monitor Training with TensorBoard
In a separate terminal, start TensorBoard:
tensorboard --logdir=./logs
Then open http://localhost:6006 in your browser to view:
- Training/Evaluation Loss - Track model convergence
- WER (Word Error Rate) - Monitor transcription quality
- Learning Rate - Visualize warmup and decay
- Gradient Norms - Check training stability
You can also monitor GPU usage:
nvidia-smi -l 1
4. Test Model
# Test with dataset samples
python project1_whisper_inference.py --test --num-samples 10
# Transcribe specific audio files
python project1_whisper_inference.py --audio file1.wav file2.wav
# Interactive mode
python project1_whisper_inference.py --interactive
Key Features
Flash Attention 2 Integration
- Faster training: 10-20% speedup
- Memory efficient: No gradient checkpointing needed
- Stable training: BF16 mixed precision
Automatic Configuration
The training script automatically adjusts parameters based on dataset size:
- Batch size and gradient accumulation
- Learning rate (1e-5 to 2e-5) and warmup ratio
- Number of epochs (8-15)
- Training time estimation
Data Quality Filtering
- Duration filtering: 0.5s to 30s audio clips
- Transcript validation: Removes empty or too-long texts
- Quality checks: Filters invalid audio samples
- Automatic normalization: Consistent text preprocessing
Evaluation & Metrics
- WER (Word Error Rate): Primary quality metric
- TensorBoard logging: Real-time training visualization
- Best model selection: Automatically saves best checkpoint by WER
- Predict with generate: Proper sequence generation for evaluation
Flexible Dataset Handling
- Automatic train/validation split
- Caches processed datasets
- Supports different dataset sizes
- Progress tracking and metrics
- Group by length for efficient batching
Performance Optimization
Current Optimizations
β Flash Attention 2 enabled β BF16 mixed precision β TF32 matrix operations β cuDNN auto-tuning β Automatic device placement
Training Speed
- Small dataset (109 samples): ~2 minutes for 8 epochs
- Estimated for medium (300 samples): ~5-6 minutes for 5 epochs
- Estimated for large (600 samples): ~10-12 minutes for 5 epochs
Next Steps
Immediate Actions
- Retrain with larger dataset (medium or large) for better results
- Evaluate model quality with Word Error Rate (WER) metrics
- Test on real-world audio samples
Future Improvements
- Use larger Whisper model (medium or large) for better accuracy
- Add data augmentation (speed, pitch, noise)
- Create web interface for easy testing
- Deploy model as API service
- Push to Hugging Face Hub for sharing and deployment
Troubleshooting
Common Issues
1. Model generates repetitive outputs
- Cause: Dataset too small (< 200 samples)
- Solution: Use medium or large dataset
2. Out of memory errors
- Cause: Batch size too large
- Solution: Reduce batch size in training script
3. Slow training
- Cause: Flash Attention 2 not enabled
- Solution: Verify
flash-attnis installed
4. Poor transcription quality
- Cause: Insufficient training data
- Solution: Use larger dataset or more epochs
Technical Details
Model Architecture
- Base model: OpenAI Whisper-small
- Parameters: 242M
- Input: 16kHz mono audio
- Output: German text transcription
Training Process
- Load and preprocess audio (resample to 16kHz)
- Extract mel-spectrogram features
- Fine-tune encoder-decoder with teacher forcing
- Evaluate on validation set each epoch
- Save best checkpoint based on loss
Generation Parameters
model.generate(
input_features,
max_length=448,
num_beams=5,
temperature=0.0,
do_sample=False,
repetition_penalty=1.2,
no_repeat_ngram_size=3
)
Resources
- Whisper Paper: https://arxiv.org/abs/2212.04356
- Hugging Face Transformers: https://huggingface.co/docs/transformers
- Flash Attention 2: https://github.com/Dao-AILab/flash-attention
- Dataset: https://huggingface.co/datasets/PolyAI/minds14
License
This project uses the MIT License. The Whisper model is licensed under Apache 2.0.
Contact
For questions or issues, please create an issue in the project repository.