ASR-finetuning / docs /guides /TRAINING_IMPROVEMENTS.md
saadmannan's picture
HF space application - exclude binary PDFs
5554ef1

A newer version of the Gradio SDK is available: 6.14.0

Upgrade

Whisper Training Pipeline - Improvements Summary

Overview

This document summarizes the comprehensive improvements made to the Whisper fine-tuning pipeline to fix training issues and enable proper evaluation.

Critical Fixes

1. Trainer API Issues (Breaking Bugs)

Problem: Training was using incorrect/deprecated API parameters Fixes:

  • βœ… Changed eval_strategy="epoch" β†’ evaluation_strategy="epoch"
    • Impact: Evaluation was never running during training
  • βœ… Changed processing_class=processor β†’ tokenizer=processor
    • Impact: Tokenizer wasn't properly saved with checkpoints
  • βœ… Added predict_with_generate=True
    • Impact: Enables proper sequence generation for WER evaluation

2. Language/Task Conditioning (Critical for Non-English)

Problem: Model wasn't conditioned for German transcription Fix:

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(
    language="german", 
    task="transcribe"
)
model.config.suppress_tokens = []

Impact:

  • Model now knows it's transcribing German
  • Decoder generates German text consistently
  • Training targets are properly aligned

3. Hyperparameter Issues

Learning Rate (Too Conservative)

Before: 5e-6 for all dataset sizes After:

  • Large datasets (>400): 2e-5
  • Medium datasets (100-400): 1.5e-5
  • Small datasets (<100): 1e-5

Impact: 2-4x higher learning rate enables actual learning with limited data

Warmup Strategy

Before: warmup_steps=min(100, len(train)//10) (could be 50%+ of training) After: warmup_ratio=0.03-0.05 (3-5% of total steps)

Impact: More stable warmup that scales with dataset size

Precision/Dtype Conflict

Before: Model loaded with torch_dtype=torch.float16, Trainer uses bf16=True After: Let Trainer control precision entirely

# Model loading - no dtype specified
model = WhisperForConditionalGeneration.from_pretrained(
    "openai/whisper-small",
    config=config,
    device_map="auto"
)

# Trainer handles precision
bf16=torch.cuda.is_bf16_supported()

Impact: Eliminates dtype mismatches and training instability

4. Data Quality Filtering

Added Filters:

  • βœ… Duration: 0.5s ≀ audio ≀ 30s
  • βœ… Transcript: Not empty, 2+ chars, <500 chars
  • βœ… Audio validation: Valid array and sampling rate
  • βœ… Text normalization: Lowercase, remove punctuation, strip whitespace

Impact: Removes noisy samples that can dominate small datasets

5. Evaluation & Metrics

Added:

  • βœ… WER (Word Error Rate) computation with jiwer
  • βœ… Text normalization for consistent metrics
  • βœ… Best model selection by WER (not just loss)
  • βœ… load_best_model_at_end=True
  • βœ… metric_for_best_model="wer"

Impact: Can now track actual transcription quality improvements

6. TensorBoard Logging

Added:

report_to=["tensorboard"]
logging_dir="./logs"
logging_steps=10
logging_first_step=True

Metrics Logged:

  • Training/Evaluation Loss
  • WER (Word Error Rate)
  • Learning Rate schedule
  • Gradient norms
  • Training speed

Usage:

tensorboard --logdir=./logs
# Open http://localhost:6006

7. Additional Optimizations

  • βœ… group_by_length=True - Reduces padding overhead
  • βœ… generation_max_length=448 - Full Whisper context (was 128)
  • βœ… Data filtering before preprocessing
  • βœ… Better epoch/batch size scaling by dataset size

Expected Improvements

Before (v1.0)

  • ❌ No evaluation running (API bug)
  • ❌ No language conditioning
  • ❌ LR too low (5e-6)
  • ❌ No WER tracking
  • ❌ No data filtering
  • ❌ Dtype conflicts
  • ❌ Model selection by loss only

Result: Training appeared to run but model didn't improve

After (v2.0)

  • βœ… Evaluation runs every epoch
  • βœ… German language/task conditioning
  • βœ… Proper LR (1e-5 to 2e-5)
  • βœ… WER metric tracking
  • βœ… Quality data filtering
  • βœ… Consistent precision
  • βœ… Best model by WER

Expected Result: Visible WER improvements, better transcription quality

Hugging Face Compatibility

Current Status: βœ… Fully Compatible

Using:

  • transformers.WhisperForConditionalGeneration
  • transformers.WhisperProcessor
  • transformers.Seq2SeqTrainer
  • datasets.load_dataset / load_from_disk
  • Standard HF checkpoint format

To Push to Hub:

# In TrainingArguments
push_to_hub=True
hub_model_id="your-username/whisper-small-german"
hub_token="your_hf_token"

# Or manually after training
model.push_to_hub("your-username/whisper-small-german")
processor.push_to_hub("your-username/whisper-small-german")

GitHub Readiness

Added Files

  • βœ… requirements.txt - All dependencies with versions
  • βœ… Updated README_WHISPER_PROJECT.md - Installation, usage, TensorBoard
  • βœ… TRAINING_IMPROVEMENTS.md - This document

Reproducibility

  • βœ… Pinned dependency versions
  • βœ… Seed set to 42
  • βœ… Clear installation instructions
  • βœ… Dataset download script
  • βœ… Training/inference scripts

Missing (Optional)

  • .gitignore for checkpoints/logs
  • LICENSE file
  • GitHub Actions for CI/CD
  • Model card template

Data Processing vs Whisper Paper

Whisper Paper Approach

  • 30-second audio chunks
  • 80-channel log-mel spectrogram
  • 16kHz sampling rate
  • Padding/truncation to 30s

Our Implementation: βœ… Matches Paper

# WhisperProcessor handles this automatically
input_features = processor(
    audio_array,           # Raw audio
    sampling_rate=16000,   # 16kHz βœ…
    return_tensors="pt"
).input_features          # Returns 80x3000 mel spectrogram βœ…

What happens:

  1. Audio resampled to 16kHz βœ…
  2. Converted to 80-channel log-mel spectrogram βœ…
  3. Padded/truncated to 3000 frames (30s at 16kHz) βœ…
  4. Normalized βœ…

For longer audio: Would need sliding window with stride (not needed for MINDS14)

Next Steps

Immediate

  1. Install dependencies: pip install -r requirements.txt
  2. Retrain model: python project1_whisper_train.py
  3. Monitor with TensorBoard: tensorboard --logdir=./logs
  4. Check WER improvements: Should see decreasing WER each epoch

Recommended

  1. Use medium or large dataset (300-600 samples)
  2. Monitor TensorBoard for convergence
  3. Compare WER across epochs
  4. Test on real-world German audio

Advanced

  1. Try Whisper-medium for better quality
  2. Add data augmentation (SpecAugment)
  3. Push best model to Hugging Face Hub
  4. Create demo/API endpoint

Summary

Root Causes of "No Learning":

  1. Evaluation never ran (API typo)
  2. No language conditioning for German
  3. Learning rate too conservative
  4. No quality metrics (WER)
  5. Dtype conflicts

All Fixed: Training should now show measurable WER improvements and produce usable German ASR models.