Spaces:
Sleeping
Sleeping
| # Whisper German ASR Fine-Tuning Project | |
| ## Project Overview | |
| This project fine-tunes OpenAI's Whisper model for German Automatic Speech Recognition (ASR) using the PolyAI/minds14 dataset. | |
| ## Hardware Setup | |
| - **GPU**: NVIDIA GeForce RTX 5060 Ti (16GB VRAM) | |
| - **CUDA**: 13.0 | |
| - **PyTorch**: 2.9.0+cu130 | |
| - **Flash Attention 2**: Enabled (v2.8.3) | |
| ## Project Structure | |
| ``` | |
| ai-career-project/ | |
| βββ project1_whisper_setup.py # Dataset download and preparation | |
| βββ project1_whisper_train.py # Model training script | |
| βββ project1_whisper_inference.py # Inference and testing script | |
| βββ data/ | |
| β βββ minds14_small/ # Training dataset (122 samples) | |
| βββ whisper_test_tuned/ # Fine-tuned model checkpoints | |
| βββ checkpoint-28/ | |
| βββ checkpoint-224/ # Final checkpoint | |
| ``` | |
| ## Dataset Options | |
| | Size | Split | Samples | Training Time | VRAM Usage | Best For | | |
| |------|-------|---------|---------------|------------|----------| | |
| | **Tiny** | 5% | ~30 | 30 seconds | 8-10 GB | Quick testing | | |
| | **Small** | 20% | ~120 | 2 minutes | 10-12 GB | Experiments β | | |
| | **Medium** | 50% | ~300 | 5-6 minutes | 12-14 GB | Good results | | |
| | **Large** | 100% | ~600 | 10-12 minutes | 14-16 GB | Best performance | | |
| ## Training Results (Small Dataset) | |
| ### Configuration | |
| - **Model**: Whisper-small (242M parameters) | |
| - **Training samples**: 109 | |
| - **Evaluation samples**: 13 | |
| - **Batch size**: 4 | |
| - **Learning rate**: 2e-05 | |
| - **Epochs**: 8 | |
| - **Mixed precision**: BF16 | |
| - **Flash Attention 2**: Enabled | |
| - **Gradient checkpointing**: Disabled | |
| ### Performance | |
| - **Training time**: ~2 minutes (119 seconds) | |
| - **Training speed**: 7.27 samples/second | |
| - **Final training loss**: 4684.90 | |
| - **Final evaluation loss**: 2490.13 | |
| ### Current Issues | |
| β οΈ **Model Performance**: The model trained on the small dataset (109 samples) shows poor inference quality, generating repetitive outputs. This is expected with such a small dataset. | |
| ## Recommendations for Better Results | |
| ### 1. Use Larger Dataset β **RECOMMENDED** | |
| ```bash | |
| # Run setup with medium or large dataset | |
| python project1_whisper_setup.py | |
| # Select 'medium' or 'large' when prompted | |
| ``` | |
| **Expected improvements:** | |
| - Medium (300 samples): 5-6 minutes training, significantly better quality | |
| - Large (600 samples): 10-12 minutes training, best quality | |
| ### 2. Adjust Training Parameters | |
| For larger datasets, the training script automatically adjusts: | |
| - Batch size: 4 | |
| - Gradient accumulation: 2 | |
| - Learning rate: 1e-5 | |
| - Epochs: 5 | |
| ### 3. Use Pre-trained Model for Inference | |
| If you need immediate results, use the base Whisper model: | |
| ```python | |
| from transformers import pipeline | |
| # Use base Whisper model (no fine-tuning needed) | |
| pipe = pipeline("automatic-speech-recognition", | |
| model="openai/whisper-small", | |
| device=0) # Use GPU | |
| result = pipe("audio.wav", generate_kwargs={"language": "german"}) | |
| print(result["text"]) | |
| ``` | |
| ## Recent Improvements (v2.0) | |
| ### Training Pipeline Enhancements | |
| β **Fixed Trainer API Issues** | |
| - Corrected `evaluation_strategy` parameter (was `eval_strategy`) | |
| - Fixed `tokenizer` parameter (was `processing_class`) | |
| - Added German language/task conditioning for proper decoder behavior | |
| β **Improved Hyperparameters** | |
| - Increased learning rates: 1e-5 to 2e-5 (was 5e-6) | |
| - Added warmup ratio (3-5%) for better convergence | |
| - Removed dtype conflicts (let Trainer control precision) | |
| - Optimized epochs by dataset size (8-15 epochs) | |
| β **Data Quality & Processing** | |
| - Duration filtering (0.5s - 30s) | |
| - Transcript length validation | |
| - Text normalization for consistent WER computation | |
| - Group by length for reduced padding | |
| β **Evaluation & Monitoring** | |
| - WER (Word Error Rate) metric with jiwer | |
| - TensorBoard logging for all metrics | |
| - Best model selection by WER (not just loss) | |
| - Predict with generate for proper evaluation | |
| ### Why Training Should Improve Now | |
| 1. **Proper evaluation**: WER tracking shows actual quality improvements | |
| 2. **Better learning rate**: 2-4x higher LR enables faster convergence | |
| 3. **Language conditioning**: Model knows it's transcribing German | |
| 4. **Data filtering**: Removes noisy/invalid samples that hurt training | |
| 5. **Best model selection**: Saves checkpoint with lowest WER, not just loss | |
| ## Installation | |
| ### 1. Install Dependencies | |
| ```bash | |
| pip install -r requirements.txt | |
| ``` | |
| ### 2. (Optional) Install Flash Attention 2 | |
| For faster training (requires CUDA toolkit): | |
| ```bash | |
| pip install flash-attn --no-build-isolation | |
| ``` | |
| ## Usage | |
| ### 1. Setup Dataset | |
| ```bash | |
| python project1_whisper_setup.py | |
| ``` | |
| Select dataset size when prompted (recommend 'medium' or 'large') | |
| ### 2. Train Model | |
| ```bash | |
| python project1_whisper_train.py | |
| ``` | |
| ### 3. Monitor Training with TensorBoard | |
| In a separate terminal, start TensorBoard: | |
| ```bash | |
| tensorboard --logdir=./logs | |
| ``` | |
| Then open http://localhost:6006 in your browser to view: | |
| - **Training/Evaluation Loss** - Track model convergence | |
| - **WER (Word Error Rate)** - Monitor transcription quality | |
| - **Learning Rate** - Visualize warmup and decay | |
| - **Gradient Norms** - Check training stability | |
| You can also monitor GPU usage: | |
| ```bash | |
| nvidia-smi -l 1 | |
| ``` | |
| ### 4. Test Model | |
| ```bash | |
| # Test with dataset samples | |
| python project1_whisper_inference.py --test --num-samples 10 | |
| # Transcribe specific audio files | |
| python project1_whisper_inference.py --audio file1.wav file2.wav | |
| # Interactive mode | |
| python project1_whisper_inference.py --interactive | |
| ``` | |
| ## Key Features | |
| ### Flash Attention 2 Integration | |
| - **Faster training**: 10-20% speedup | |
| - **Memory efficient**: No gradient checkpointing needed | |
| - **Stable training**: BF16 mixed precision | |
| ### Automatic Configuration | |
| The training script automatically adjusts parameters based on dataset size: | |
| - Batch size and gradient accumulation | |
| - Learning rate (1e-5 to 2e-5) and warmup ratio | |
| - Number of epochs (8-15) | |
| - Training time estimation | |
| ### Data Quality Filtering | |
| - **Duration filtering**: 0.5s to 30s audio clips | |
| - **Transcript validation**: Removes empty or too-long texts | |
| - **Quality checks**: Filters invalid audio samples | |
| - **Automatic normalization**: Consistent text preprocessing | |
| ### Evaluation & Metrics | |
| - **WER (Word Error Rate)**: Primary quality metric | |
| - **TensorBoard logging**: Real-time training visualization | |
| - **Best model selection**: Automatically saves best checkpoint by WER | |
| - **Predict with generate**: Proper sequence generation for evaluation | |
| ### Flexible Dataset Handling | |
| - Automatic train/validation split | |
| - Caches processed datasets | |
| - Supports different dataset sizes | |
| - Progress tracking and metrics | |
| - Group by length for efficient batching | |
| ## Performance Optimization | |
| ### Current Optimizations | |
| β Flash Attention 2 enabled | |
| β BF16 mixed precision | |
| β TF32 matrix operations | |
| β cuDNN auto-tuning | |
| β Automatic device placement | |
| ### Training Speed | |
| - **Small dataset (109 samples)**: ~2 minutes for 8 epochs | |
| - **Estimated for medium (300 samples)**: ~5-6 minutes for 5 epochs | |
| - **Estimated for large (600 samples)**: ~10-12 minutes for 5 epochs | |
| ## Next Steps | |
| ### Immediate Actions | |
| 1. **Retrain with larger dataset** (medium or large) for better results | |
| 2. **Evaluate model quality** with Word Error Rate (WER) metrics | |
| 3. **Test on real-world audio** samples | |
| ### Future Improvements | |
| 1. **Use larger Whisper model** (medium or large) for better accuracy | |
| 2. **Add data augmentation** (speed, pitch, noise) | |
| 3. **Create web interface** for easy testing | |
| 4. **Deploy model** as API service | |
| 5. **Push to Hugging Face Hub** for sharing and deployment | |
| ## Troubleshooting | |
| ### Common Issues | |
| **1. Model generates repetitive outputs** | |
| - **Cause**: Dataset too small (< 200 samples) | |
| - **Solution**: Use medium or large dataset | |
| **2. Out of memory errors** | |
| - **Cause**: Batch size too large | |
| - **Solution**: Reduce batch size in training script | |
| **3. Slow training** | |
| - **Cause**: Flash Attention 2 not enabled | |
| - **Solution**: Verify `flash-attn` is installed | |
| **4. Poor transcription quality** | |
| - **Cause**: Insufficient training data | |
| - **Solution**: Use larger dataset or more epochs | |
| ## Technical Details | |
| ### Model Architecture | |
| - **Base model**: OpenAI Whisper-small | |
| - **Parameters**: 242M | |
| - **Input**: 16kHz mono audio | |
| - **Output**: German text transcription | |
| ### Training Process | |
| 1. Load and preprocess audio (resample to 16kHz) | |
| 2. Extract mel-spectrogram features | |
| 3. Fine-tune encoder-decoder with teacher forcing | |
| 4. Evaluate on validation set each epoch | |
| 5. Save best checkpoint based on loss | |
| ### Generation Parameters | |
| ```python | |
| model.generate( | |
| input_features, | |
| max_length=448, | |
| num_beams=5, | |
| temperature=0.0, | |
| do_sample=False, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=3 | |
| ) | |
| ``` | |
| ## Resources | |
| - **Whisper Paper**: https://arxiv.org/abs/2212.04356 | |
| - **Hugging Face Transformers**: https://huggingface.co/docs/transformers | |
| - **Flash Attention 2**: https://github.com/Dao-AILab/flash-attention | |
| - **Dataset**: https://huggingface.co/datasets/PolyAI/minds14 | |
| ## License | |
| This project uses the MIT License. The Whisper model is licensed under Apache 2.0. | |
| ## Contact | |
| For questions or issues, please create an issue in the project repository. | |