New: Phoneme-level speech pathology diagnosis MVP with real-time streaming
Browse files- Add Wav2Vec2-XLSR-53 based speech pathology classifier with 8-class output (fluency + articulation)
- Implement phone-level feature extraction with 1-second sliding windows and 10ms hops
- Add grapheme-to-phoneme mapping using g2p_en library with frame alignment
- Create error taxonomy system with substitution/omission/distortion detection and therapy recommendations
- Build FastAPI REST API with batch diagnosis endpoints (/diagnose/file) and WebSocket streaming (/ws/diagnose)
- Add Gradio web interface with audio upload/recording, real-time error display, and detailed reports
- Implement training infrastructure: synthetic data generation, classifier head training, and evaluation scripts
- Add Docker containerization with NLTK data download for phoneme mapping
- Fix model loading issues and improve error handling throughout the pipeline
- Remove unnecessary package-lock.json file (Python project)
Features:
- Real-time streaming analysis (<200ms per file, <50ms per frame target)
- Phoneme-level error detection with visual feedback
- Therapy recommendation system with clinical guidance
- Comprehensive error reporting with severity levels and timelines
- WebSocket-based real-time diagnosis for live audio streams
- REST API for batch processing with detailed JSON responses
Infrastructure:
- Training pipeline for classifier fine-tuning
- Data collection tools for phoneme-level annotation
- Performance testing and integration testing suites
- Production-ready logging and error handling
- .gitignore +1 -1
- Dockerfile +3 -0
- README_TRAINING.md +176 -0
- api/routes.py +18 -1
- api/schemas.py +3 -0
- app.py +105 -103
- inference/__init__.py +5 -5
- models/phoneme_mapper.py +12 -0
- models/speech_pathology_model.py +54 -0
- requirements.txt +7 -0
- scripts/annotation_helper.py +146 -0
- scripts/data_collection.py +418 -0
- training/config.yaml +86 -0
- training/evaluate_classifier.py +282 -0
- training/train_classifier_head.py +469 -0
- ui/gradio_interface.py +218 -6
|
@@ -7,5 +7,5 @@ __pycache__/
|
|
| 7 |
|
| 8 |
.gradio/
|
| 9 |
.cursor/
|
| 10 |
-
package-lock.json
|
| 11 |
docker-compose.yml
|
|
|
|
| 7 |
|
| 8 |
.gradio/
|
| 9 |
.cursor/
|
| 10 |
+
# package-lock.json
|
| 11 |
docker-compose.yml
|
|
@@ -27,6 +27,9 @@ RUN pip install --no-cache-dir \
|
|
| 27 |
# Install the rest of requirements
|
| 28 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 29 |
|
|
|
|
|
|
|
|
|
|
| 30 |
# Copy application files
|
| 31 |
COPY . .
|
| 32 |
|
|
|
|
| 27 |
# Install the rest of requirements
|
| 28 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 29 |
|
| 30 |
+
# Download NLTK data required by g2p_en
|
| 31 |
+
RUN python -c "import nltk; nltk.download('averaged_perceptron_tagger_eng', quiet=True)"
|
| 32 |
+
|
| 33 |
# Copy application files
|
| 34 |
COPY . .
|
| 35 |
|
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Guide for Speech Pathology Classifier
|
| 2 |
+
|
| 3 |
+
This guide explains how to train the classifier head for phoneme-level speech pathology detection.
|
| 4 |
+
|
| 5 |
+
## Overview
|
| 6 |
+
|
| 7 |
+
The system uses Wav2Vec2-XLSR-53 as a frozen feature extractor and trains only the classification head (2-3 layer feedforward network) on phoneme-level labeled data.
|
| 8 |
+
|
| 9 |
+
## Prerequisites
|
| 10 |
+
|
| 11 |
+
1. **Labeled Data**: 50-100 audio samples with phoneme-level error annotations
|
| 12 |
+
2. **Python Environment**: Python 3.10+ with required dependencies
|
| 13 |
+
3. **GPU** (recommended): For faster training
|
| 14 |
+
|
| 15 |
+
## Step 1: Data Collection
|
| 16 |
+
|
| 17 |
+
### Using the Annotation Tool
|
| 18 |
+
|
| 19 |
+
1. Launch the data collection interface:
|
| 20 |
+
```bash
|
| 21 |
+
python scripts/data_collection.py
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
2. The Gradio interface will open at `http://localhost:7861`
|
| 25 |
+
|
| 26 |
+
3. For each sample:
|
| 27 |
+
- Upload or record audio (5-30 seconds, 16kHz WAV)
|
| 28 |
+
- Enter expected text/transcript
|
| 29 |
+
- Extract phonemes (automatic G2P conversion)
|
| 30 |
+
- Annotate errors at phoneme level:
|
| 31 |
+
- Frame ID where error occurs
|
| 32 |
+
- Phoneme with error
|
| 33 |
+
- Error type (substitution/omission/distortion/stutter)
|
| 34 |
+
- Wrong sound (for substitutions)
|
| 35 |
+
- Severity (0-1)
|
| 36 |
+
- Timestamp
|
| 37 |
+
- Add notes if needed
|
| 38 |
+
- Save annotation
|
| 39 |
+
|
| 40 |
+
4. Annotations are saved to:
|
| 41 |
+
- Audio files: `data/raw/`
|
| 42 |
+
- Annotations: `data/annotations.json`
|
| 43 |
+
|
| 44 |
+
### Export Training Data
|
| 45 |
+
|
| 46 |
+
After collecting annotations, export for training:
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
python scripts/annotation_helper.py
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
This creates `data/training_dataset.json` with frame-level labels.
|
| 53 |
+
|
| 54 |
+
## Step 2: Training
|
| 55 |
+
|
| 56 |
+
### Configuration
|
| 57 |
+
|
| 58 |
+
Edit `training/config.yaml` to adjust hyperparameters:
|
| 59 |
+
|
| 60 |
+
- `batch_size`: 16 (adjust based on GPU memory)
|
| 61 |
+
- `learning_rate`: 0.001
|
| 62 |
+
- `num_epochs`: 50
|
| 63 |
+
- `train_split`: 0.8 (80% for training, 20% for validation)
|
| 64 |
+
|
| 65 |
+
### Run Training
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
python training/train_classifier_head.py --config training/config.yaml
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Training will:
|
| 72 |
+
- Load training dataset
|
| 73 |
+
- Extract Wav2Vec2 features for each sample
|
| 74 |
+
- Train classifier head (Wav2Vec2 frozen)
|
| 75 |
+
- Save best checkpoint to `models/checkpoints/classifier_head_best.pt`
|
| 76 |
+
- Save last checkpoint to `models/checkpoints/classifier_head_trained.pt`
|
| 77 |
+
|
| 78 |
+
### Monitor Training
|
| 79 |
+
|
| 80 |
+
Training logs include:
|
| 81 |
+
- Loss per epoch
|
| 82 |
+
- Accuracy per epoch
|
| 83 |
+
- Validation metrics
|
| 84 |
+
- Best model checkpoint saves
|
| 85 |
+
|
| 86 |
+
## Step 3: Evaluation
|
| 87 |
+
|
| 88 |
+
Evaluate the trained model:
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
python training/evaluate_classifier.py \
|
| 92 |
+
--checkpoint models/checkpoints/classifier_head_best.pt \
|
| 93 |
+
--dataset data/training_dataset.json \
|
| 94 |
+
--output training/evaluation_results.json \
|
| 95 |
+
--plot training/confusion_matrix.png
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
This generates:
|
| 99 |
+
- Overall accuracy, F1 score, precision, recall
|
| 100 |
+
- Per-class accuracy
|
| 101 |
+
- Confusion matrix (saved as PNG)
|
| 102 |
+
- Confidence analysis
|
| 103 |
+
- Detailed metrics JSON
|
| 104 |
+
|
| 105 |
+
## Step 4: Deployment
|
| 106 |
+
|
| 107 |
+
Once trained, the model automatically loads trained weights on startup:
|
| 108 |
+
|
| 109 |
+
1. Place checkpoint in `models/checkpoints/classifier_head_best.pt`
|
| 110 |
+
2. Restart the application
|
| 111 |
+
3. The model will detect and load trained weights automatically
|
| 112 |
+
|
| 113 |
+
### Verify Training Status
|
| 114 |
+
|
| 115 |
+
Check API responses for:
|
| 116 |
+
- `model_version`: "wav2vec2-xlsr-53-v2-trained" (if trained) or "wav2vec2-xlsr-53-v2-beta" (if untrained)
|
| 117 |
+
- `model_trained`: true/false
|
| 118 |
+
- `confidence_filter_threshold`: 0.65
|
| 119 |
+
|
| 120 |
+
## Troubleshooting
|
| 121 |
+
|
| 122 |
+
### Issue: "No training dataset found"
|
| 123 |
+
|
| 124 |
+
**Solution**: Run `scripts/annotation_helper.py` to export training data from annotations.
|
| 125 |
+
|
| 126 |
+
### Issue: "CUDA out of memory"
|
| 127 |
+
|
| 128 |
+
**Solution**: Reduce `batch_size` in `training/config.yaml` (try 8 or 4).
|
| 129 |
+
|
| 130 |
+
### Issue: "Poor validation accuracy"
|
| 131 |
+
|
| 132 |
+
**Solutions**:
|
| 133 |
+
- Collect more training data (aim for 100+ samples)
|
| 134 |
+
- Check data quality (ensure accurate annotations)
|
| 135 |
+
- Adjust learning rate or add data augmentation
|
| 136 |
+
- Use class weights for imbalanced data
|
| 137 |
+
|
| 138 |
+
### Issue: "Model not loading trained weights"
|
| 139 |
+
|
| 140 |
+
**Solution**:
|
| 141 |
+
- Verify checkpoint path: `models/checkpoints/classifier_head_best.pt`
|
| 142 |
+
- Check file permissions
|
| 143 |
+
- Review logs for loading errors
|
| 144 |
+
|
| 145 |
+
## Best Practices
|
| 146 |
+
|
| 147 |
+
1. **Data Quality > Quantity**: 50 high-quality samples > 100 poor samples
|
| 148 |
+
2. **Balanced Classes**: Ensure all 8 classes have sufficient examples
|
| 149 |
+
3. **Validation Split**: Use 20% for validation, never train on test data
|
| 150 |
+
4. **Early Stopping**: Enabled by default to prevent overfitting
|
| 151 |
+
5. **Class Weights**: Automatically calculated to handle imbalance
|
| 152 |
+
6. **Checkpointing**: Best model saved automatically
|
| 153 |
+
|
| 154 |
+
## Expected Results
|
| 155 |
+
|
| 156 |
+
After training with 50-100 samples:
|
| 157 |
+
- **Frame-level accuracy**: >75%
|
| 158 |
+
- **Phoneme-level F1**: >85%
|
| 159 |
+
- **Per-class accuracy**: >70% for each class
|
| 160 |
+
- **Confidence**: Higher for correct predictions
|
| 161 |
+
|
| 162 |
+
## Next Steps
|
| 163 |
+
|
| 164 |
+
1. Collect more data based on error patterns
|
| 165 |
+
2. Fine-tune hyperparameters
|
| 166 |
+
3. Add data augmentation
|
| 167 |
+
4. Deploy and monitor in production
|
| 168 |
+
5. Retrain quarterly with new data
|
| 169 |
+
|
| 170 |
+
## Support
|
| 171 |
+
|
| 172 |
+
For issues or questions:
|
| 173 |
+
- Check training logs in console
|
| 174 |
+
- Review `training/evaluation_results.json`
|
| 175 |
+
- Verify data format in `data/annotations.json`
|
| 176 |
+
|
|
@@ -47,6 +47,16 @@ phoneme_mapper: Optional[PhonemeMapper] = None
|
|
| 47 |
error_mapper: Optional[ErrorMapper] = None
|
| 48 |
|
| 49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
def initialize_routes(
|
| 51 |
pipeline: InferencePipeline,
|
| 52 |
mapper: Optional[PhonemeMapper] = None,
|
|
@@ -263,6 +273,10 @@ async def diagnose_file(
|
|
| 263 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 264 |
|
| 265 |
# Create response
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
response = BatchDiagnosisResponse(
|
| 267 |
session_id=session_id,
|
| 268 |
filename=audio.filename or "unknown",
|
|
@@ -274,7 +288,10 @@ async def diagnose_file(
|
|
| 274 |
summary=summary,
|
| 275 |
therapy_plan=therapy_plan,
|
| 276 |
processing_time_ms=processing_time_ms,
|
| 277 |
-
created_at=datetime.
|
|
|
|
|
|
|
|
|
|
| 278 |
)
|
| 279 |
|
| 280 |
# Store in sessions
|
|
|
|
| 47 |
error_mapper: Optional[ErrorMapper] = None
|
| 48 |
|
| 49 |
|
| 50 |
+
def get_phoneme_mapper() -> Optional[PhonemeMapper]:
|
| 51 |
+
"""Get the global PhonemeMapper instance."""
|
| 52 |
+
return phoneme_mapper
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def get_error_mapper() -> Optional[ErrorMapper]:
|
| 56 |
+
"""Get the global ErrorMapper instance."""
|
| 57 |
+
return error_mapper
|
| 58 |
+
|
| 59 |
+
|
| 60 |
def initialize_routes(
|
| 61 |
pipeline: InferencePipeline,
|
| 62 |
mapper: Optional[PhonemeMapper] = None,
|
|
|
|
| 273 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 274 |
|
| 275 |
# Create response
|
| 276 |
+
# Check if model is trained
|
| 277 |
+
model_trained = inference_pipeline.model.is_trained if hasattr(inference_pipeline.model, 'is_trained') else False
|
| 278 |
+
model_version = "wav2vec2-xlsr-53-v2-trained" if model_trained else "wav2vec2-xlsr-53-v2-beta"
|
| 279 |
+
|
| 280 |
response = BatchDiagnosisResponse(
|
| 281 |
session_id=session_id,
|
| 282 |
filename=audio.filename or "unknown",
|
|
|
|
| 288 |
summary=summary,
|
| 289 |
therapy_plan=therapy_plan,
|
| 290 |
processing_time_ms=processing_time_ms,
|
| 291 |
+
created_at=datetime.utcnow(),
|
| 292 |
+
model_version=model_version,
|
| 293 |
+
model_trained=model_trained,
|
| 294 |
+
confidence_filter_threshold=0.65
|
| 295 |
)
|
| 296 |
|
| 297 |
# Store in sessions
|
|
@@ -76,6 +76,9 @@ class BatchDiagnosisResponse(BaseModel):
|
|
| 76 |
therapy_plan: List[str] = Field(default_factory=list, description="Therapy recommendations")
|
| 77 |
processing_time_ms: float = Field(..., ge=0.0, description="Processing time in milliseconds")
|
| 78 |
created_at: datetime = Field(default_factory=datetime.now, description="Analysis timestamp")
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
class StreamingDiagnosisRequest(BaseModel):
|
|
|
|
| 76 |
therapy_plan: List[str] = Field(default_factory=list, description="Therapy recommendations")
|
| 77 |
processing_time_ms: float = Field(..., ge=0.0, description="Processing time in milliseconds")
|
| 78 |
created_at: datetime = Field(default_factory=datetime.now, description="Analysis timestamp")
|
| 79 |
+
model_version: str = Field(default="wav2vec2-xlsr-53-v2", description="Model version identifier")
|
| 80 |
+
model_trained: bool = Field(default=False, description="Whether classifier head is trained")
|
| 81 |
+
confidence_filter_threshold: float = Field(default=0.65, ge=0.0, le=1.0, description="Confidence threshold for filtering predictions")
|
| 82 |
|
| 83 |
|
| 84 |
class StreamingDiagnosisRequest(BaseModel):
|
|
@@ -7,7 +7,7 @@ from pathlib import Path
|
|
| 7 |
from datetime import datetime
|
| 8 |
from typing import Optional
|
| 9 |
|
| 10 |
-
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, WebSocket, WebSocketDisconnect
|
| 11 |
from fastapi.responses import JSONResponse
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
import gradio as gr
|
|
@@ -26,8 +26,7 @@ sys.path.insert(0, str(Path(__file__).parent))
|
|
| 26 |
# Import model loaders and inference pipeline
|
| 27 |
try:
|
| 28 |
from diagnosis.ai_engine.model_loader import (
|
| 29 |
-
|
| 30 |
-
get_inference_pipeline # New inference pipeline
|
| 31 |
)
|
| 32 |
from ui.gradio_interface import create_gradio_interface
|
| 33 |
from config import APIConfig, GradioConfig, default_api_config, default_gradio_config
|
|
@@ -53,27 +52,30 @@ app.add_middleware(
|
|
| 53 |
)
|
| 54 |
|
| 55 |
# Global instances
|
| 56 |
-
|
| 57 |
-
inference_pipeline = None # New inference pipeline
|
| 58 |
|
| 59 |
@app.on_event("startup")
|
| 60 |
async def startup_event():
|
| 61 |
"""Load models on startup"""
|
| 62 |
-
global
|
| 63 |
try:
|
| 64 |
logger.info("π Startup event: Loading AI models...")
|
| 65 |
|
| 66 |
-
# Load
|
| 67 |
-
try:
|
| 68 |
-
detector = get_stutter_detector()
|
| 69 |
-
logger.info("β
Legacy detector loaded")
|
| 70 |
-
except Exception as e:
|
| 71 |
-
logger.warning(f"β οΈ Legacy detector not available: {e}")
|
| 72 |
-
|
| 73 |
-
# Load new inference pipeline
|
| 74 |
try:
|
| 75 |
inference_pipeline = get_inference_pipeline()
|
| 76 |
logger.info("β
Inference pipeline loaded")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
except Exception as e:
|
| 78 |
logger.error(f"β Failed to load inference pipeline: {e}", exc_info=True)
|
| 79 |
# Don't raise - allow API to start even if new pipeline fails
|
|
@@ -83,6 +85,24 @@ async def startup_event():
|
|
| 83 |
logger.error(f"β Failed to load models: {e}", exc_info=True)
|
| 84 |
raise
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
# Create and mount new Gradio interface
|
| 87 |
try:
|
| 88 |
gradio_interface = create_gradio_interface(default_gradio_config)
|
|
@@ -103,31 +123,30 @@ async def health_check():
|
|
| 103 |
return {
|
| 104 |
"status": "healthy",
|
| 105 |
"models_loaded": {
|
| 106 |
-
"
|
| 107 |
-
"
|
| 108 |
},
|
| 109 |
"timestamp": datetime.utcnow().isoformat() + "Z"
|
| 110 |
}
|
| 111 |
|
| 112 |
@app.post("/api/diagnose")
|
| 113 |
async def diagnose_speech(
|
| 114 |
-
audio: UploadFile = File(...)
|
|
|
|
| 115 |
):
|
| 116 |
"""
|
| 117 |
-
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
| 120 |
|
| 121 |
Parameters:
|
| 122 |
- audio: Audio file (WAV, MP3, FLAC, M4A)
|
|
|
|
| 123 |
|
| 124 |
Returns:
|
| 125 |
-
Dictionary with
|
| 126 |
-
- status: "success" or "error"
|
| 127 |
-
- fluency_metrics: Fluency statistics
|
| 128 |
-
- articulation_results: Articulation analysis
|
| 129 |
-
- confidence: Overall confidence score
|
| 130 |
-
- processing_time_ms: Processing time in milliseconds
|
| 131 |
"""
|
| 132 |
if not inference_pipeline:
|
| 133 |
raise HTTPException(
|
|
@@ -135,11 +154,15 @@ async def diagnose_speech(
|
|
| 135 |
detail="Inference pipeline not loaded yet. Try again in a moment."
|
| 136 |
)
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
start_time = time.time()
|
| 139 |
temp_file = None
|
| 140 |
|
| 141 |
try:
|
| 142 |
-
logger.info(f"π₯ Processing diagnosis request: {audio.filename}")
|
| 143 |
|
| 144 |
# Validate file extension
|
| 145 |
file_ext = Path(audio.filename).suffix.lower()
|
|
@@ -173,7 +196,6 @@ async def diagnose_speech(
|
|
| 173 |
|
| 174 |
# Run inference
|
| 175 |
logger.info("π Running inference pipeline...")
|
| 176 |
-
# Use new phone-level prediction
|
| 177 |
result = inference_pipeline.predict_phone_level(
|
| 178 |
temp_file,
|
| 179 |
return_timestamps=True
|
|
@@ -181,28 +203,64 @@ async def diagnose_speech(
|
|
| 181 |
|
| 182 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 183 |
|
| 184 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
aggregate = result.aggregate
|
| 186 |
mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
|
| 187 |
-
fluency_percentage = (1.0 - mean_fluency_stutter) * 100
|
| 188 |
|
| 189 |
-
# Count fluent frames
|
| 190 |
fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
|
| 191 |
fluent_frames_ratio = fluent_frames / result.num_frames if result.num_frames > 0 else 0.0
|
| 192 |
|
| 193 |
-
# Extract articulation class distribution
|
| 194 |
articulation_class_counts = {}
|
| 195 |
for fp in result.frame_predictions:
|
| 196 |
label = fp.articulation_label
|
| 197 |
articulation_class_counts[label] = articulation_class_counts.get(label, 0) + 1
|
| 198 |
|
| 199 |
-
# Get dominant articulation class
|
| 200 |
dominant_articulation = aggregate.get("articulation_label", "normal")
|
| 201 |
-
|
| 202 |
-
# Calculate average confidence
|
| 203 |
avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
|
| 204 |
|
| 205 |
-
# Format response
|
| 206 |
response = {
|
| 207 |
"status": "success",
|
| 208 |
"fluency_metrics": {
|
|
@@ -225,9 +283,10 @@ async def diagnose_speech(
|
|
| 225 |
"fluency_label": fp.fluency_label,
|
| 226 |
"articulation_class": fp.articulation_class,
|
| 227 |
"articulation_label": fp.articulation_label,
|
| 228 |
-
"confidence": fp.confidence
|
|
|
|
| 229 |
}
|
| 230 |
-
for fp in result.frame_predictions
|
| 231 |
]
|
| 232 |
},
|
| 233 |
"confidence": avg_confidence,
|
|
@@ -235,8 +294,14 @@ async def diagnose_speech(
|
|
| 235 |
"processing_time_ms": processing_time_ms
|
| 236 |
}
|
| 237 |
|
| 238 |
-
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
f"time={processing_time_ms:.0f}ms")
|
| 241 |
|
| 242 |
return response
|
|
@@ -257,70 +322,7 @@ async def diagnose_speech(
|
|
| 257 |
logger.warning(f"Could not clean up {temp_file}: {e}")
|
| 258 |
|
| 259 |
|
| 260 |
-
|
| 261 |
-
async def analyze_audio(
|
| 262 |
-
audio: UploadFile = File(...),
|
| 263 |
-
transcript: str = Form("")
|
| 264 |
-
):
|
| 265 |
-
"""
|
| 266 |
-
Legacy endpoint: Analyze audio file for stuttering.
|
| 267 |
-
|
| 268 |
-
Uses the legacy Whisper-based detector for backward compatibility.
|
| 269 |
-
|
| 270 |
-
Parameters:
|
| 271 |
-
- audio: WAV or MP3 audio file
|
| 272 |
-
- transcript: Optional expected transcript
|
| 273 |
-
|
| 274 |
-
Returns: Complete stutter analysis results
|
| 275 |
-
"""
|
| 276 |
-
if not detector:
|
| 277 |
-
raise HTTPException(
|
| 278 |
-
status_code=503,
|
| 279 |
-
detail="Legacy detector not loaded. Use /api/diagnose for new analysis."
|
| 280 |
-
)
|
| 281 |
-
|
| 282 |
-
temp_file = None
|
| 283 |
-
try:
|
| 284 |
-
logger.info(f"π₯ Processing (legacy): {audio.filename}")
|
| 285 |
-
|
| 286 |
-
# Create temp directory if needed
|
| 287 |
-
temp_dir = tempfile.gettempdir()
|
| 288 |
-
os.makedirs(temp_dir, exist_ok=True)
|
| 289 |
-
|
| 290 |
-
# Save uploaded file
|
| 291 |
-
temp_file = os.path.join(temp_dir, f"legacy_{int(time.time())}_{audio.filename}")
|
| 292 |
-
content = await audio.read()
|
| 293 |
-
|
| 294 |
-
with open(temp_file, "wb") as f:
|
| 295 |
-
f.write(content)
|
| 296 |
-
|
| 297 |
-
logger.info(f"π Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
|
| 298 |
-
|
| 299 |
-
# Analyze
|
| 300 |
-
logger.info(f"π Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'")
|
| 301 |
-
result = detector.analyze_audio(temp_file, transcript)
|
| 302 |
-
|
| 303 |
-
actual = result.get('actual_transcript', '')
|
| 304 |
-
target = result.get('target_transcript', '')
|
| 305 |
-
logger.info(f"β
Analysis complete: severity={result.get('severity', 'N/A')}, "
|
| 306 |
-
f"mismatch={result.get('mismatch_percentage', 'N/A')}%")
|
| 307 |
-
|
| 308 |
-
return result
|
| 309 |
-
|
| 310 |
-
except HTTPException:
|
| 311 |
-
raise
|
| 312 |
-
except Exception as e:
|
| 313 |
-
logger.error(f"β Error during analysis: {str(e)}", exc_info=True)
|
| 314 |
-
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
|
| 315 |
-
|
| 316 |
-
finally:
|
| 317 |
-
# Cleanup
|
| 318 |
-
if temp_file and os.path.exists(temp_file):
|
| 319 |
-
try:
|
| 320 |
-
os.remove(temp_file)
|
| 321 |
-
logger.debug(f"π§Ή Cleaned up: {temp_file}")
|
| 322 |
-
except Exception as e:
|
| 323 |
-
logger.warning(f"Could not clean up {temp_file}: {e}")
|
| 324 |
|
| 325 |
|
| 326 |
@app.websocket("/ws/audio")
|
|
|
|
| 7 |
from datetime import datetime
|
| 8 |
from typing import Optional
|
| 9 |
|
| 10 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, WebSocket, WebSocketDisconnect, Query
|
| 11 |
from fastapi.responses import JSONResponse
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
import gradio as gr
|
|
|
|
| 26 |
# Import model loaders and inference pipeline
|
| 27 |
try:
|
| 28 |
from diagnosis.ai_engine.model_loader import (
|
| 29 |
+
get_inference_pipeline # Wav2Vec2-based inference pipeline
|
|
|
|
| 30 |
)
|
| 31 |
from ui.gradio_interface import create_gradio_interface
|
| 32 |
from config import APIConfig, GradioConfig, default_api_config, default_gradio_config
|
|
|
|
| 52 |
)
|
| 53 |
|
| 54 |
# Global instances
|
| 55 |
+
inference_pipeline = None # Wav2Vec2-based inference pipeline
|
|
|
|
| 56 |
|
| 57 |
@app.on_event("startup")
|
| 58 |
async def startup_event():
|
| 59 |
"""Load models on startup"""
|
| 60 |
+
global inference_pipeline
|
| 61 |
try:
|
| 62 |
logger.info("π Startup event: Loading AI models...")
|
| 63 |
|
| 64 |
+
# Load Wav2Vec2-based inference pipeline
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
try:
|
| 66 |
inference_pipeline = get_inference_pipeline()
|
| 67 |
logger.info("β
Inference pipeline loaded")
|
| 68 |
+
|
| 69 |
+
# Initialize API routes with phoneme and error mappers
|
| 70 |
+
try:
|
| 71 |
+
from api.routes import initialize_routes
|
| 72 |
+
from api.streaming import initialize_streaming
|
| 73 |
+
initialize_routes(inference_pipeline)
|
| 74 |
+
initialize_streaming(inference_pipeline)
|
| 75 |
+
logger.info("β
API routes initialized with phoneme/error mappers")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.warning(f"β οΈ API routes initialization failed: {e}", exc_info=True)
|
| 78 |
+
# Continue without phoneme mapping if it fails
|
| 79 |
except Exception as e:
|
| 80 |
logger.error(f"β Failed to load inference pipeline: {e}", exc_info=True)
|
| 81 |
# Don't raise - allow API to start even if new pipeline fails
|
|
|
|
| 85 |
logger.error(f"β Failed to load models: {e}", exc_info=True)
|
| 86 |
raise
|
| 87 |
|
| 88 |
+
# Include API routers
|
| 89 |
+
try:
|
| 90 |
+
from api.routes import router as diagnose_router
|
| 91 |
+
app.include_router(diagnose_router)
|
| 92 |
+
logger.info("β
Diagnosis router included")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
logger.warning(f"β οΈ Failed to include diagnosis router: {e}")
|
| 95 |
+
|
| 96 |
+
# Add WebSocket endpoint
|
| 97 |
+
try:
|
| 98 |
+
from api.streaming import handle_streaming_websocket
|
| 99 |
+
@app.websocket("/ws/diagnose")
|
| 100 |
+
async def websocket_diagnose(websocket: WebSocket, session_id: Optional[str] = None):
|
| 101 |
+
await handle_streaming_websocket(websocket, session_id)
|
| 102 |
+
logger.info("β
WebSocket endpoint registered")
|
| 103 |
+
except Exception as e:
|
| 104 |
+
logger.warning(f"β οΈ Failed to register WebSocket endpoint: {e}")
|
| 105 |
+
|
| 106 |
# Create and mount new Gradio interface
|
| 107 |
try:
|
| 108 |
gradio_interface = create_gradio_interface(default_gradio_config)
|
|
|
|
| 123 |
return {
|
| 124 |
"status": "healthy",
|
| 125 |
"models_loaded": {
|
| 126 |
+
"inference_pipeline": inference_pipeline is not None,
|
| 127 |
+
"model_version": "wav2vec2-xlsr-53-v2"
|
| 128 |
},
|
| 129 |
"timestamp": datetime.utcnow().isoformat() + "Z"
|
| 130 |
}
|
| 131 |
|
| 132 |
@app.post("/api/diagnose")
|
| 133 |
async def diagnose_speech(
|
| 134 |
+
audio: UploadFile = File(...),
|
| 135 |
+
text: Optional[str] = Query(None, description="Expected text/transcript for phoneme mapping (optional)")
|
| 136 |
):
|
| 137 |
"""
|
| 138 |
+
Legacy endpoint for speech diagnosis.
|
| 139 |
|
| 140 |
+
NOTE: For full phoneme-level error detection with therapy recommendations,
|
| 141 |
+
use POST /diagnose/file?text=<expected_text> instead.
|
| 142 |
+
This endpoint is maintained for backward compatibility.
|
| 143 |
|
| 144 |
Parameters:
|
| 145 |
- audio: Audio file (WAV, MP3, FLAC, M4A)
|
| 146 |
+
- text: Optional expected text for phoneme mapping
|
| 147 |
|
| 148 |
Returns:
|
| 149 |
+
Dictionary with diagnosis results (legacy format for backward compatibility)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
"""
|
| 151 |
if not inference_pipeline:
|
| 152 |
raise HTTPException(
|
|
|
|
| 154 |
detail="Inference pipeline not loaded yet. Try again in a moment."
|
| 155 |
)
|
| 156 |
|
| 157 |
+
# Import here to avoid circular imports
|
| 158 |
+
from api.routes import get_phoneme_mapper, get_error_mapper
|
| 159 |
+
from models.error_taxonomy import ErrorType
|
| 160 |
+
|
| 161 |
start_time = time.time()
|
| 162 |
temp_file = None
|
| 163 |
|
| 164 |
try:
|
| 165 |
+
logger.info(f"π₯ Processing legacy diagnosis request: {audio.filename}")
|
| 166 |
|
| 167 |
# Validate file extension
|
| 168 |
file_ext = Path(audio.filename).suffix.lower()
|
|
|
|
| 196 |
|
| 197 |
# Run inference
|
| 198 |
logger.info("π Running inference pipeline...")
|
|
|
|
| 199 |
result = inference_pipeline.predict_phone_level(
|
| 200 |
temp_file,
|
| 201 |
return_timestamps=True
|
|
|
|
| 203 |
|
| 204 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 205 |
|
| 206 |
+
# Get mappers for phoneme/error processing
|
| 207 |
+
phoneme_mapper = get_phoneme_mapper()
|
| 208 |
+
error_mapper = get_error_mapper()
|
| 209 |
+
|
| 210 |
+
# Map phonemes if text provided
|
| 211 |
+
frame_phonemes = []
|
| 212 |
+
errors = []
|
| 213 |
+
if text and phoneme_mapper and error_mapper:
|
| 214 |
+
try:
|
| 215 |
+
frame_phonemes = phoneme_mapper.map_text_to_frames(
|
| 216 |
+
text,
|
| 217 |
+
num_frames=result.num_frames,
|
| 218 |
+
audio_duration=result.duration
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Process errors
|
| 222 |
+
for i, frame_pred in enumerate(result.frame_predictions):
|
| 223 |
+
phoneme = frame_phonemes[i] if i < len(frame_phonemes) else ''
|
| 224 |
+
class_id = frame_pred.articulation_class
|
| 225 |
+
if frame_pred.fluency_label == 'stutter':
|
| 226 |
+
class_id += 4
|
| 227 |
+
|
| 228 |
+
error_detail = error_mapper.map_classifier_output(
|
| 229 |
+
class_id=class_id,
|
| 230 |
+
confidence=frame_pred.confidence,
|
| 231 |
+
phoneme=phoneme if phoneme else 'unknown',
|
| 232 |
+
fluency_label=frame_pred.fluency_label
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
if error_detail.error_type != ErrorType.NORMAL:
|
| 236 |
+
errors.append({
|
| 237 |
+
"phoneme": error_detail.phoneme,
|
| 238 |
+
"time": frame_pred.time,
|
| 239 |
+
"error_type": error_detail.error_type.value,
|
| 240 |
+
"wrong_sound": error_detail.wrong_sound,
|
| 241 |
+
"severity": error_mapper.get_severity_level(error_detail.severity).value,
|
| 242 |
+
"therapy": error_detail.therapy
|
| 243 |
+
})
|
| 244 |
+
except Exception as e:
|
| 245 |
+
logger.warning(f"β οΈ Phoneme/error mapping failed: {e}")
|
| 246 |
+
|
| 247 |
+
# Extract metrics
|
| 248 |
aggregate = result.aggregate
|
| 249 |
mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
|
| 250 |
+
fluency_percentage = (1.0 - mean_fluency_stutter) * 100
|
| 251 |
|
|
|
|
| 252 |
fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
|
| 253 |
fluent_frames_ratio = fluent_frames / result.num_frames if result.num_frames > 0 else 0.0
|
| 254 |
|
|
|
|
| 255 |
articulation_class_counts = {}
|
| 256 |
for fp in result.frame_predictions:
|
| 257 |
label = fp.articulation_label
|
| 258 |
articulation_class_counts[label] = articulation_class_counts.get(label, 0) + 1
|
| 259 |
|
|
|
|
| 260 |
dominant_articulation = aggregate.get("articulation_label", "normal")
|
|
|
|
|
|
|
| 261 |
avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
|
| 262 |
|
| 263 |
+
# Format response (legacy format with optional error info)
|
| 264 |
response = {
|
| 265 |
"status": "success",
|
| 266 |
"fluency_metrics": {
|
|
|
|
| 283 |
"fluency_label": fp.fluency_label,
|
| 284 |
"articulation_class": fp.articulation_class,
|
| 285 |
"articulation_label": fp.articulation_label,
|
| 286 |
+
"confidence": fp.confidence,
|
| 287 |
+
"phoneme": frame_phonemes[i] if i < len(frame_phonemes) else ''
|
| 288 |
}
|
| 289 |
+
for i, fp in enumerate(result.frame_predictions)
|
| 290 |
]
|
| 291 |
},
|
| 292 |
"confidence": avg_confidence,
|
|
|
|
| 294 |
"processing_time_ms": processing_time_ms
|
| 295 |
}
|
| 296 |
|
| 297 |
+
# Add error info if available
|
| 298 |
+
if errors:
|
| 299 |
+
response["error_count"] = len(errors)
|
| 300 |
+
response["errors"] = errors[:10] # Limit to first 10 for legacy format
|
| 301 |
+
response["problematic_sounds"] = list(set(err["phoneme"] for err in errors if err["phoneme"]))
|
| 302 |
+
|
| 303 |
+
logger.info(f"β
Legacy diagnosis complete: fluency={response['fluency_metrics']['fluency_percentage']:.1f}%, "
|
| 304 |
+
f"errors={len(errors) if errors else 0}, "
|
| 305 |
f"time={processing_time_ms:.0f}ms")
|
| 306 |
|
| 307 |
return response
|
|
|
|
| 322 |
logger.warning(f"Could not clean up {temp_file}: {e}")
|
| 323 |
|
| 324 |
|
| 325 |
+
# Legacy /analyze endpoint removed - use /api/diagnose or /diagnose/file instead
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
|
| 328 |
@app.websocket("/ws/audio")
|
|
@@ -6,15 +6,15 @@ This package contains the inference pipeline for real-time and batch processing.
|
|
| 6 |
|
| 7 |
from .inference_pipeline import (
|
| 8 |
InferencePipeline,
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
create_inference_pipeline
|
| 12 |
)
|
| 13 |
|
| 14 |
__all__ = [
|
| 15 |
"InferencePipeline",
|
| 16 |
-
"
|
| 17 |
-
"
|
| 18 |
"create_inference_pipeline",
|
| 19 |
]
|
| 20 |
|
|
|
|
| 6 |
|
| 7 |
from .inference_pipeline import (
|
| 8 |
InferencePipeline,
|
| 9 |
+
FramePrediction,
|
| 10 |
+
PhoneLevelResult,
|
| 11 |
+
create_inference_pipeline,
|
| 12 |
)
|
| 13 |
|
| 14 |
__all__ = [
|
| 15 |
"InferencePipeline",
|
| 16 |
+
"FramePrediction",
|
| 17 |
+
"PhoneLevelResult",
|
| 18 |
"create_inference_pipeline",
|
| 19 |
]
|
| 20 |
|
|
@@ -72,6 +72,18 @@ class PhonemeMapper:
|
|
| 72 |
"g2p_en library is required. Install with: pip install g2p-en"
|
| 73 |
)
|
| 74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
try:
|
| 76 |
self.g2p = g2p_en.G2p()
|
| 77 |
logger.info("β
G2P model loaded successfully")
|
|
|
|
| 72 |
"g2p_en library is required. Install with: pip install g2p-en"
|
| 73 |
)
|
| 74 |
|
| 75 |
+
# Ensure NLTK data is available (required by g2p_en)
|
| 76 |
+
try:
|
| 77 |
+
import nltk
|
| 78 |
+
try:
|
| 79 |
+
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
|
| 80 |
+
except LookupError:
|
| 81 |
+
logger.info("Downloading NLTK averaged_perceptron_tagger_eng...")
|
| 82 |
+
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
|
| 83 |
+
logger.info("β
NLTK data downloaded")
|
| 84 |
+
except Exception as e:
|
| 85 |
+
logger.warning(f"β οΈ Could not download NLTK data: {e}")
|
| 86 |
+
|
| 87 |
try:
|
| 88 |
self.g2p = g2p_en.G2p()
|
| 89 |
logger.info("β
G2P model loaded successfully")
|
|
@@ -210,6 +210,7 @@ class SpeechPathologyClassifier(nn.Module):
|
|
| 210 |
|
| 211 |
self.device = torch.device(device)
|
| 212 |
self.use_fp16 = use_fp16 and device == "cuda"
|
|
|
|
| 213 |
|
| 214 |
if classifier_hidden_dims is None:
|
| 215 |
classifier_hidden_dims = [256, 128]
|
|
@@ -257,6 +258,9 @@ class SpeechPathologyClassifier(nn.Module):
|
|
| 257 |
num_articulation_classes=num_articulation_classes
|
| 258 |
)
|
| 259 |
|
|
|
|
|
|
|
|
|
|
| 260 |
# Move to device
|
| 261 |
self.wav2vec2_model = self.wav2vec2_model.to(self.device)
|
| 262 |
self.classifier_head = self.classifier_head.to(self.device)
|
|
@@ -275,6 +279,56 @@ class SpeechPathologyClassifier(nn.Module):
|
|
| 275 |
logger.error(f"β Failed to initialize model: {e}", exc_info=True)
|
| 276 |
raise RuntimeError(f"Failed to load Wav2Vec2 model: {e}") from e
|
| 277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
def forward(
|
| 279 |
self,
|
| 280 |
input_values: torch.Tensor,
|
|
|
|
| 210 |
|
| 211 |
self.device = torch.device(device)
|
| 212 |
self.use_fp16 = use_fp16 and device == "cuda"
|
| 213 |
+
self.is_trained = False # Track if classifier is trained
|
| 214 |
|
| 215 |
if classifier_hidden_dims is None:
|
| 216 |
classifier_hidden_dims = [256, 128]
|
|
|
|
| 258 |
num_articulation_classes=num_articulation_classes
|
| 259 |
)
|
| 260 |
|
| 261 |
+
# Try to load trained weights if available (None = try default paths)
|
| 262 |
+
self._load_trained_weights(None)
|
| 263 |
+
|
| 264 |
# Move to device
|
| 265 |
self.wav2vec2_model = self.wav2vec2_model.to(self.device)
|
| 266 |
self.classifier_head = self.classifier_head.to(self.device)
|
|
|
|
| 279 |
logger.error(f"β Failed to initialize model: {e}", exc_info=True)
|
| 280 |
raise RuntimeError(f"Failed to load Wav2Vec2 model: {e}") from e
|
| 281 |
|
| 282 |
+
def _load_trained_weights(self, model_path: Optional[str] = None):
|
| 283 |
+
"""
|
| 284 |
+
Load trained classifier head weights if available.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
model_path: Optional path to model checkpoint. If None, tries default checkpoint paths.
|
| 288 |
+
"""
|
| 289 |
+
from pathlib import Path
|
| 290 |
+
|
| 291 |
+
checkpoint_paths = []
|
| 292 |
+
|
| 293 |
+
# Add user-provided path
|
| 294 |
+
if model_path:
|
| 295 |
+
checkpoint_paths.append(Path(model_path))
|
| 296 |
+
|
| 297 |
+
# Add default checkpoint paths
|
| 298 |
+
checkpoint_paths.extend([
|
| 299 |
+
Path("models/checkpoints/classifier_head_best.pt"),
|
| 300 |
+
Path("models/checkpoints/classifier_head_trained.pt")
|
| 301 |
+
])
|
| 302 |
+
|
| 303 |
+
for checkpoint_path in checkpoint_paths:
|
| 304 |
+
if checkpoint_path.exists():
|
| 305 |
+
try:
|
| 306 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 307 |
+
|
| 308 |
+
# Handle both full checkpoint dict and state_dict directly
|
| 309 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 310 |
+
state_dict = checkpoint['model_state_dict']
|
| 311 |
+
epoch = checkpoint.get('epoch', 'unknown')
|
| 312 |
+
val_acc = checkpoint.get('val_accuracy', 'unknown')
|
| 313 |
+
else:
|
| 314 |
+
state_dict = checkpoint
|
| 315 |
+
epoch = 'unknown'
|
| 316 |
+
val_acc = 'unknown'
|
| 317 |
+
|
| 318 |
+
self.classifier_head.load_state_dict(state_dict)
|
| 319 |
+
logger.info(f"β
Loaded trained classifier head from {checkpoint_path}")
|
| 320 |
+
logger.info(f" Epoch: {epoch}, Validation Accuracy: {val_acc}")
|
| 321 |
+
self.is_trained = True
|
| 322 |
+
return
|
| 323 |
+
except Exception as e:
|
| 324 |
+
logger.warning(f"β οΈ Could not load checkpoint {checkpoint_path}: {e}")
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
# No trained weights found
|
| 328 |
+
logger.warning("β οΈ No trained classifier weights found. Using untrained head (beta mode)")
|
| 329 |
+
logger.warning(" To train the classifier, run: python training/train_classifier_head.py")
|
| 330 |
+
self.is_trained = False
|
| 331 |
+
|
| 332 |
def forward(
|
| 333 |
self,
|
| 334 |
input_values: torch.Tensor,
|
|
@@ -5,6 +5,7 @@ torchaudio>=2.6.0
|
|
| 5 |
transformers>=4.57.3,<5.0
|
| 6 |
numpy>=1.24.0,<2.0.0
|
| 7 |
protobuf>=3.20.0
|
|
|
|
| 8 |
|
| 9 |
# Audio Processing
|
| 10 |
librosa>=0.10.0
|
|
@@ -25,6 +26,12 @@ gradio==6.1.0
|
|
| 25 |
# Logging
|
| 26 |
python-json-logger>=2.0.0
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
# Optional: Legacy/Advanced features
|
| 29 |
openai-whisper>=20230314
|
| 30 |
praat-parselmouth>=0.4.3
|
|
|
|
| 5 |
transformers>=4.57.3,<5.0
|
| 6 |
numpy>=1.24.0,<2.0.0
|
| 7 |
protobuf>=3.20.0
|
| 8 |
+
g2p-en>=2.1.0
|
| 9 |
|
| 10 |
# Audio Processing
|
| 11 |
librosa>=0.10.0
|
|
|
|
| 26 |
# Logging
|
| 27 |
python-json-logger>=2.0.0
|
| 28 |
|
| 29 |
+
# Training dependencies
|
| 30 |
+
pyyaml>=6.0
|
| 31 |
+
scikit-learn>=1.3.0
|
| 32 |
+
matplotlib>=3.7.0
|
| 33 |
+
seaborn>=0.12.0
|
| 34 |
+
|
| 35 |
# Optional: Legacy/Advanced features
|
| 36 |
openai-whisper>=20230314
|
| 37 |
praat-parselmouth>=0.4.3
|
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Annotation Helper Utilities
|
| 3 |
+
|
| 4 |
+
Helper functions for phoneme-level annotation tasks.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Dict, Any, Optional
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def load_annotations(annotations_file: Path = Path("data/annotations.json")) -> List[Dict[str, Any]]:
|
| 17 |
+
"""Load annotations from JSON file."""
|
| 18 |
+
if not annotations_file.exists():
|
| 19 |
+
logger.warning(f"Annotations file not found: {annotations_file}")
|
| 20 |
+
return []
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
with open(annotations_file, 'r', encoding='utf-8') as f:
|
| 24 |
+
return json.load(f)
|
| 25 |
+
except Exception as e:
|
| 26 |
+
logger.error(f"Failed to load annotations: {e}")
|
| 27 |
+
return []
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def save_annotations(annotations: List[Dict[str, Any]], annotations_file: Path = Path("data/annotations.json")):
|
| 31 |
+
"""Save annotations to JSON file."""
|
| 32 |
+
annotations_file.parent.mkdir(parents=True, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
with open(annotations_file, 'w', encoding='utf-8') as f:
|
| 35 |
+
json.dump(annotations, f, indent=2, ensure_ascii=False)
|
| 36 |
+
|
| 37 |
+
logger.info(f"Saved {len(annotations)} annotations to {annotations_file}")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_annotation_statistics(annotations: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 41 |
+
"""Calculate statistics from annotations."""
|
| 42 |
+
total_samples = len(annotations)
|
| 43 |
+
total_errors = sum(a.get('total_errors', 0) for a in annotations)
|
| 44 |
+
|
| 45 |
+
error_types = {
|
| 46 |
+
'substitution': 0,
|
| 47 |
+
'omission': 0,
|
| 48 |
+
'distortion': 0,
|
| 49 |
+
'stutter': 0,
|
| 50 |
+
'normal': 0
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
phoneme_errors = {}
|
| 54 |
+
|
| 55 |
+
for ann in annotations:
|
| 56 |
+
for err in ann.get('phoneme_errors', []):
|
| 57 |
+
err_type = err.get('error_type', 'normal')
|
| 58 |
+
error_types[err_type] = error_types.get(err_type, 0) + 1
|
| 59 |
+
|
| 60 |
+
phoneme = err.get('phoneme', 'unknown')
|
| 61 |
+
if phoneme not in phoneme_errors:
|
| 62 |
+
phoneme_errors[phoneme] = 0
|
| 63 |
+
phoneme_errors[phoneme] += 1
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
'total_samples': total_samples,
|
| 67 |
+
'total_errors': total_errors,
|
| 68 |
+
'error_types': error_types,
|
| 69 |
+
'phoneme_errors': phoneme_errors,
|
| 70 |
+
'avg_errors_per_sample': total_errors / total_samples if total_samples > 0 else 0.0
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def export_for_training(
|
| 75 |
+
annotations: List[Dict[str, Any]],
|
| 76 |
+
output_file: Path = Path("data/training_dataset.json")
|
| 77 |
+
) -> Dict[str, Any]:
|
| 78 |
+
"""Export annotations in training-ready format."""
|
| 79 |
+
training_data = []
|
| 80 |
+
|
| 81 |
+
for ann in annotations:
|
| 82 |
+
audio_file = ann.get('audio_file')
|
| 83 |
+
expected_text = ann.get('expected_text', '')
|
| 84 |
+
duration = ann.get('duration', 0.0)
|
| 85 |
+
|
| 86 |
+
# Create frame-level labels
|
| 87 |
+
num_frames = int((duration * 1000) / 20) # 20ms frames
|
| 88 |
+
frame_labels = [0] * num_frames # 0 = normal
|
| 89 |
+
|
| 90 |
+
# Map errors to frames
|
| 91 |
+
for err in ann.get('phoneme_errors', []):
|
| 92 |
+
frame_id = err.get('frame_id', 0)
|
| 93 |
+
err_type = err.get('error_type', 'normal')
|
| 94 |
+
|
| 95 |
+
# Map to 8-class system
|
| 96 |
+
class_id = {
|
| 97 |
+
'normal': 0,
|
| 98 |
+
'substitution': 1,
|
| 99 |
+
'omission': 2,
|
| 100 |
+
'distortion': 3,
|
| 101 |
+
'stutter': 4
|
| 102 |
+
}.get(err_type, 0)
|
| 103 |
+
|
| 104 |
+
# Check if stutter + articulation error
|
| 105 |
+
if err_type != 'normal' and err_type != 'stutter':
|
| 106 |
+
# Check if there's also stutter
|
| 107 |
+
if any(e.get('error_type') == 'stutter' for e in ann.get('phoneme_errors', [])
|
| 108 |
+
if e.get('frame_id') == frame_id):
|
| 109 |
+
class_id += 4 # Add 4 for stutter classes (5-7)
|
| 110 |
+
|
| 111 |
+
if 0 <= frame_id < num_frames:
|
| 112 |
+
frame_labels[frame_id] = class_id
|
| 113 |
+
|
| 114 |
+
training_data.append({
|
| 115 |
+
'audio_file': audio_file,
|
| 116 |
+
'expected_text': expected_text,
|
| 117 |
+
'duration': duration,
|
| 118 |
+
'num_frames': num_frames,
|
| 119 |
+
'frame_labels': frame_labels,
|
| 120 |
+
'phoneme_errors': ann.get('phoneme_errors', [])
|
| 121 |
+
})
|
| 122 |
+
|
| 123 |
+
output_file.parent.mkdir(parents=True, exist_ok=True)
|
| 124 |
+
with open(output_file, 'w', encoding='utf-8') as f:
|
| 125 |
+
json.dump(training_data, f, indent=2, ensure_ascii=False)
|
| 126 |
+
|
| 127 |
+
logger.info(f"Exported {len(training_data)} samples for training to {output_file}")
|
| 128 |
+
|
| 129 |
+
return {
|
| 130 |
+
'samples': len(training_data),
|
| 131 |
+
'output_file': str(output_file)
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
# Example usage
|
| 137 |
+
annotations = load_annotations()
|
| 138 |
+
stats = get_annotation_statistics(annotations)
|
| 139 |
+
|
| 140 |
+
print(f"Total samples: {stats['total_samples']}")
|
| 141 |
+
print(f"Total errors: {stats['total_errors']}")
|
| 142 |
+
print(f"Error types: {stats['error_types']}")
|
| 143 |
+
|
| 144 |
+
if annotations:
|
| 145 |
+
export_for_training(annotations)
|
| 146 |
+
|
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Data Collection Tool for Speech Pathology Annotation
|
| 3 |
+
|
| 4 |
+
This module provides a Gradio-based interface for collecting and annotating
|
| 5 |
+
phoneme-level speech pathology data. Clinicians can record or upload audio,
|
| 6 |
+
then annotate errors at the phoneme level with timestamps.
|
| 7 |
+
|
| 8 |
+
Usage:
|
| 9 |
+
python scripts/data_collection.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import json
|
| 15 |
+
import time
|
| 16 |
+
import tempfile
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Optional, List, Dict, Any, Tuple
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
import numpy as np
|
| 21 |
+
|
| 22 |
+
import gradio as gr
|
| 23 |
+
import librosa
|
| 24 |
+
import soundfile as sf
|
| 25 |
+
|
| 26 |
+
from models.phoneme_mapper import PhonemeMapper
|
| 27 |
+
from models.error_taxonomy import ErrorType, SeverityLevel
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(level=logging.INFO)
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
# Configuration
|
| 33 |
+
DATA_DIR = Path("data/raw")
|
| 34 |
+
ANNOTATIONS_FILE = Path("data/annotations.json")
|
| 35 |
+
SAMPLE_RATE = 16000
|
| 36 |
+
FRAME_DURATION_MS = 20
|
| 37 |
+
|
| 38 |
+
# Ensure directories exist
|
| 39 |
+
DATA_DIR.mkdir(parents=True, exist_ok=True)
|
| 40 |
+
ANNOTATIONS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
| 41 |
+
|
| 42 |
+
# Load existing annotations
|
| 43 |
+
annotations_db: List[Dict[str, Any]] = []
|
| 44 |
+
if ANNOTATIONS_FILE.exists():
|
| 45 |
+
try:
|
| 46 |
+
with open(ANNOTATIONS_FILE, 'r', encoding='utf-8') as f:
|
| 47 |
+
annotations_db = json.load(f)
|
| 48 |
+
logger.info(f"β
Loaded {len(annotations_db)} existing annotations")
|
| 49 |
+
except Exception as e:
|
| 50 |
+
logger.warning(f"β οΈ Could not load annotations: {e}")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def save_audio_file(audio_data: Optional[Tuple[int, np.ndarray]], filename: str) -> Optional[str]:
|
| 54 |
+
"""Save uploaded/recorded audio to file."""
|
| 55 |
+
if audio_data is None:
|
| 56 |
+
return None
|
| 57 |
+
|
| 58 |
+
sample_rate, audio_array = audio_data
|
| 59 |
+
|
| 60 |
+
# Resample to 16kHz if needed
|
| 61 |
+
if sample_rate != SAMPLE_RATE:
|
| 62 |
+
audio_array = librosa.resample(
|
| 63 |
+
audio_array.astype(np.float32),
|
| 64 |
+
orig_sr=sample_rate,
|
| 65 |
+
target_sr=SAMPLE_RATE
|
| 66 |
+
)
|
| 67 |
+
sample_rate = SAMPLE_RATE
|
| 68 |
+
|
| 69 |
+
# Normalize
|
| 70 |
+
if np.max(np.abs(audio_array)) > 0:
|
| 71 |
+
audio_array = audio_array / np.max(np.abs(audio_array))
|
| 72 |
+
|
| 73 |
+
# Save to data/raw
|
| 74 |
+
output_path = DATA_DIR / filename
|
| 75 |
+
sf.write(str(output_path), audio_array, sample_rate)
|
| 76 |
+
logger.info(f"β
Saved audio to {output_path}")
|
| 77 |
+
|
| 78 |
+
return str(output_path)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_phoneme_list(text: str) -> List[str]:
|
| 82 |
+
"""Convert text to phoneme list using PhonemeMapper."""
|
| 83 |
+
try:
|
| 84 |
+
mapper = PhonemeMapper(
|
| 85 |
+
frame_duration_ms=FRAME_DURATION_MS,
|
| 86 |
+
sample_rate=SAMPLE_RATE
|
| 87 |
+
)
|
| 88 |
+
phonemes = mapper.g2p.convert(text)
|
| 89 |
+
return [p for p in phonemes if p.strip()] if phonemes else []
|
| 90 |
+
except Exception as e:
|
| 91 |
+
logger.error(f"β G2P conversion failed: {e}")
|
| 92 |
+
return []
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def calculate_frame_count(audio_path: str) -> int:
|
| 96 |
+
"""Calculate number of frames for audio file."""
|
| 97 |
+
try:
|
| 98 |
+
duration = librosa.get_duration(path=audio_path)
|
| 99 |
+
frames = int((duration * 1000) / FRAME_DURATION_MS)
|
| 100 |
+
return max(1, frames)
|
| 101 |
+
except Exception as e:
|
| 102 |
+
logger.error(f"β Could not calculate frames: {e}")
|
| 103 |
+
return 0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def save_annotation(
|
| 107 |
+
audio_path: str,
|
| 108 |
+
expected_text: str,
|
| 109 |
+
phoneme_errors: List[Dict[str, Any]],
|
| 110 |
+
annotator_name: str,
|
| 111 |
+
notes: str
|
| 112 |
+
) -> Dict[str, Any]:
|
| 113 |
+
"""Save annotation to database."""
|
| 114 |
+
try:
|
| 115 |
+
duration = librosa.get_duration(path=audio_path)
|
| 116 |
+
|
| 117 |
+
annotation = {
|
| 118 |
+
'id': f"annot_{int(time.time())}",
|
| 119 |
+
'audio_file': audio_path,
|
| 120 |
+
'expected_text': expected_text,
|
| 121 |
+
'duration': float(duration),
|
| 122 |
+
'annotator': annotator_name,
|
| 123 |
+
'notes': notes,
|
| 124 |
+
'created_at': datetime.utcnow().isoformat() + "Z",
|
| 125 |
+
'phoneme_errors': phoneme_errors,
|
| 126 |
+
'total_errors': len(phoneme_errors),
|
| 127 |
+
'error_types': {
|
| 128 |
+
'substitution': sum(1 for e in phoneme_errors if e.get('error_type') == 'substitution'),
|
| 129 |
+
'omission': sum(1 for e in phoneme_errors if e.get('error_type') == 'omission'),
|
| 130 |
+
'distortion': sum(1 for e in phoneme_errors if e.get('error_type') == 'distortion'),
|
| 131 |
+
'stutter': sum(1 for e in phoneme_errors if e.get('error_type') == 'stutter'),
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
annotations_db.append(annotation)
|
| 136 |
+
|
| 137 |
+
# Save to file
|
| 138 |
+
with open(ANNOTATIONS_FILE, 'w', encoding='utf-8') as f:
|
| 139 |
+
json.dump(annotations_db, f, indent=2, ensure_ascii=False)
|
| 140 |
+
|
| 141 |
+
logger.info(f"β
Saved annotation {annotation['id']} with {len(phoneme_errors)} errors")
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
'status': 'success',
|
| 145 |
+
'annotation_id': annotation['id'],
|
| 146 |
+
'total_errors': len(phoneme_errors),
|
| 147 |
+
'message': f"β
Annotation saved! Total annotations: {len(annotations_db)}"
|
| 148 |
+
}
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"β Failed to save annotation: {e}", exc_info=True)
|
| 151 |
+
return {
|
| 152 |
+
'status': 'error',
|
| 153 |
+
'message': f"β Failed to save: {str(e)}"
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def create_annotation_interface():
|
| 158 |
+
"""Create Gradio interface for data collection."""
|
| 159 |
+
|
| 160 |
+
with gr.Blocks(title="Speech Pathology Data Collection", theme=gr.themes.Soft()) as interface:
|
| 161 |
+
gr.Markdown("""
|
| 162 |
+
# π€ Speech Pathology Data Collection Tool
|
| 163 |
+
|
| 164 |
+
**Purpose:** Collect and annotate phoneme-level speech pathology data for training.
|
| 165 |
+
|
| 166 |
+
**Instructions:**
|
| 167 |
+
1. Upload or record audio (5-30 seconds, 16kHz WAV)
|
| 168 |
+
2. Enter expected text/transcript
|
| 169 |
+
3. Review phoneme list
|
| 170 |
+
4. Annotate errors at phoneme level
|
| 171 |
+
5. Save annotation
|
| 172 |
+
""")
|
| 173 |
+
|
| 174 |
+
with gr.Row():
|
| 175 |
+
with gr.Column(scale=1):
|
| 176 |
+
gr.Markdown("### π₯ Audio Input")
|
| 177 |
+
|
| 178 |
+
audio_input = gr.Audio(
|
| 179 |
+
type="numpy",
|
| 180 |
+
label="Record or Upload Audio",
|
| 181 |
+
sources=["microphone", "upload"],
|
| 182 |
+
format="wav"
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
expected_text = gr.Textbox(
|
| 186 |
+
label="Expected Text/Transcript",
|
| 187 |
+
placeholder="Enter the expected text that should be spoken",
|
| 188 |
+
lines=3
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
phoneme_display = gr.Textbox(
|
| 192 |
+
label="Phonemes (G2P)",
|
| 193 |
+
lines=5,
|
| 194 |
+
interactive=False,
|
| 195 |
+
info="Phonemes extracted from expected text"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
btn_get_phonemes = gr.Button("π Extract Phonemes", variant="secondary")
|
| 199 |
+
|
| 200 |
+
with gr.Column(scale=1):
|
| 201 |
+
gr.Markdown("### βοΈ Annotation")
|
| 202 |
+
|
| 203 |
+
annotator_name = gr.Textbox(
|
| 204 |
+
label="Annotator Name",
|
| 205 |
+
placeholder="Your name",
|
| 206 |
+
value="clinician"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
error_frame_id = gr.Number(
|
| 210 |
+
label="Frame ID (0-based)",
|
| 211 |
+
value=0,
|
| 212 |
+
precision=0,
|
| 213 |
+
info="Frame number where error occurs"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
error_phoneme = gr.Textbox(
|
| 217 |
+
label="Phoneme with Error",
|
| 218 |
+
placeholder="/r/",
|
| 219 |
+
info="The phoneme that has an error"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
error_type = gr.Dropdown(
|
| 223 |
+
label="Error Type",
|
| 224 |
+
choices=["normal", "substitution", "omission", "distortion", "stutter"],
|
| 225 |
+
value="normal",
|
| 226 |
+
info="Type of error detected"
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
wrong_sound = gr.Textbox(
|
| 230 |
+
label="Wrong Sound (if substitution)",
|
| 231 |
+
placeholder="/w/",
|
| 232 |
+
info="What sound was produced instead (for substitutions)"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
error_severity = gr.Slider(
|
| 236 |
+
label="Severity (0-1)",
|
| 237 |
+
minimum=0.0,
|
| 238 |
+
maximum=1.0,
|
| 239 |
+
value=0.5,
|
| 240 |
+
step=0.1,
|
| 241 |
+
info="Severity of the error"
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
error_timestamp = gr.Number(
|
| 245 |
+
label="Timestamp (seconds)",
|
| 246 |
+
value=0.0,
|
| 247 |
+
precision=2,
|
| 248 |
+
info="Time in audio where error occurs"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
btn_add_error = gr.Button("β Add Error", variant="primary")
|
| 252 |
+
|
| 253 |
+
errors_list = gr.Dataframe(
|
| 254 |
+
label="Annotated Errors",
|
| 255 |
+
headers=["Frame", "Phoneme", "Type", "Wrong Sound", "Severity", "Time"],
|
| 256 |
+
interactive=False,
|
| 257 |
+
wrap=True
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
notes = gr.Textbox(
|
| 261 |
+
label="Notes",
|
| 262 |
+
placeholder="Additional notes about this sample",
|
| 263 |
+
lines=3
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
btn_save = gr.Button("πΎ Save Annotation", variant="primary", size="lg")
|
| 267 |
+
|
| 268 |
+
output_status = gr.Textbox(
|
| 269 |
+
label="Status",
|
| 270 |
+
interactive=False,
|
| 271 |
+
lines=3
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# Statistics panel
|
| 275 |
+
with gr.Row():
|
| 276 |
+
gr.Markdown("### π Statistics")
|
| 277 |
+
stats_display = gr.Markdown("**Total Annotations:** 0 | **Total Errors:** 0")
|
| 278 |
+
|
| 279 |
+
# Event handlers
|
| 280 |
+
errors_data = gr.State(value=[])
|
| 281 |
+
|
| 282 |
+
def extract_phonemes(text: str) -> str:
|
| 283 |
+
"""Extract phonemes from text."""
|
| 284 |
+
if not text:
|
| 285 |
+
return "Enter expected text first"
|
| 286 |
+
phonemes = get_phoneme_list(text)
|
| 287 |
+
return " ".join([f"/{p}/" for p in phonemes]) if phonemes else "No phonemes found"
|
| 288 |
+
|
| 289 |
+
def add_error(
|
| 290 |
+
frame_id: int,
|
| 291 |
+
phoneme: str,
|
| 292 |
+
error_type: str,
|
| 293 |
+
wrong_sound: str,
|
| 294 |
+
severity: float,
|
| 295 |
+
timestamp: float,
|
| 296 |
+
current_errors: List[Dict]
|
| 297 |
+
) -> Tuple[List[Dict], gr.Dataframe]:
|
| 298 |
+
"""Add an error to the list."""
|
| 299 |
+
error = {
|
| 300 |
+
'frame_id': int(frame_id),
|
| 301 |
+
'phoneme': phoneme.strip(),
|
| 302 |
+
'error_type': error_type,
|
| 303 |
+
'wrong_sound': wrong_sound.strip() if wrong_sound else None,
|
| 304 |
+
'severity': float(severity),
|
| 305 |
+
'timestamp': float(timestamp),
|
| 306 |
+
'confidence': 1.0 # Manual annotation is always confident
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
new_errors = current_errors + [error]
|
| 310 |
+
|
| 311 |
+
# Create dataframe
|
| 312 |
+
df_data = [
|
| 313 |
+
[
|
| 314 |
+
e['frame_id'],
|
| 315 |
+
e['phoneme'],
|
| 316 |
+
e['error_type'],
|
| 317 |
+
e.get('wrong_sound', 'N/A'),
|
| 318 |
+
f"{e['severity']:.2f}",
|
| 319 |
+
f"{e['timestamp']:.2f}s"
|
| 320 |
+
]
|
| 321 |
+
for e in new_errors
|
| 322 |
+
]
|
| 323 |
+
|
| 324 |
+
return new_errors, df_data
|
| 325 |
+
|
| 326 |
+
def save_annotation_handler(
|
| 327 |
+
audio_data: Optional[Tuple[int, np.ndarray]],
|
| 328 |
+
expected_text: str,
|
| 329 |
+
errors: List[Dict],
|
| 330 |
+
annotator: str,
|
| 331 |
+
notes: str
|
| 332 |
+
) -> str:
|
| 333 |
+
"""Handle annotation saving."""
|
| 334 |
+
if audio_data is None:
|
| 335 |
+
return "β Please provide audio first"
|
| 336 |
+
|
| 337 |
+
if not expected_text:
|
| 338 |
+
return "β Please provide expected text"
|
| 339 |
+
|
| 340 |
+
# Save audio
|
| 341 |
+
filename = f"sample_{int(time.time())}.wav"
|
| 342 |
+
audio_path = save_audio_file(audio_data, filename)
|
| 343 |
+
|
| 344 |
+
if not audio_path:
|
| 345 |
+
return "β Failed to save audio file"
|
| 346 |
+
|
| 347 |
+
# Save annotation
|
| 348 |
+
result = save_annotation(
|
| 349 |
+
audio_path=audio_path,
|
| 350 |
+
expected_text=expected_text,
|
| 351 |
+
phoneme_errors=errors,
|
| 352 |
+
annotator_name=annotator,
|
| 353 |
+
notes=notes
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
return result.get('message', 'Unknown status')
|
| 357 |
+
|
| 358 |
+
def update_stats() -> str:
|
| 359 |
+
"""Update statistics display."""
|
| 360 |
+
total_annotations = len(annotations_db)
|
| 361 |
+
total_errors = sum(a.get('total_errors', 0) for a in annotations_db)
|
| 362 |
+
|
| 363 |
+
error_breakdown = {}
|
| 364 |
+
for ann in annotations_db:
|
| 365 |
+
for err_type, count in ann.get('error_types', {}).items():
|
| 366 |
+
error_breakdown[err_type] = error_breakdown.get(err_type, 0) + count
|
| 367 |
+
|
| 368 |
+
stats_text = f"""
|
| 369 |
+
**Total Annotations:** {total_annotations} | **Total Errors:** {total_errors}
|
| 370 |
+
|
| 371 |
+
**Error Breakdown:**
|
| 372 |
+
- Substitution: {error_breakdown.get('substitution', 0)}
|
| 373 |
+
- Omission: {error_breakdown.get('omission', 0)}
|
| 374 |
+
- Distortion: {error_breakdown.get('distortion', 0)}
|
| 375 |
+
- Stutter: {error_breakdown.get('stutter', 0)}
|
| 376 |
+
"""
|
| 377 |
+
return stats_text
|
| 378 |
+
|
| 379 |
+
# Wire up events
|
| 380 |
+
btn_get_phonemes.click(
|
| 381 |
+
fn=extract_phonemes,
|
| 382 |
+
inputs=[expected_text],
|
| 383 |
+
outputs=[phoneme_display]
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
btn_add_error.click(
|
| 387 |
+
fn=add_error,
|
| 388 |
+
inputs=[
|
| 389 |
+
error_frame_id,
|
| 390 |
+
error_phoneme,
|
| 391 |
+
error_type,
|
| 392 |
+
wrong_sound,
|
| 393 |
+
error_severity,
|
| 394 |
+
error_timestamp,
|
| 395 |
+
errors_data
|
| 396 |
+
],
|
| 397 |
+
outputs=[errors_data, errors_list]
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
btn_save.click(
|
| 401 |
+
fn=save_annotation_handler,
|
| 402 |
+
inputs=[audio_input, expected_text, errors_data, annotator_name, notes],
|
| 403 |
+
outputs=[output_status]
|
| 404 |
+
).then(
|
| 405 |
+
fn=update_stats,
|
| 406 |
+
outputs=[stats_display]
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Load stats on startup
|
| 410 |
+
interface.load(fn=update_stats, outputs=[stats_display])
|
| 411 |
+
|
| 412 |
+
return interface
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
if __name__ == "__main__":
|
| 416 |
+
interface = create_annotation_interface()
|
| 417 |
+
interface.launch(server_name="0.0.0.0", server_port=7861, share=False)
|
| 418 |
+
|
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Training Configuration for Classifier Head
|
| 2 |
+
|
| 3 |
+
# Data Configuration
|
| 4 |
+
data:
|
| 5 |
+
annotations_file: "data/annotations.json"
|
| 6 |
+
training_dataset: "data/training_dataset.json"
|
| 7 |
+
train_split: 0.8
|
| 8 |
+
val_split: 0.2
|
| 9 |
+
test_split: 0.0 # Use validation as test if test_split is 0
|
| 10 |
+
random_seed: 42
|
| 11 |
+
|
| 12 |
+
# Model Configuration
|
| 13 |
+
model:
|
| 14 |
+
input_dim: 1024 # Wav2Vec2-XLSR-53 feature dimension
|
| 15 |
+
hidden_dims: [512, 256] # Shared layers
|
| 16 |
+
dropout: 0.1
|
| 17 |
+
num_classes: 8 # 8-class output (fluency + articulation combined)
|
| 18 |
+
use_pretrained_head: false # Set to true after first training
|
| 19 |
+
|
| 20 |
+
# Training Configuration
|
| 21 |
+
training:
|
| 22 |
+
batch_size: 16
|
| 23 |
+
num_epochs: 50
|
| 24 |
+
learning_rate: 0.001
|
| 25 |
+
weight_decay: 0.0001
|
| 26 |
+
gradient_clip_norm: 1.0
|
| 27 |
+
|
| 28 |
+
# Loss Configuration
|
| 29 |
+
loss:
|
| 30 |
+
type: "cross_entropy" # or "focal" for imbalanced data
|
| 31 |
+
class_weights: null # Auto-calculate from data if null
|
| 32 |
+
focal_alpha: 0.25
|
| 33 |
+
focal_gamma: 2.0
|
| 34 |
+
|
| 35 |
+
# Optimizer
|
| 36 |
+
optimizer: "adam"
|
| 37 |
+
adam_betas: [0.9, 0.999]
|
| 38 |
+
|
| 39 |
+
# Scheduler
|
| 40 |
+
scheduler: "reduce_on_plateau"
|
| 41 |
+
scheduler_patience: 5
|
| 42 |
+
scheduler_factor: 0.5
|
| 43 |
+
scheduler_min_lr: 0.00001
|
| 44 |
+
|
| 45 |
+
# Early Stopping
|
| 46 |
+
early_stopping:
|
| 47 |
+
enabled: true
|
| 48 |
+
patience: 10
|
| 49 |
+
min_delta: 0.001
|
| 50 |
+
monitor: "val_loss"
|
| 51 |
+
|
| 52 |
+
# Data Augmentation
|
| 53 |
+
augmentation:
|
| 54 |
+
enabled: false # Enable after initial training
|
| 55 |
+
time_stretch: [0.9, 1.1]
|
| 56 |
+
noise_injection: 0.01
|
| 57 |
+
pitch_shift: [-2, 2] # semitones
|
| 58 |
+
|
| 59 |
+
# Validation Configuration
|
| 60 |
+
validation:
|
| 61 |
+
metrics: ["accuracy", "f1_score", "precision", "recall", "confusion_matrix"]
|
| 62 |
+
per_class_metrics: true
|
| 63 |
+
save_predictions: true
|
| 64 |
+
|
| 65 |
+
# Checkpoint Configuration
|
| 66 |
+
checkpoint:
|
| 67 |
+
save_dir: "models/checkpoints"
|
| 68 |
+
save_best: true
|
| 69 |
+
save_last: true
|
| 70 |
+
save_frequency: 5 # Save every N epochs
|
| 71 |
+
filename: "classifier_head_trained.pt"
|
| 72 |
+
best_filename: "classifier_head_best.pt"
|
| 73 |
+
|
| 74 |
+
# Logging Configuration
|
| 75 |
+
logging:
|
| 76 |
+
log_dir: "training/logs"
|
| 77 |
+
tensorboard: false # Enable if tensorboard installed
|
| 78 |
+
wandb: false # Enable if wandb installed
|
| 79 |
+
log_frequency: 10 # Log every N batches
|
| 80 |
+
|
| 81 |
+
# Device Configuration
|
| 82 |
+
device:
|
| 83 |
+
use_cuda: true
|
| 84 |
+
cuda_device: 0
|
| 85 |
+
mixed_precision: false # Use FP16 training
|
| 86 |
+
|
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluation Script for Trained Classifier Head
|
| 3 |
+
|
| 4 |
+
Evaluates the trained classifier on test/validation data and generates
|
| 5 |
+
comprehensive metrics including per-class accuracy, confusion matrix, etc.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python training/evaluate_classifier.py --checkpoint models/checkpoints/classifier_head_best.pt
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import yaml
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Dict, List, Any
|
| 17 |
+
import sys
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import numpy as np
|
| 21 |
+
from sklearn.metrics import (
|
| 22 |
+
accuracy_score, f1_score, precision_score, recall_score,
|
| 23 |
+
confusion_matrix, classification_report
|
| 24 |
+
)
|
| 25 |
+
import matplotlib.pyplot as plt
|
| 26 |
+
import seaborn as sns
|
| 27 |
+
|
| 28 |
+
# Add project root to path
|
| 29 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 30 |
+
|
| 31 |
+
from training.train_classifier_head import PhonemeDataset, collate_fn
|
| 32 |
+
from torch.utils.data import DataLoader
|
| 33 |
+
from models.speech_pathology_model import SpeechPathologyClassifier
|
| 34 |
+
from models.phoneme_mapper import PhonemeMapper
|
| 35 |
+
from inference.inference_pipeline import InferencePipeline
|
| 36 |
+
from config import default_audio_config, default_model_config, default_inference_config
|
| 37 |
+
|
| 38 |
+
logging.basicConfig(level=logging.INFO)
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def load_trained_model(checkpoint_path: Path, config_path: Path = Path("training/config.yaml")) -> torch.nn.Module:
|
| 43 |
+
"""Load trained classifier head from checkpoint."""
|
| 44 |
+
# Load config
|
| 45 |
+
with open(config_path, 'r') as f:
|
| 46 |
+
config = yaml.safe_load(f)
|
| 47 |
+
|
| 48 |
+
# Initialize inference pipeline
|
| 49 |
+
inference_pipeline = InferencePipeline(
|
| 50 |
+
audio_config=default_audio_config,
|
| 51 |
+
model_config=default_model_config,
|
| 52 |
+
inference_config=default_inference_config
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
model = inference_pipeline.model
|
| 56 |
+
|
| 57 |
+
# Load checkpoint
|
| 58 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 59 |
+
model.classifier_head.load_state_dict(checkpoint['model_state_dict'])
|
| 60 |
+
|
| 61 |
+
logger.info(f"β
Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
|
| 62 |
+
logger.info(f" Validation loss: {checkpoint.get('val_loss', 'unknown'):.4f}")
|
| 63 |
+
logger.info(f" Validation accuracy: {checkpoint.get('val_accuracy', 'unknown'):.4f}")
|
| 64 |
+
|
| 65 |
+
return model
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def evaluate_model(
|
| 69 |
+
model: torch.nn.Module,
|
| 70 |
+
dataloader: DataLoader,
|
| 71 |
+
device: torch.device,
|
| 72 |
+
class_names: List[str]
|
| 73 |
+
) -> Dict[str, Any]:
|
| 74 |
+
"""Evaluate model and return comprehensive metrics."""
|
| 75 |
+
model.eval()
|
| 76 |
+
|
| 77 |
+
all_preds = []
|
| 78 |
+
all_labels = []
|
| 79 |
+
all_probs = []
|
| 80 |
+
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
for batch in dataloader:
|
| 83 |
+
features = batch['features'].to(device)
|
| 84 |
+
labels = batch['labels'].to(device)
|
| 85 |
+
|
| 86 |
+
batch_size, seq_len, feat_dim = features.shape
|
| 87 |
+
features_flat = features.view(-1, feat_dim)
|
| 88 |
+
labels_flat = labels.view(-1)
|
| 89 |
+
|
| 90 |
+
# Forward pass
|
| 91 |
+
shared_features = model.classifier_head.shared_layers(features_flat)
|
| 92 |
+
logits = model.classifier_head.full_head(shared_features)
|
| 93 |
+
probs = torch.softmax(logits, dim=-1)
|
| 94 |
+
|
| 95 |
+
preds = torch.argmax(logits, dim=-1).cpu().numpy()
|
| 96 |
+
all_preds.extend(preds)
|
| 97 |
+
all_labels.extend(labels_flat.cpu().numpy())
|
| 98 |
+
all_probs.extend(probs.cpu().numpy())
|
| 99 |
+
|
| 100 |
+
# Calculate metrics
|
| 101 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 102 |
+
f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
|
| 103 |
+
f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
|
| 104 |
+
precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0)
|
| 105 |
+
recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0)
|
| 106 |
+
|
| 107 |
+
# Per-class metrics
|
| 108 |
+
cm = confusion_matrix(all_labels, all_preds, labels=list(range(len(class_names))))
|
| 109 |
+
|
| 110 |
+
# Per-class accuracy
|
| 111 |
+
per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
|
| 112 |
+
per_class_accuracy = np.nan_to_num(per_class_accuracy) # Handle division by zero
|
| 113 |
+
|
| 114 |
+
# Classification report
|
| 115 |
+
report = classification_report(
|
| 116 |
+
all_labels, all_preds,
|
| 117 |
+
target_names=class_names,
|
| 118 |
+
output_dict=True,
|
| 119 |
+
zero_division=0
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Confidence analysis
|
| 123 |
+
all_probs = np.array(all_probs)
|
| 124 |
+
max_probs = np.max(all_probs, axis=1)
|
| 125 |
+
correct_mask = np.array(all_preds) == np.array(all_labels)
|
| 126 |
+
|
| 127 |
+
avg_confidence_correct = np.mean(max_probs[correct_mask]) if np.any(correct_mask) else 0.0
|
| 128 |
+
avg_confidence_incorrect = np.mean(max_probs[~correct_mask]) if np.any(~correct_mask) else 0.0
|
| 129 |
+
|
| 130 |
+
return {
|
| 131 |
+
'overall_accuracy': float(accuracy),
|
| 132 |
+
'f1_macro': float(f1_macro),
|
| 133 |
+
'f1_weighted': float(f1_weighted),
|
| 134 |
+
'precision_macro': float(precision_macro),
|
| 135 |
+
'recall_macro': float(recall_macro),
|
| 136 |
+
'confusion_matrix': cm.tolist(),
|
| 137 |
+
'per_class_accuracy': per_class_accuracy.tolist(),
|
| 138 |
+
'classification_report': report,
|
| 139 |
+
'confidence': {
|
| 140 |
+
'avg_correct': float(avg_confidence_correct),
|
| 141 |
+
'avg_incorrect': float(avg_confidence_incorrect),
|
| 142 |
+
'confidence_distribution': {
|
| 143 |
+
'mean': float(np.mean(max_probs)),
|
| 144 |
+
'std': float(np.std(max_probs)),
|
| 145 |
+
'min': float(np.min(max_probs)),
|
| 146 |
+
'max': float(np.max(max_probs))
|
| 147 |
+
}
|
| 148 |
+
},
|
| 149 |
+
'num_samples': len(all_labels)
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], output_path: Path):
|
| 154 |
+
"""Plot and save confusion matrix."""
|
| 155 |
+
plt.figure(figsize=(10, 8))
|
| 156 |
+
sns.heatmap(
|
| 157 |
+
cm,
|
| 158 |
+
annot=True,
|
| 159 |
+
fmt='d',
|
| 160 |
+
cmap='Blues',
|
| 161 |
+
xticklabels=class_names,
|
| 162 |
+
yticklabels=class_names
|
| 163 |
+
)
|
| 164 |
+
plt.title('Confusion Matrix')
|
| 165 |
+
plt.ylabel('True Label')
|
| 166 |
+
plt.xlabel('Predicted Label')
|
| 167 |
+
plt.tight_layout()
|
| 168 |
+
plt.savefig(output_path)
|
| 169 |
+
logger.info(f"β
Saved confusion matrix to {output_path}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def main():
|
| 173 |
+
parser = argparse.ArgumentParser(description="Evaluate trained classifier")
|
| 174 |
+
parser.add_argument('--checkpoint', type=str, required=True,
|
| 175 |
+
help='Path to checkpoint file')
|
| 176 |
+
parser.add_argument('--config', type=str, default='training/config.yaml',
|
| 177 |
+
help='Path to config file')
|
| 178 |
+
parser.add_argument('--dataset', type=str, default='data/training_dataset.json',
|
| 179 |
+
help='Path to evaluation dataset')
|
| 180 |
+
parser.add_argument('--output', type=str, default='training/evaluation_results.json',
|
| 181 |
+
help='Path to save evaluation results')
|
| 182 |
+
parser.add_argument('--plot', type=str, default='training/confusion_matrix.png',
|
| 183 |
+
help='Path to save confusion matrix plot')
|
| 184 |
+
args = parser.parse_args()
|
| 185 |
+
|
| 186 |
+
# Load config
|
| 187 |
+
with open(args.config, 'r') as f:
|
| 188 |
+
config = yaml.safe_load(f)
|
| 189 |
+
|
| 190 |
+
# Set device
|
| 191 |
+
device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu')
|
| 192 |
+
logger.info(f"Using device: {device}")
|
| 193 |
+
|
| 194 |
+
# Load model
|
| 195 |
+
checkpoint_path = Path(args.checkpoint)
|
| 196 |
+
if not checkpoint_path.exists():
|
| 197 |
+
logger.error(f"Checkpoint not found: {checkpoint_path}")
|
| 198 |
+
return
|
| 199 |
+
|
| 200 |
+
model = load_trained_model(checkpoint_path, Path(args.config))
|
| 201 |
+
model = model.to(device)
|
| 202 |
+
|
| 203 |
+
# Load evaluation dataset
|
| 204 |
+
dataset_path = Path(args.dataset)
|
| 205 |
+
if not dataset_path.exists():
|
| 206 |
+
logger.error(f"Dataset not found: {dataset_path}")
|
| 207 |
+
return
|
| 208 |
+
|
| 209 |
+
with open(dataset_path, 'r') as f:
|
| 210 |
+
eval_data = json.load(f)
|
| 211 |
+
|
| 212 |
+
logger.info(f"Loaded {len(eval_data)} evaluation samples")
|
| 213 |
+
|
| 214 |
+
# Create dataset and dataloader
|
| 215 |
+
inference_pipeline = InferencePipeline(
|
| 216 |
+
audio_config=default_audio_config,
|
| 217 |
+
model_config=default_model_config,
|
| 218 |
+
inference_config=default_inference_config
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
phoneme_mapper = PhonemeMapper(frame_duration_ms=20, sample_rate=16000)
|
| 222 |
+
|
| 223 |
+
from training.train_classifier_head import PhonemeDataset
|
| 224 |
+
dataset = PhonemeDataset(eval_data, inference_pipeline, phoneme_mapper)
|
| 225 |
+
|
| 226 |
+
dataloader = DataLoader(
|
| 227 |
+
dataset,
|
| 228 |
+
batch_size=config['training']['batch_size'],
|
| 229 |
+
shuffle=False,
|
| 230 |
+
collate_fn=collate_fn
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Class names
|
| 234 |
+
class_names = [
|
| 235 |
+
"Normal",
|
| 236 |
+
"Substitution",
|
| 237 |
+
"Omission",
|
| 238 |
+
"Distortion",
|
| 239 |
+
"Normal+Stutter",
|
| 240 |
+
"Substitution+Stutter",
|
| 241 |
+
"Omission+Stutter",
|
| 242 |
+
"Distortion+Stutter"
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
# Evaluate
|
| 246 |
+
logger.info("Evaluating model...")
|
| 247 |
+
metrics = evaluate_model(model, dataloader, device, class_names)
|
| 248 |
+
|
| 249 |
+
# Print results
|
| 250 |
+
logger.info("\n" + "="*50)
|
| 251 |
+
logger.info("EVALUATION RESULTS")
|
| 252 |
+
logger.info("="*50)
|
| 253 |
+
logger.info(f"Overall Accuracy: {metrics['overall_accuracy']:.4f}")
|
| 254 |
+
logger.info(f"F1 Score (macro): {metrics['f1_macro']:.4f}")
|
| 255 |
+
logger.info(f"F1 Score (weighted): {metrics['f1_weighted']:.4f}")
|
| 256 |
+
logger.info(f"Precision (macro): {metrics['precision_macro']:.4f}")
|
| 257 |
+
logger.info(f"Recall (macro): {metrics['recall_macro']:.4f}")
|
| 258 |
+
logger.info(f"\nPer-Class Accuracy:")
|
| 259 |
+
for i, (name, acc) in enumerate(zip(class_names, metrics['per_class_accuracy'])):
|
| 260 |
+
logger.info(f" {name}: {acc:.4f}")
|
| 261 |
+
logger.info(f"\nConfidence Analysis:")
|
| 262 |
+
logger.info(f" Avg confidence (correct): {metrics['confidence']['avg_correct']:.4f}")
|
| 263 |
+
logger.info(f" Avg confidence (incorrect): {metrics['confidence']['avg_incorrect']:.4f}")
|
| 264 |
+
|
| 265 |
+
# Save results
|
| 266 |
+
output_path = Path(args.output)
|
| 267 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 268 |
+
with open(output_path, 'w') as f:
|
| 269 |
+
json.dump(metrics, f, indent=2)
|
| 270 |
+
logger.info(f"\nβ
Saved evaluation results to {output_path}")
|
| 271 |
+
|
| 272 |
+
# Plot confusion matrix
|
| 273 |
+
if args.plot:
|
| 274 |
+
plot_path = Path(args.plot)
|
| 275 |
+
plot_path.parent.mkdir(parents=True, exist_ok=True)
|
| 276 |
+
cm = np.array(metrics['confusion_matrix'])
|
| 277 |
+
plot_confusion_matrix(cm, class_names, plot_path)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
| 282 |
+
|
|
@@ -0,0 +1,469 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Training Script for Speech Pathology Classifier Head
|
| 3 |
+
|
| 4 |
+
This script fine-tunes the classification head on phoneme-level labeled data.
|
| 5 |
+
Wav2Vec2 encoder is frozen; only the classifier head is trained.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python training/train_classifier_head.py --config training/config.yaml
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import json
|
| 15 |
+
import yaml
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Dict, List, Tuple, Optional, Any
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
import torch.optim as optim
|
| 24 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 25 |
+
import numpy as np
|
| 26 |
+
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
|
| 27 |
+
import librosa
|
| 28 |
+
import soundfile as sf
|
| 29 |
+
|
| 30 |
+
# Add project root to path
|
| 31 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 32 |
+
|
| 33 |
+
from models.speech_pathology_model import SpeechPathologyClassifier, MultiTaskClassifierHead
|
| 34 |
+
from models.phoneme_mapper import PhonemeMapper
|
| 35 |
+
from inference.inference_pipeline import InferencePipeline
|
| 36 |
+
from config import default_audio_config, default_model_config, default_inference_config
|
| 37 |
+
|
| 38 |
+
logging.basicConfig(
|
| 39 |
+
level=logging.INFO,
|
| 40 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 41 |
+
)
|
| 42 |
+
logger = logging.getLogger(__name__)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PhonemeDataset(Dataset):
|
| 46 |
+
"""Dataset for phoneme-level speech pathology training."""
|
| 47 |
+
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
training_data: List[Dict[str, Any]],
|
| 51 |
+
inference_pipeline: InferencePipeline,
|
| 52 |
+
phoneme_mapper: PhonemeMapper
|
| 53 |
+
):
|
| 54 |
+
"""
|
| 55 |
+
Initialize dataset.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
training_data: List of training samples with frame labels
|
| 59 |
+
inference_pipeline: Pipeline for extracting Wav2Vec2 features
|
| 60 |
+
phoneme_mapper: Mapper for phoneme alignment
|
| 61 |
+
"""
|
| 62 |
+
self.training_data = training_data
|
| 63 |
+
self.inference_pipeline = inference_pipeline
|
| 64 |
+
self.phoneme_mapper = phoneme_mapper
|
| 65 |
+
|
| 66 |
+
logger.info(f"Initialized dataset with {len(training_data)} samples")
|
| 67 |
+
|
| 68 |
+
def __len__(self) -> int:
|
| 69 |
+
return len(self.training_data)
|
| 70 |
+
|
| 71 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 72 |
+
"""Get a training sample."""
|
| 73 |
+
sample = self.training_data[idx]
|
| 74 |
+
audio_file = sample['audio_file']
|
| 75 |
+
frame_labels = sample['frame_labels']
|
| 76 |
+
|
| 77 |
+
# Load audio
|
| 78 |
+
try:
|
| 79 |
+
audio, sr = librosa.load(audio_file, sr=16000)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Failed to load {audio_file}: {e}")
|
| 82 |
+
# Return dummy data
|
| 83 |
+
return {
|
| 84 |
+
'features': torch.zeros(1, 1024),
|
| 85 |
+
'labels': torch.tensor([0], dtype=torch.long),
|
| 86 |
+
'valid': torch.tensor(False)
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
# Extract Wav2Vec2 features
|
| 90 |
+
try:
|
| 91 |
+
frame_features, frame_times = self.inference_pipeline.get_phone_level_features(audio)
|
| 92 |
+
|
| 93 |
+
# Align labels with features
|
| 94 |
+
num_features = len(frame_features)
|
| 95 |
+
num_labels = len(frame_labels)
|
| 96 |
+
|
| 97 |
+
# Pad or truncate labels to match features
|
| 98 |
+
if num_labels < num_features:
|
| 99 |
+
frame_labels = frame_labels + [0] * (num_features - num_labels)
|
| 100 |
+
elif num_labels > num_features:
|
| 101 |
+
frame_labels = frame_labels[:num_features]
|
| 102 |
+
|
| 103 |
+
# Convert to tensors
|
| 104 |
+
features_tensor = frame_features # Already a tensor
|
| 105 |
+
labels_tensor = torch.tensor(frame_labels[:num_features], dtype=torch.long)
|
| 106 |
+
|
| 107 |
+
return {
|
| 108 |
+
'features': features_tensor,
|
| 109 |
+
'labels': labels_tensor,
|
| 110 |
+
'valid': torch.tensor(True)
|
| 111 |
+
}
|
| 112 |
+
except Exception as e:
|
| 113 |
+
logger.error(f"Failed to extract features from {audio_file}: {e}")
|
| 114 |
+
return {
|
| 115 |
+
'features': torch.zeros(1, 1024),
|
| 116 |
+
'labels': torch.tensor([0], dtype=torch.long),
|
| 117 |
+
'valid': torch.tensor(False)
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
| 122 |
+
"""Collate function for DataLoader."""
|
| 123 |
+
# Filter out invalid samples
|
| 124 |
+
valid_batch = [b for b in batch if b['valid'].item()]
|
| 125 |
+
|
| 126 |
+
if not valid_batch:
|
| 127 |
+
# Return dummy batch
|
| 128 |
+
return {
|
| 129 |
+
'features': torch.zeros(1, 1, 1024),
|
| 130 |
+
'labels': torch.zeros(1, 1, dtype=torch.long)
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
# Stack features and labels
|
| 134 |
+
features_list = []
|
| 135 |
+
labels_list = []
|
| 136 |
+
|
| 137 |
+
for item in valid_batch:
|
| 138 |
+
features_list.append(item['features'])
|
| 139 |
+
labels_list.append(item['labels'])
|
| 140 |
+
|
| 141 |
+
# Pad to same length
|
| 142 |
+
max_len = max(f.shape[0] for f in features_list)
|
| 143 |
+
|
| 144 |
+
padded_features = []
|
| 145 |
+
padded_labels = []
|
| 146 |
+
|
| 147 |
+
for feat, lab in zip(features_list, labels_list):
|
| 148 |
+
if feat.shape[0] < max_len:
|
| 149 |
+
padding = max_len - feat.shape[0]
|
| 150 |
+
feat = torch.cat([feat, torch.zeros(padding, feat.shape[1])])
|
| 151 |
+
lab = torch.cat([lab, torch.zeros(padding, dtype=torch.long)])
|
| 152 |
+
padded_features.append(feat)
|
| 153 |
+
padded_labels.append(lab)
|
| 154 |
+
|
| 155 |
+
return {
|
| 156 |
+
'features': torch.stack(padded_features),
|
| 157 |
+
'labels': torch.stack(padded_labels)
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def calculate_class_weights(dataset: PhonemeDataset) -> torch.Tensor:
|
| 162 |
+
"""Calculate class weights for imbalanced data."""
|
| 163 |
+
all_labels = []
|
| 164 |
+
for i in range(len(dataset)):
|
| 165 |
+
sample = dataset[i]
|
| 166 |
+
if sample['valid'].item():
|
| 167 |
+
all_labels.extend(sample['labels'].tolist())
|
| 168 |
+
|
| 169 |
+
if not all_labels:
|
| 170 |
+
return torch.ones(8)
|
| 171 |
+
|
| 172 |
+
unique, counts = np.unique(all_labels, return_counts=True)
|
| 173 |
+
total = len(all_labels)
|
| 174 |
+
|
| 175 |
+
weights = torch.ones(8)
|
| 176 |
+
for cls, count in zip(unique, counts):
|
| 177 |
+
if count > 0:
|
| 178 |
+
weights[int(cls)] = total / (8 * count) # Inverse frequency weighting
|
| 179 |
+
|
| 180 |
+
logger.info(f"Class weights: {weights.tolist()}")
|
| 181 |
+
return weights
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def train_epoch(
|
| 185 |
+
model: nn.Module,
|
| 186 |
+
dataloader: DataLoader,
|
| 187 |
+
optimizer: optim.Optimizer,
|
| 188 |
+
criterion: nn.Module,
|
| 189 |
+
device: torch.device,
|
| 190 |
+
epoch: int
|
| 191 |
+
) -> Dict[str, float]:
|
| 192 |
+
"""Train for one epoch."""
|
| 193 |
+
model.train()
|
| 194 |
+
total_loss = 0.0
|
| 195 |
+
all_preds = []
|
| 196 |
+
all_labels = []
|
| 197 |
+
|
| 198 |
+
for batch_idx, batch in enumerate(dataloader):
|
| 199 |
+
features = batch['features'].to(device) # (batch, seq_len, 1024)
|
| 200 |
+
labels = batch['labels'].to(device) # (batch, seq_len)
|
| 201 |
+
|
| 202 |
+
# Flatten for processing
|
| 203 |
+
batch_size, seq_len, feat_dim = features.shape
|
| 204 |
+
features_flat = features.view(-1, feat_dim) # (batch * seq_len, 1024)
|
| 205 |
+
labels_flat = labels.view(-1) # (batch * seq_len)
|
| 206 |
+
|
| 207 |
+
# Forward pass
|
| 208 |
+
optimizer.zero_grad()
|
| 209 |
+
|
| 210 |
+
# Get predictions from full_head
|
| 211 |
+
shared_features = model.classifier_head.shared_layers(features_flat)
|
| 212 |
+
logits = model.classifier_head.full_head(shared_features) # (batch * seq_len, 8)
|
| 213 |
+
|
| 214 |
+
# Calculate loss
|
| 215 |
+
loss = criterion(logits, labels_flat)
|
| 216 |
+
|
| 217 |
+
# Backward pass
|
| 218 |
+
loss.backward()
|
| 219 |
+
torch.nn.utils.clip_grad_norm_(model.classifier_head.parameters(), max_norm=1.0)
|
| 220 |
+
optimizer.step()
|
| 221 |
+
|
| 222 |
+
# Metrics
|
| 223 |
+
total_loss += loss.item()
|
| 224 |
+
preds = torch.argmax(logits, dim=-1).cpu().numpy()
|
| 225 |
+
all_preds.extend(preds)
|
| 226 |
+
all_labels.extend(labels_flat.cpu().numpy())
|
| 227 |
+
|
| 228 |
+
if batch_idx % 10 == 0:
|
| 229 |
+
logger.info(f"Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
|
| 230 |
+
|
| 231 |
+
avg_loss = total_loss / len(dataloader)
|
| 232 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 233 |
+
|
| 234 |
+
return {
|
| 235 |
+
'loss': avg_loss,
|
| 236 |
+
'accuracy': accuracy
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def validate(
|
| 241 |
+
model: nn.Module,
|
| 242 |
+
dataloader: DataLoader,
|
| 243 |
+
criterion: nn.Module,
|
| 244 |
+
device: torch.device
|
| 245 |
+
) -> Dict[str, float]:
|
| 246 |
+
"""Validate model."""
|
| 247 |
+
model.eval()
|
| 248 |
+
total_loss = 0.0
|
| 249 |
+
all_preds = []
|
| 250 |
+
all_labels = []
|
| 251 |
+
|
| 252 |
+
with torch.no_grad():
|
| 253 |
+
for batch in dataloader:
|
| 254 |
+
features = batch['features'].to(device)
|
| 255 |
+
labels = batch['labels'].to(device)
|
| 256 |
+
|
| 257 |
+
batch_size, seq_len, feat_dim = features.shape
|
| 258 |
+
features_flat = features.view(-1, feat_dim)
|
| 259 |
+
labels_flat = labels.view(-1)
|
| 260 |
+
|
| 261 |
+
# Forward pass
|
| 262 |
+
shared_features = model.classifier_head.shared_layers(features_flat)
|
| 263 |
+
logits = model.classifier_head.full_head(shared_features)
|
| 264 |
+
|
| 265 |
+
loss = criterion(logits, labels_flat)
|
| 266 |
+
total_loss += loss.item()
|
| 267 |
+
|
| 268 |
+
preds = torch.argmax(logits, dim=-1).cpu().numpy()
|
| 269 |
+
all_preds.extend(preds)
|
| 270 |
+
all_labels.extend(labels_flat.cpu().numpy())
|
| 271 |
+
|
| 272 |
+
avg_loss = total_loss / len(dataloader)
|
| 273 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 274 |
+
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
|
| 275 |
+
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
|
| 276 |
+
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
|
| 277 |
+
|
| 278 |
+
# Per-class metrics
|
| 279 |
+
cm = confusion_matrix(all_labels, all_preds, labels=list(range(8)))
|
| 280 |
+
|
| 281 |
+
return {
|
| 282 |
+
'loss': avg_loss,
|
| 283 |
+
'accuracy': accuracy,
|
| 284 |
+
'f1_score': f1,
|
| 285 |
+
'precision': precision,
|
| 286 |
+
'recall': recall,
|
| 287 |
+
'confusion_matrix': cm.tolist()
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def main():
|
| 292 |
+
parser = argparse.ArgumentParser(description="Train classifier head")
|
| 293 |
+
parser.add_argument('--config', type=str, default='training/config.yaml',
|
| 294 |
+
help='Path to config file')
|
| 295 |
+
parser.add_argument('--resume', type=str, default=None,
|
| 296 |
+
help='Resume from checkpoint')
|
| 297 |
+
args = parser.parse_args()
|
| 298 |
+
|
| 299 |
+
# Load config
|
| 300 |
+
with open(args.config, 'r') as f:
|
| 301 |
+
config = yaml.safe_load(f)
|
| 302 |
+
|
| 303 |
+
# Set device
|
| 304 |
+
device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu')
|
| 305 |
+
logger.info(f"Using device: {device}")
|
| 306 |
+
|
| 307 |
+
# Load training data
|
| 308 |
+
training_file = Path(config['data']['training_dataset'])
|
| 309 |
+
if not training_file.exists():
|
| 310 |
+
logger.error(f"Training dataset not found: {training_file}")
|
| 311 |
+
logger.info("Run scripts/annotation_helper.py to export training data first")
|
| 312 |
+
return
|
| 313 |
+
|
| 314 |
+
with open(training_file, 'r') as f:
|
| 315 |
+
training_data = json.load(f)
|
| 316 |
+
|
| 317 |
+
logger.info(f"Loaded {len(training_data)} training samples")
|
| 318 |
+
|
| 319 |
+
# Initialize inference pipeline for feature extraction
|
| 320 |
+
inference_pipeline = InferencePipeline(
|
| 321 |
+
audio_config=default_audio_config,
|
| 322 |
+
model_config=default_model_config,
|
| 323 |
+
inference_config=default_inference_config
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Initialize phoneme mapper
|
| 327 |
+
phoneme_mapper = PhonemeMapper(
|
| 328 |
+
frame_duration_ms=20,
|
| 329 |
+
sample_rate=16000
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Create dataset
|
| 333 |
+
dataset = PhonemeDataset(training_data, inference_pipeline, phoneme_mapper)
|
| 334 |
+
|
| 335 |
+
# Split dataset
|
| 336 |
+
train_size = int(config['data']['train_split'] * len(dataset))
|
| 337 |
+
val_size = len(dataset) - train_size
|
| 338 |
+
|
| 339 |
+
train_dataset, val_dataset = random_split(
|
| 340 |
+
dataset,
|
| 341 |
+
[train_size, val_size],
|
| 342 |
+
generator=torch.Generator().manual_seed(config['data']['random_seed'])
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
| 346 |
+
|
| 347 |
+
# Create data loaders
|
| 348 |
+
train_loader = DataLoader(
|
| 349 |
+
train_dataset,
|
| 350 |
+
batch_size=config['training']['batch_size'],
|
| 351 |
+
shuffle=True,
|
| 352 |
+
collate_fn=collate_fn
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
val_loader = DataLoader(
|
| 356 |
+
val_dataset,
|
| 357 |
+
batch_size=config['training']['batch_size'],
|
| 358 |
+
shuffle=False,
|
| 359 |
+
collate_fn=collate_fn
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# Load model
|
| 363 |
+
model = inference_pipeline.model
|
| 364 |
+
model.train() # Set to training mode
|
| 365 |
+
|
| 366 |
+
# Freeze Wav2Vec2 (should already be frozen, but ensure it)
|
| 367 |
+
for param in model.wav2vec2_model.parameters():
|
| 368 |
+
param.requires_grad = False
|
| 369 |
+
|
| 370 |
+
# Unfreeze classifier head
|
| 371 |
+
for param in model.classifier_head.parameters():
|
| 372 |
+
param.requires_grad = True
|
| 373 |
+
|
| 374 |
+
logger.info("Model prepared: Wav2Vec2 frozen, classifier head trainable")
|
| 375 |
+
|
| 376 |
+
# Calculate class weights
|
| 377 |
+
class_weights = calculate_class_weights(dataset)
|
| 378 |
+
class_weights = class_weights.to(device)
|
| 379 |
+
|
| 380 |
+
# Loss function
|
| 381 |
+
if config['training']['loss']['type'] == 'cross_entropy':
|
| 382 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
| 383 |
+
else:
|
| 384 |
+
# Focal loss implementation would go here
|
| 385 |
+
criterion = nn.CrossEntropyLoss(weight=class_weights)
|
| 386 |
+
|
| 387 |
+
# Optimizer
|
| 388 |
+
optimizer = optim.Adam(
|
| 389 |
+
model.classifier_head.parameters(),
|
| 390 |
+
lr=config['training']['learning_rate'],
|
| 391 |
+
weight_decay=config['training']['weight_decay']
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
# Scheduler
|
| 395 |
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
| 396 |
+
optimizer,
|
| 397 |
+
mode='min',
|
| 398 |
+
factor=config['training']['scheduler_factor'],
|
| 399 |
+
patience=config['training']['scheduler_patience'],
|
| 400 |
+
min_lr=config['training']['scheduler_min_lr']
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# Training loop
|
| 404 |
+
best_val_loss = float('inf')
|
| 405 |
+
patience_counter = 0
|
| 406 |
+
|
| 407 |
+
checkpoint_dir = Path(config['checkpoint']['save_dir'])
|
| 408 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 409 |
+
|
| 410 |
+
for epoch in range(config['training']['num_epochs']):
|
| 411 |
+
logger.info(f"\n{'='*50}")
|
| 412 |
+
logger.info(f"Epoch {epoch+1}/{config['training']['num_epochs']}")
|
| 413 |
+
logger.info(f"{'='*50}")
|
| 414 |
+
|
| 415 |
+
# Train
|
| 416 |
+
train_metrics = train_epoch(model, train_loader, optimizer, criterion, device, epoch+1)
|
| 417 |
+
logger.info(f"Train - Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}")
|
| 418 |
+
|
| 419 |
+
# Validate
|
| 420 |
+
val_metrics = validate(model, val_loader, criterion, device)
|
| 421 |
+
logger.info(f"Val - Loss: {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}, "
|
| 422 |
+
f"F1: {val_metrics['f1_score']:.4f}")
|
| 423 |
+
|
| 424 |
+
# Scheduler step
|
| 425 |
+
scheduler.step(val_metrics['loss'])
|
| 426 |
+
|
| 427 |
+
# Save checkpoint
|
| 428 |
+
if config['checkpoint']['save_best'] and val_metrics['loss'] < best_val_loss:
|
| 429 |
+
best_val_loss = val_metrics['loss']
|
| 430 |
+
checkpoint_path = checkpoint_dir / config['checkpoint']['best_filename']
|
| 431 |
+
torch.save({
|
| 432 |
+
'epoch': epoch,
|
| 433 |
+
'model_state_dict': model.classifier_head.state_dict(),
|
| 434 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 435 |
+
'val_loss': val_metrics['loss'],
|
| 436 |
+
'val_accuracy': val_metrics['accuracy'],
|
| 437 |
+
'config': config
|
| 438 |
+
}, checkpoint_path)
|
| 439 |
+
logger.info(f"β
Saved best checkpoint to {checkpoint_path}")
|
| 440 |
+
patience_counter = 0
|
| 441 |
+
else:
|
| 442 |
+
patience_counter += 1
|
| 443 |
+
|
| 444 |
+
# Early stopping
|
| 445 |
+
if config['training']['early_stopping']['enabled']:
|
| 446 |
+
if patience_counter >= config['training']['early_stopping']['patience']:
|
| 447 |
+
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
| 448 |
+
break
|
| 449 |
+
|
| 450 |
+
# Save last checkpoint
|
| 451 |
+
if config['checkpoint']['save_last'] and (epoch + 1) % config['checkpoint']['save_frequency'] == 0:
|
| 452 |
+
checkpoint_path = checkpoint_dir / config['checkpoint']['filename']
|
| 453 |
+
torch.save({
|
| 454 |
+
'epoch': epoch,
|
| 455 |
+
'model_state_dict': model.classifier_head.state_dict(),
|
| 456 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 457 |
+
'val_loss': val_metrics['loss'],
|
| 458 |
+
'val_accuracy': val_metrics['accuracy'],
|
| 459 |
+
'config': config
|
| 460 |
+
}, checkpoint_path)
|
| 461 |
+
logger.info(f"Saved checkpoint to {checkpoint_path}")
|
| 462 |
+
|
| 463 |
+
logger.info("\nβ
Training complete!")
|
| 464 |
+
logger.info(f"Best validation loss: {best_val_loss:.4f}")
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
if __name__ == "__main__":
|
| 468 |
+
main()
|
| 469 |
+
|
|
@@ -17,6 +17,8 @@ import numpy as np
|
|
| 17 |
import gradio as gr
|
| 18 |
|
| 19 |
from diagnosis.ai_engine.model_loader import get_inference_pipeline
|
|
|
|
|
|
|
| 20 |
from config import GradioConfig, default_gradio_config
|
| 21 |
|
| 22 |
logger = logging.getLogger(__name__)
|
|
@@ -83,8 +85,9 @@ def format_articulation_issues(articulation_scores: list) -> str:
|
|
| 83 |
|
| 84 |
def analyze_speech(
|
| 85 |
audio_input: Optional[Tuple[int, np.ndarray]],
|
| 86 |
-
audio_file: Optional[str]
|
| 87 |
-
|
|
|
|
| 88 |
"""
|
| 89 |
Analyze speech audio for fluency and articulation issues.
|
| 90 |
|
|
@@ -167,6 +170,72 @@ def analyze_speech(
|
|
| 167 |
except: pass
|
| 168 |
# #endregion
|
| 169 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
# Calculate processing time
|
| 171 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 172 |
|
|
@@ -240,7 +309,121 @@ def analyze_speech(
|
|
| 240 |
</div>
|
| 241 |
"""
|
| 242 |
|
| 243 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 244 |
json_output = {
|
| 245 |
"status": "success",
|
| 246 |
"fluency_metrics": {
|
|
@@ -259,6 +442,18 @@ def analyze_speech(
|
|
| 259 |
"confidence": avg_confidence,
|
| 260 |
"confidence_percentage": confidence_percentage,
|
| 261 |
"processing_time_ms": processing_time_ms,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
"frame_predictions": [
|
| 263 |
{
|
| 264 |
"time": fp.time,
|
|
@@ -266,9 +461,10 @@ def analyze_speech(
|
|
| 266 |
"fluency_label": fp.fluency_label,
|
| 267 |
"articulation_class": fp.articulation_class,
|
| 268 |
"articulation_label": fp.articulation_label,
|
| 269 |
-
"confidence": fp.confidence
|
|
|
|
| 270 |
}
|
| 271 |
-
for fp in result.frame_predictions[:20] # First 20 frames for preview
|
| 272 |
]
|
| 273 |
}
|
| 274 |
|
|
@@ -289,6 +485,7 @@ def analyze_speech(
|
|
| 289 |
articulation_text,
|
| 290 |
confidence_html,
|
| 291 |
processing_time_html,
|
|
|
|
| 292 |
json_output
|
| 293 |
)
|
| 294 |
|
|
@@ -302,11 +499,13 @@ def analyze_speech(
|
|
| 302 |
except: pass
|
| 303 |
# #endregion
|
| 304 |
error_html = f"<p style='color: red;'>β Error: {str(e)}</p>"
|
|
|
|
| 305 |
return (
|
| 306 |
error_html,
|
| 307 |
f"Error: {str(e)}",
|
| 308 |
"N/A",
|
| 309 |
"N/A",
|
|
|
|
| 310 |
{"error": str(e), "status": "error"}
|
| 311 |
)
|
| 312 |
|
|
@@ -370,6 +569,13 @@ def create_gradio_interface(gradio_config: Optional[GradioConfig] = None) -> gr.
|
|
| 370 |
format="wav"
|
| 371 |
)
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
analyze_btn = gr.Button(
|
| 374 |
"π Analyze Speech",
|
| 375 |
variant="primary",
|
|
@@ -409,6 +615,11 @@ def create_gradio_interface(gradio_config: Optional[GradioConfig] = None) -> gr.
|
|
| 409 |
elem_classes=["output-box"]
|
| 410 |
)
|
| 411 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
json_output = gr.JSON(
|
| 413 |
label="Detailed Results (JSON)",
|
| 414 |
elem_classes=["output-box"]
|
|
@@ -417,12 +628,13 @@ def create_gradio_interface(gradio_config: Optional[GradioConfig] = None) -> gr.
|
|
| 417 |
# Set up event handlers
|
| 418 |
analyze_btn.click(
|
| 419 |
fn=analyze_speech,
|
| 420 |
-
inputs=[audio_mic, audio_file],
|
| 421 |
outputs=[
|
| 422 |
fluency_output,
|
| 423 |
articulation_output,
|
| 424 |
confidence_output,
|
| 425 |
processing_time_output,
|
|
|
|
| 426 |
json_output
|
| 427 |
]
|
| 428 |
)
|
|
|
|
| 17 |
import gradio as gr
|
| 18 |
|
| 19 |
from diagnosis.ai_engine.model_loader import get_inference_pipeline
|
| 20 |
+
from api.routes import get_phoneme_mapper, get_error_mapper
|
| 21 |
+
from models.error_taxonomy import ErrorType, SeverityLevel
|
| 22 |
from config import GradioConfig, default_gradio_config
|
| 23 |
|
| 24 |
logger = logging.getLogger(__name__)
|
|
|
|
| 85 |
|
| 86 |
def analyze_speech(
|
| 87 |
audio_input: Optional[Tuple[int, np.ndarray]],
|
| 88 |
+
audio_file: Optional[str],
|
| 89 |
+
expected_text: Optional[str] = None
|
| 90 |
+
) -> Tuple[str, str, str, str, str, Dict[str, Any]]:
|
| 91 |
"""
|
| 92 |
Analyze speech audio for fluency and articulation issues.
|
| 93 |
|
|
|
|
| 170 |
except: pass
|
| 171 |
# #endregion
|
| 172 |
|
| 173 |
+
# Get phoneme and error mappers
|
| 174 |
+
phoneme_mapper = get_phoneme_mapper()
|
| 175 |
+
error_mapper = get_error_mapper()
|
| 176 |
+
|
| 177 |
+
# Map phonemes to frames if text provided
|
| 178 |
+
frame_phonemes = []
|
| 179 |
+
if expected_text and phoneme_mapper:
|
| 180 |
+
try:
|
| 181 |
+
frame_phonemes = phoneme_mapper.map_text_to_frames(
|
| 182 |
+
expected_text,
|
| 183 |
+
num_frames=result.num_frames,
|
| 184 |
+
audio_duration=result.duration
|
| 185 |
+
)
|
| 186 |
+
logger.info(f"β
Mapped {len(frame_phonemes)} phonemes to frames")
|
| 187 |
+
except Exception as e:
|
| 188 |
+
logger.warning(f"β οΈ Phoneme mapping failed: {e}")
|
| 189 |
+
frame_phonemes = [''] * result.num_frames
|
| 190 |
+
else:
|
| 191 |
+
frame_phonemes = [''] * result.num_frames
|
| 192 |
+
|
| 193 |
+
# Process errors with error mapper
|
| 194 |
+
errors = []
|
| 195 |
+
error_table_rows = []
|
| 196 |
+
|
| 197 |
+
for i, frame_pred in enumerate(result.frame_predictions):
|
| 198 |
+
phoneme = frame_phonemes[i] if i < len(frame_phonemes) else ''
|
| 199 |
+
|
| 200 |
+
# Map classifier output to error detail (8-class system)
|
| 201 |
+
class_id = frame_pred.articulation_class
|
| 202 |
+
if frame_pred.fluency_label == 'stutter':
|
| 203 |
+
class_id += 4 # Add 4 for stutter classes (4-7)
|
| 204 |
+
|
| 205 |
+
# Get error detail
|
| 206 |
+
if error_mapper:
|
| 207 |
+
try:
|
| 208 |
+
error_detail = error_mapper.map_classifier_output(
|
| 209 |
+
class_id=class_id,
|
| 210 |
+
confidence=frame_pred.confidence,
|
| 211 |
+
phoneme=phoneme if phoneme else 'unknown',
|
| 212 |
+
fluency_label=frame_pred.fluency_label
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if error_detail.error_type != ErrorType.NORMAL:
|
| 216 |
+
errors.append((i, frame_pred.time, error_detail))
|
| 217 |
+
|
| 218 |
+
# Add to error table
|
| 219 |
+
severity_level = error_mapper.get_severity_level(error_detail.severity)
|
| 220 |
+
severity_color = {
|
| 221 |
+
SeverityLevel.NONE: "green",
|
| 222 |
+
SeverityLevel.LOW: "orange",
|
| 223 |
+
SeverityLevel.MEDIUM: "orange",
|
| 224 |
+
SeverityLevel.HIGH: "red"
|
| 225 |
+
}.get(severity_level, "gray")
|
| 226 |
+
|
| 227 |
+
error_table_rows.append({
|
| 228 |
+
"phoneme": error_detail.phoneme,
|
| 229 |
+
"time": f"{frame_pred.time:.2f}s",
|
| 230 |
+
"error_type": error_detail.error_type.value,
|
| 231 |
+
"wrong_sound": error_detail.wrong_sound or "N/A",
|
| 232 |
+
"severity": severity_level.value,
|
| 233 |
+
"severity_color": severity_color,
|
| 234 |
+
"therapy": error_detail.therapy[:80] + "..." if len(error_detail.therapy) > 80 else error_detail.therapy
|
| 235 |
+
})
|
| 236 |
+
except Exception as e:
|
| 237 |
+
logger.warning(f"Error mapping failed for frame {i}: {e}")
|
| 238 |
+
|
| 239 |
# Calculate processing time
|
| 240 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 241 |
|
|
|
|
| 309 |
</div>
|
| 310 |
"""
|
| 311 |
|
| 312 |
+
# Format error table with summary of problematic sounds
|
| 313 |
+
if error_table_rows:
|
| 314 |
+
# Group errors by phoneme to show which sounds have issues
|
| 315 |
+
phoneme_errors = {}
|
| 316 |
+
for row in error_table_rows:
|
| 317 |
+
phoneme = row['phoneme']
|
| 318 |
+
if phoneme not in phoneme_errors:
|
| 319 |
+
phoneme_errors[phoneme] = {
|
| 320 |
+
'count': 0,
|
| 321 |
+
'types': set(),
|
| 322 |
+
'severity': 'low',
|
| 323 |
+
'examples': []
|
| 324 |
+
}
|
| 325 |
+
phoneme_errors[phoneme]['count'] += 1
|
| 326 |
+
phoneme_errors[phoneme]['types'].add(row['error_type'])
|
| 327 |
+
if row['severity'] in ['high', 'medium']:
|
| 328 |
+
phoneme_errors[phoneme]['severity'] = row['severity']
|
| 329 |
+
if len(phoneme_errors[phoneme]['examples']) < 2:
|
| 330 |
+
phoneme_errors[phoneme]['examples'].append(row)
|
| 331 |
+
|
| 332 |
+
# Create summary section
|
| 333 |
+
problematic_sounds = sorted(phoneme_errors.keys())
|
| 334 |
+
summary_html = f"""
|
| 335 |
+
<div style='background-color: #fff3cd; border: 2px solid #ffc107; border-radius: 8px; padding: 15px; margin-bottom: 20px;'>
|
| 336 |
+
<h3 style='color: #856404; margin-top: 0;'>β οΈ Problematic Sounds Detected</h3>
|
| 337 |
+
<p style='color: #856404; font-size: 14px; margin-bottom: 10px;'>
|
| 338 |
+
<strong>{len(problematic_sounds)} sound(s) with issues:</strong> {', '.join([f'<strong style="color: red;">/{p}/</strong>' for p in problematic_sounds[:10]])}
|
| 339 |
+
{f'<span style="color: #666;">(+{len(problematic_sounds) - 10} more)</span>' if len(problematic_sounds) > 10 else ''}
|
| 340 |
+
</p>
|
| 341 |
+
<div style='display: flex; flex-wrap: wrap; gap: 10px;'>
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
for phoneme in problematic_sounds[:10]:
|
| 345 |
+
error_info = phoneme_errors[phoneme]
|
| 346 |
+
severity_color = 'red' if error_info['severity'] == 'high' else 'orange' if error_info['severity'] == 'medium' else '#666'
|
| 347 |
+
summary_html += f"""
|
| 348 |
+
<div style='background-color: white; border: 1px solid {severity_color}; border-radius: 4px; padding: 8px; min-width: 120px;'>
|
| 349 |
+
<strong style='color: {severity_color}; font-size: 18px;'>/{phoneme}/</strong>
|
| 350 |
+
<div style='font-size: 12px; color: #666;'>
|
| 351 |
+
{error_info['count']} error(s)<br/>
|
| 352 |
+
Types: {', '.join(error_info['types'])}
|
| 353 |
+
</div>
|
| 354 |
+
</div>
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
summary_html += """
|
| 358 |
+
</div>
|
| 359 |
+
</div>
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
# Create detailed error table
|
| 363 |
+
error_table_html = summary_html + """
|
| 364 |
+
<h4 style='color: #333; margin-top: 20px;'>π Detailed Error Report</h4>
|
| 365 |
+
<table style='width: 100%; border-collapse: collapse; margin: 10px 0; font-size: 13px;'>
|
| 366 |
+
<thead>
|
| 367 |
+
<tr style='background-color: #f0f0f0;'>
|
| 368 |
+
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Sound</th>
|
| 369 |
+
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Time</th>
|
| 370 |
+
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Error Type</th>
|
| 371 |
+
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Wrong Sound</th>
|
| 372 |
+
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Severity</th>
|
| 373 |
+
<th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Therapy Recommendation</th>
|
| 374 |
+
</tr>
|
| 375 |
+
</thead>
|
| 376 |
+
<tbody>
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
for row in error_table_rows[:20]: # Limit to first 20 errors
|
| 380 |
+
severity_bg = {
|
| 381 |
+
'high': '#ffebee',
|
| 382 |
+
'medium': '#fff3e0',
|
| 383 |
+
'low': '#f3e5f5',
|
| 384 |
+
'none': '#e8f5e9'
|
| 385 |
+
}.get(row['severity'], '#f5f5f5')
|
| 386 |
+
|
| 387 |
+
error_table_html += f"""
|
| 388 |
+
<tr style='background-color: {severity_bg};'>
|
| 389 |
+
<td style='padding: 10px; border: 1px solid #ddd;'>
|
| 390 |
+
<strong style='color: {row['severity_color']}; font-size: 16px;'>/{row['phoneme']}/</strong>
|
| 391 |
+
</td>
|
| 392 |
+
<td style='padding: 10px; border: 1px solid #ddd;'>{row['time']}</td>
|
| 393 |
+
<td style='padding: 10px; border: 1px solid #ddd;'>
|
| 394 |
+
<span style='background-color: {row['severity_color']}; color: white; padding: 3px 8px; border-radius: 3px; font-size: 11px;'>
|
| 395 |
+
{row['error_type'].upper()}
|
| 396 |
+
</span>
|
| 397 |
+
</td>
|
| 398 |
+
<td style='padding: 10px; border: 1px solid #ddd;'>
|
| 399 |
+
{f"<strong style='color: red;'>/{row['wrong_sound']}/</strong>" if row['wrong_sound'] != 'N/A' else '<span style="color: #999;">N/A</span>'}
|
| 400 |
+
</td>
|
| 401 |
+
<td style='padding: 10px; border: 1px solid #ddd;'>
|
| 402 |
+
<strong style='color: {row['severity_color']};'>{row['severity'].upper()}</strong>
|
| 403 |
+
</td>
|
| 404 |
+
<td style='padding: 10px; border: 1px solid #ddd; font-size: 12px;'>{row['therapy']}</td>
|
| 405 |
+
</tr>
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
error_table_html += """
|
| 409 |
+
</tbody>
|
| 410 |
+
</table>
|
| 411 |
+
"""
|
| 412 |
+
|
| 413 |
+
if len(error_table_rows) > 20:
|
| 414 |
+
error_table_html += f"<p style='color: #666; font-size: 12px; margin-top: 10px;'>π Showing first 20 of <strong>{len(error_table_rows)}</strong> total errors detected</p>"
|
| 415 |
+
else:
|
| 416 |
+
error_table_html = """
|
| 417 |
+
<div style='background-color: #d4edda; border: 2px solid #28a745; border-radius: 8px; padding: 20px; text-align: center;'>
|
| 418 |
+
<h3 style='color: #155724; margin-top: 0;'>β
No Errors Detected</h3>
|
| 419 |
+
<p style='color: #155724; font-size: 16px;'>
|
| 420 |
+
All sounds/phonemes were produced correctly!<br/>
|
| 421 |
+
<span style='font-size: 14px; color: #666;'>Great job! π</span>
|
| 422 |
+
</p>
|
| 423 |
+
</div>
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
# Create JSON output with errors
|
| 427 |
json_output = {
|
| 428 |
"status": "success",
|
| 429 |
"fluency_metrics": {
|
|
|
|
| 442 |
"confidence": avg_confidence,
|
| 443 |
"confidence_percentage": confidence_percentage,
|
| 444 |
"processing_time_ms": processing_time_ms,
|
| 445 |
+
"error_count": len(errors),
|
| 446 |
+
"errors": [
|
| 447 |
+
{
|
| 448 |
+
"phoneme": err[2].phoneme,
|
| 449 |
+
"time": err[1],
|
| 450 |
+
"error_type": err[2].error_type.value,
|
| 451 |
+
"wrong_sound": err[2].wrong_sound,
|
| 452 |
+
"severity": error_mapper.get_severity_level(err[2].severity).value if error_mapper else "unknown",
|
| 453 |
+
"therapy": err[2].therapy
|
| 454 |
+
}
|
| 455 |
+
for err in errors[:20]
|
| 456 |
+
] if errors else [],
|
| 457 |
"frame_predictions": [
|
| 458 |
{
|
| 459 |
"time": fp.time,
|
|
|
|
| 461 |
"fluency_label": fp.fluency_label,
|
| 462 |
"articulation_class": fp.articulation_class,
|
| 463 |
"articulation_label": fp.articulation_label,
|
| 464 |
+
"confidence": fp.confidence,
|
| 465 |
+
"phoneme": frame_phonemes[i] if i < len(frame_phonemes) else ''
|
| 466 |
}
|
| 467 |
+
for i, fp in enumerate(result.frame_predictions[:20]) # First 20 frames for preview
|
| 468 |
]
|
| 469 |
}
|
| 470 |
|
|
|
|
| 485 |
articulation_text,
|
| 486 |
confidence_html,
|
| 487 |
processing_time_html,
|
| 488 |
+
error_table_html,
|
| 489 |
json_output
|
| 490 |
)
|
| 491 |
|
|
|
|
| 499 |
except: pass
|
| 500 |
# #endregion
|
| 501 |
error_html = f"<p style='color: red;'>β Error: {str(e)}</p>"
|
| 502 |
+
error_table_html = "<p style='color: #999;'>No error details available</p>"
|
| 503 |
return (
|
| 504 |
error_html,
|
| 505 |
f"Error: {str(e)}",
|
| 506 |
"N/A",
|
| 507 |
"N/A",
|
| 508 |
+
error_table_html,
|
| 509 |
{"error": str(e), "status": "error"}
|
| 510 |
)
|
| 511 |
|
|
|
|
| 569 |
format="wav"
|
| 570 |
)
|
| 571 |
|
| 572 |
+
expected_text_input = gr.Textbox(
|
| 573 |
+
label="Expected Text (Optional)",
|
| 574 |
+
placeholder="Enter the expected text/transcript for phoneme mapping",
|
| 575 |
+
lines=2,
|
| 576 |
+
info="Provide the expected text to enable phoneme-level error detection"
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
analyze_btn = gr.Button(
|
| 580 |
"π Analyze Speech",
|
| 581 |
variant="primary",
|
|
|
|
| 615 |
elem_classes=["output-box"]
|
| 616 |
)
|
| 617 |
|
| 618 |
+
error_table_output = gr.HTML(
|
| 619 |
+
label="Error Details",
|
| 620 |
+
elem_classes=["output-box"]
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
json_output = gr.JSON(
|
| 624 |
label="Detailed Results (JSON)",
|
| 625 |
elem_classes=["output-box"]
|
|
|
|
| 628 |
# Set up event handlers
|
| 629 |
analyze_btn.click(
|
| 630 |
fn=analyze_speech,
|
| 631 |
+
inputs=[audio_mic, audio_file, expected_text_input],
|
| 632 |
outputs=[
|
| 633 |
fluency_output,
|
| 634 |
articulation_output,
|
| 635 |
confidence_output,
|
| 636 |
processing_time_output,
|
| 637 |
+
error_table_output,
|
| 638 |
json_output
|
| 639 |
]
|
| 640 |
)
|