zlaqa-version-c-ai-enginee / README_TRAINING.md
anfastech's picture
New: Phoneme-level speech pathology diagnosis MVP with real-time streaming
1cd6149

Training Guide for Speech Pathology Classifier

This guide explains how to train the classifier head for phoneme-level speech pathology detection.

Overview

The system uses Wav2Vec2-XLSR-53 as a frozen feature extractor and trains only the classification head (2-3 layer feedforward network) on phoneme-level labeled data.

Prerequisites

  1. Labeled Data: 50-100 audio samples with phoneme-level error annotations
  2. Python Environment: Python 3.10+ with required dependencies
  3. GPU (recommended): For faster training

Step 1: Data Collection

Using the Annotation Tool

  1. Launch the data collection interface:
python scripts/data_collection.py
  1. The Gradio interface will open at http://localhost:7861

  2. For each sample:

    • Upload or record audio (5-30 seconds, 16kHz WAV)
    • Enter expected text/transcript
    • Extract phonemes (automatic G2P conversion)
    • Annotate errors at phoneme level:
      • Frame ID where error occurs
      • Phoneme with error
      • Error type (substitution/omission/distortion/stutter)
      • Wrong sound (for substitutions)
      • Severity (0-1)
      • Timestamp
    • Add notes if needed
    • Save annotation
  3. Annotations are saved to:

    • Audio files: data/raw/
    • Annotations: data/annotations.json

Export Training Data

After collecting annotations, export for training:

python scripts/annotation_helper.py

This creates data/training_dataset.json with frame-level labels.

Step 2: Training

Configuration

Edit training/config.yaml to adjust hyperparameters:

  • batch_size: 16 (adjust based on GPU memory)
  • learning_rate: 0.001
  • num_epochs: 50
  • train_split: 0.8 (80% for training, 20% for validation)

Run Training

python training/train_classifier_head.py --config training/config.yaml

Training will:

  • Load training dataset
  • Extract Wav2Vec2 features for each sample
  • Train classifier head (Wav2Vec2 frozen)
  • Save best checkpoint to models/checkpoints/classifier_head_best.pt
  • Save last checkpoint to models/checkpoints/classifier_head_trained.pt

Monitor Training

Training logs include:

  • Loss per epoch
  • Accuracy per epoch
  • Validation metrics
  • Best model checkpoint saves

Step 3: Evaluation

Evaluate the trained model:

python training/evaluate_classifier.py \
    --checkpoint models/checkpoints/classifier_head_best.pt \
    --dataset data/training_dataset.json \
    --output training/evaluation_results.json \
    --plot training/confusion_matrix.png

This generates:

  • Overall accuracy, F1 score, precision, recall
  • Per-class accuracy
  • Confusion matrix (saved as PNG)
  • Confidence analysis
  • Detailed metrics JSON

Step 4: Deployment

Once trained, the model automatically loads trained weights on startup:

  1. Place checkpoint in models/checkpoints/classifier_head_best.pt
  2. Restart the application
  3. The model will detect and load trained weights automatically

Verify Training Status

Check API responses for:

  • model_version: "wav2vec2-xlsr-53-v2-trained" (if trained) or "wav2vec2-xlsr-53-v2-beta" (if untrained)
  • model_trained: true/false
  • confidence_filter_threshold: 0.65

Troubleshooting

Issue: "No training dataset found"

Solution: Run scripts/annotation_helper.py to export training data from annotations.

Issue: "CUDA out of memory"

Solution: Reduce batch_size in training/config.yaml (try 8 or 4).

Issue: "Poor validation accuracy"

Solutions:

  • Collect more training data (aim for 100+ samples)
  • Check data quality (ensure accurate annotations)
  • Adjust learning rate or add data augmentation
  • Use class weights for imbalanced data

Issue: "Model not loading trained weights"

Solution:

  • Verify checkpoint path: models/checkpoints/classifier_head_best.pt
  • Check file permissions
  • Review logs for loading errors

Best Practices

  1. Data Quality > Quantity: 50 high-quality samples > 100 poor samples
  2. Balanced Classes: Ensure all 8 classes have sufficient examples
  3. Validation Split: Use 20% for validation, never train on test data
  4. Early Stopping: Enabled by default to prevent overfitting
  5. Class Weights: Automatically calculated to handle imbalance
  6. Checkpointing: Best model saved automatically

Expected Results

After training with 50-100 samples:

  • Frame-level accuracy: >75%
  • Phoneme-level F1: >85%
  • Per-class accuracy: >70% for each class
  • Confidence: Higher for correct predictions

Next Steps

  1. Collect more data based on error patterns
  2. Fine-tune hyperparameters
  3. Add data augmentation
  4. Deploy and monitor in production
  5. Retrain quarterly with new data

Support

For issues or questions:

  • Check training logs in console
  • Review training/evaluation_results.json
  • Verify data format in data/annotations.json