Training Guide for Speech Pathology Classifier
This guide explains how to train the classifier head for phoneme-level speech pathology detection.
Overview
The system uses Wav2Vec2-XLSR-53 as a frozen feature extractor and trains only the classification head (2-3 layer feedforward network) on phoneme-level labeled data.
Prerequisites
- Labeled Data: 50-100 audio samples with phoneme-level error annotations
- Python Environment: Python 3.10+ with required dependencies
- GPU (recommended): For faster training
Step 1: Data Collection
Using the Annotation Tool
- Launch the data collection interface:
python scripts/data_collection.py
The Gradio interface will open at
http://localhost:7861For each sample:
- Upload or record audio (5-30 seconds, 16kHz WAV)
- Enter expected text/transcript
- Extract phonemes (automatic G2P conversion)
- Annotate errors at phoneme level:
- Frame ID where error occurs
- Phoneme with error
- Error type (substitution/omission/distortion/stutter)
- Wrong sound (for substitutions)
- Severity (0-1)
- Timestamp
- Add notes if needed
- Save annotation
Annotations are saved to:
- Audio files:
data/raw/ - Annotations:
data/annotations.json
- Audio files:
Export Training Data
After collecting annotations, export for training:
python scripts/annotation_helper.py
This creates data/training_dataset.json with frame-level labels.
Step 2: Training
Configuration
Edit training/config.yaml to adjust hyperparameters:
batch_size: 16 (adjust based on GPU memory)learning_rate: 0.001num_epochs: 50train_split: 0.8 (80% for training, 20% for validation)
Run Training
python training/train_classifier_head.py --config training/config.yaml
Training will:
- Load training dataset
- Extract Wav2Vec2 features for each sample
- Train classifier head (Wav2Vec2 frozen)
- Save best checkpoint to
models/checkpoints/classifier_head_best.pt - Save last checkpoint to
models/checkpoints/classifier_head_trained.pt
Monitor Training
Training logs include:
- Loss per epoch
- Accuracy per epoch
- Validation metrics
- Best model checkpoint saves
Step 3: Evaluation
Evaluate the trained model:
python training/evaluate_classifier.py \
--checkpoint models/checkpoints/classifier_head_best.pt \
--dataset data/training_dataset.json \
--output training/evaluation_results.json \
--plot training/confusion_matrix.png
This generates:
- Overall accuracy, F1 score, precision, recall
- Per-class accuracy
- Confusion matrix (saved as PNG)
- Confidence analysis
- Detailed metrics JSON
Step 4: Deployment
Once trained, the model automatically loads trained weights on startup:
- Place checkpoint in
models/checkpoints/classifier_head_best.pt - Restart the application
- The model will detect and load trained weights automatically
Verify Training Status
Check API responses for:
model_version: "wav2vec2-xlsr-53-v2-trained" (if trained) or "wav2vec2-xlsr-53-v2-beta" (if untrained)model_trained: true/falseconfidence_filter_threshold: 0.65
Troubleshooting
Issue: "No training dataset found"
Solution: Run scripts/annotation_helper.py to export training data from annotations.
Issue: "CUDA out of memory"
Solution: Reduce batch_size in training/config.yaml (try 8 or 4).
Issue: "Poor validation accuracy"
Solutions:
- Collect more training data (aim for 100+ samples)
- Check data quality (ensure accurate annotations)
- Adjust learning rate or add data augmentation
- Use class weights for imbalanced data
Issue: "Model not loading trained weights"
Solution:
- Verify checkpoint path:
models/checkpoints/classifier_head_best.pt - Check file permissions
- Review logs for loading errors
Best Practices
- Data Quality > Quantity: 50 high-quality samples > 100 poor samples
- Balanced Classes: Ensure all 8 classes have sufficient examples
- Validation Split: Use 20% for validation, never train on test data
- Early Stopping: Enabled by default to prevent overfitting
- Class Weights: Automatically calculated to handle imbalance
- Checkpointing: Best model saved automatically
Expected Results
After training with 50-100 samples:
- Frame-level accuracy: >75%
- Phoneme-level F1: >85%
- Per-class accuracy: >70% for each class
- Confidence: Higher for correct predictions
Next Steps
- Collect more data based on error patterns
- Fine-tune hyperparameters
- Add data augmentation
- Deploy and monitor in production
- Retrain quarterly with new data
Support
For issues or questions:
- Check training logs in console
- Review
training/evaluation_results.json - Verify data format in
data/annotations.json