anfastech commited on
Commit
1cd6149
Β·
1 Parent(s): 278e294

New: Phoneme-level speech pathology diagnosis MVP with real-time streaming

Browse files

- Add Wav2Vec2-XLSR-53 based speech pathology classifier with 8-class output (fluency + articulation)
- Implement phone-level feature extraction with 1-second sliding windows and 10ms hops
- Add grapheme-to-phoneme mapping using g2p_en library with frame alignment
- Create error taxonomy system with substitution/omission/distortion detection and therapy recommendations
- Build FastAPI REST API with batch diagnosis endpoints (/diagnose/file) and WebSocket streaming (/ws/diagnose)
- Add Gradio web interface with audio upload/recording, real-time error display, and detailed reports
- Implement training infrastructure: synthetic data generation, classifier head training, and evaluation scripts
- Add Docker containerization with NLTK data download for phoneme mapping
- Fix model loading issues and improve error handling throughout the pipeline
- Remove unnecessary package-lock.json file (Python project)

Features:
- Real-time streaming analysis (<200ms per file, <50ms per frame target)
- Phoneme-level error detection with visual feedback
- Therapy recommendation system with clinical guidance
- Comprehensive error reporting with severity levels and timelines
- WebSocket-based real-time diagnosis for live audio streams
- REST API for batch processing with detailed JSON responses

Infrastructure:
- Training pipeline for classifier fine-tuning
- Data collection tools for phoneme-level annotation
- Performance testing and integration testing suites
- Production-ready logging and error handling

.gitignore CHANGED
@@ -7,5 +7,5 @@ __pycache__/
7
 
8
  .gradio/
9
  .cursor/
10
- package-lock.json
11
  docker-compose.yml
 
7
 
8
  .gradio/
9
  .cursor/
10
+ # package-lock.json
11
  docker-compose.yml
Dockerfile CHANGED
@@ -27,6 +27,9 @@ RUN pip install --no-cache-dir \
27
  # Install the rest of requirements
28
  RUN pip install --no-cache-dir -r requirements.txt
29
 
 
 
 
30
  # Copy application files
31
  COPY . .
32
 
 
27
  # Install the rest of requirements
28
  RUN pip install --no-cache-dir -r requirements.txt
29
 
30
+ # Download NLTK data required by g2p_en
31
+ RUN python -c "import nltk; nltk.download('averaged_perceptron_tagger_eng', quiet=True)"
32
+
33
  # Copy application files
34
  COPY . .
35
 
README_TRAINING.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Guide for Speech Pathology Classifier
2
+
3
+ This guide explains how to train the classifier head for phoneme-level speech pathology detection.
4
+
5
+ ## Overview
6
+
7
+ The system uses Wav2Vec2-XLSR-53 as a frozen feature extractor and trains only the classification head (2-3 layer feedforward network) on phoneme-level labeled data.
8
+
9
+ ## Prerequisites
10
+
11
+ 1. **Labeled Data**: 50-100 audio samples with phoneme-level error annotations
12
+ 2. **Python Environment**: Python 3.10+ with required dependencies
13
+ 3. **GPU** (recommended): For faster training
14
+
15
+ ## Step 1: Data Collection
16
+
17
+ ### Using the Annotation Tool
18
+
19
+ 1. Launch the data collection interface:
20
+ ```bash
21
+ python scripts/data_collection.py
22
+ ```
23
+
24
+ 2. The Gradio interface will open at `http://localhost:7861`
25
+
26
+ 3. For each sample:
27
+ - Upload or record audio (5-30 seconds, 16kHz WAV)
28
+ - Enter expected text/transcript
29
+ - Extract phonemes (automatic G2P conversion)
30
+ - Annotate errors at phoneme level:
31
+ - Frame ID where error occurs
32
+ - Phoneme with error
33
+ - Error type (substitution/omission/distortion/stutter)
34
+ - Wrong sound (for substitutions)
35
+ - Severity (0-1)
36
+ - Timestamp
37
+ - Add notes if needed
38
+ - Save annotation
39
+
40
+ 4. Annotations are saved to:
41
+ - Audio files: `data/raw/`
42
+ - Annotations: `data/annotations.json`
43
+
44
+ ### Export Training Data
45
+
46
+ After collecting annotations, export for training:
47
+
48
+ ```bash
49
+ python scripts/annotation_helper.py
50
+ ```
51
+
52
+ This creates `data/training_dataset.json` with frame-level labels.
53
+
54
+ ## Step 2: Training
55
+
56
+ ### Configuration
57
+
58
+ Edit `training/config.yaml` to adjust hyperparameters:
59
+
60
+ - `batch_size`: 16 (adjust based on GPU memory)
61
+ - `learning_rate`: 0.001
62
+ - `num_epochs`: 50
63
+ - `train_split`: 0.8 (80% for training, 20% for validation)
64
+
65
+ ### Run Training
66
+
67
+ ```bash
68
+ python training/train_classifier_head.py --config training/config.yaml
69
+ ```
70
+
71
+ Training will:
72
+ - Load training dataset
73
+ - Extract Wav2Vec2 features for each sample
74
+ - Train classifier head (Wav2Vec2 frozen)
75
+ - Save best checkpoint to `models/checkpoints/classifier_head_best.pt`
76
+ - Save last checkpoint to `models/checkpoints/classifier_head_trained.pt`
77
+
78
+ ### Monitor Training
79
+
80
+ Training logs include:
81
+ - Loss per epoch
82
+ - Accuracy per epoch
83
+ - Validation metrics
84
+ - Best model checkpoint saves
85
+
86
+ ## Step 3: Evaluation
87
+
88
+ Evaluate the trained model:
89
+
90
+ ```bash
91
+ python training/evaluate_classifier.py \
92
+ --checkpoint models/checkpoints/classifier_head_best.pt \
93
+ --dataset data/training_dataset.json \
94
+ --output training/evaluation_results.json \
95
+ --plot training/confusion_matrix.png
96
+ ```
97
+
98
+ This generates:
99
+ - Overall accuracy, F1 score, precision, recall
100
+ - Per-class accuracy
101
+ - Confusion matrix (saved as PNG)
102
+ - Confidence analysis
103
+ - Detailed metrics JSON
104
+
105
+ ## Step 4: Deployment
106
+
107
+ Once trained, the model automatically loads trained weights on startup:
108
+
109
+ 1. Place checkpoint in `models/checkpoints/classifier_head_best.pt`
110
+ 2. Restart the application
111
+ 3. The model will detect and load trained weights automatically
112
+
113
+ ### Verify Training Status
114
+
115
+ Check API responses for:
116
+ - `model_version`: "wav2vec2-xlsr-53-v2-trained" (if trained) or "wav2vec2-xlsr-53-v2-beta" (if untrained)
117
+ - `model_trained`: true/false
118
+ - `confidence_filter_threshold`: 0.65
119
+
120
+ ## Troubleshooting
121
+
122
+ ### Issue: "No training dataset found"
123
+
124
+ **Solution**: Run `scripts/annotation_helper.py` to export training data from annotations.
125
+
126
+ ### Issue: "CUDA out of memory"
127
+
128
+ **Solution**: Reduce `batch_size` in `training/config.yaml` (try 8 or 4).
129
+
130
+ ### Issue: "Poor validation accuracy"
131
+
132
+ **Solutions**:
133
+ - Collect more training data (aim for 100+ samples)
134
+ - Check data quality (ensure accurate annotations)
135
+ - Adjust learning rate or add data augmentation
136
+ - Use class weights for imbalanced data
137
+
138
+ ### Issue: "Model not loading trained weights"
139
+
140
+ **Solution**:
141
+ - Verify checkpoint path: `models/checkpoints/classifier_head_best.pt`
142
+ - Check file permissions
143
+ - Review logs for loading errors
144
+
145
+ ## Best Practices
146
+
147
+ 1. **Data Quality > Quantity**: 50 high-quality samples > 100 poor samples
148
+ 2. **Balanced Classes**: Ensure all 8 classes have sufficient examples
149
+ 3. **Validation Split**: Use 20% for validation, never train on test data
150
+ 4. **Early Stopping**: Enabled by default to prevent overfitting
151
+ 5. **Class Weights**: Automatically calculated to handle imbalance
152
+ 6. **Checkpointing**: Best model saved automatically
153
+
154
+ ## Expected Results
155
+
156
+ After training with 50-100 samples:
157
+ - **Frame-level accuracy**: >75%
158
+ - **Phoneme-level F1**: >85%
159
+ - **Per-class accuracy**: >70% for each class
160
+ - **Confidence**: Higher for correct predictions
161
+
162
+ ## Next Steps
163
+
164
+ 1. Collect more data based on error patterns
165
+ 2. Fine-tune hyperparameters
166
+ 3. Add data augmentation
167
+ 4. Deploy and monitor in production
168
+ 5. Retrain quarterly with new data
169
+
170
+ ## Support
171
+
172
+ For issues or questions:
173
+ - Check training logs in console
174
+ - Review `training/evaluation_results.json`
175
+ - Verify data format in `data/annotations.json`
176
+
api/routes.py CHANGED
@@ -47,6 +47,16 @@ phoneme_mapper: Optional[PhonemeMapper] = None
47
  error_mapper: Optional[ErrorMapper] = None
48
 
49
 
 
 
 
 
 
 
 
 
 
 
50
  def initialize_routes(
51
  pipeline: InferencePipeline,
52
  mapper: Optional[PhonemeMapper] = None,
@@ -263,6 +273,10 @@ async def diagnose_file(
263
  processing_time_ms = (time.time() - start_time) * 1000
264
 
265
  # Create response
 
 
 
 
266
  response = BatchDiagnosisResponse(
267
  session_id=session_id,
268
  filename=audio.filename or "unknown",
@@ -274,7 +288,10 @@ async def diagnose_file(
274
  summary=summary,
275
  therapy_plan=therapy_plan,
276
  processing_time_ms=processing_time_ms,
277
- created_at=datetime.now()
 
 
 
278
  )
279
 
280
  # Store in sessions
 
47
  error_mapper: Optional[ErrorMapper] = None
48
 
49
 
50
+ def get_phoneme_mapper() -> Optional[PhonemeMapper]:
51
+ """Get the global PhonemeMapper instance."""
52
+ return phoneme_mapper
53
+
54
+
55
+ def get_error_mapper() -> Optional[ErrorMapper]:
56
+ """Get the global ErrorMapper instance."""
57
+ return error_mapper
58
+
59
+
60
  def initialize_routes(
61
  pipeline: InferencePipeline,
62
  mapper: Optional[PhonemeMapper] = None,
 
273
  processing_time_ms = (time.time() - start_time) * 1000
274
 
275
  # Create response
276
+ # Check if model is trained
277
+ model_trained = inference_pipeline.model.is_trained if hasattr(inference_pipeline.model, 'is_trained') else False
278
+ model_version = "wav2vec2-xlsr-53-v2-trained" if model_trained else "wav2vec2-xlsr-53-v2-beta"
279
+
280
  response = BatchDiagnosisResponse(
281
  session_id=session_id,
282
  filename=audio.filename or "unknown",
 
288
  summary=summary,
289
  therapy_plan=therapy_plan,
290
  processing_time_ms=processing_time_ms,
291
+ created_at=datetime.utcnow(),
292
+ model_version=model_version,
293
+ model_trained=model_trained,
294
+ confidence_filter_threshold=0.65
295
  )
296
 
297
  # Store in sessions
api/schemas.py CHANGED
@@ -76,6 +76,9 @@ class BatchDiagnosisResponse(BaseModel):
76
  therapy_plan: List[str] = Field(default_factory=list, description="Therapy recommendations")
77
  processing_time_ms: float = Field(..., ge=0.0, description="Processing time in milliseconds")
78
  created_at: datetime = Field(default_factory=datetime.now, description="Analysis timestamp")
 
 
 
79
 
80
 
81
  class StreamingDiagnosisRequest(BaseModel):
 
76
  therapy_plan: List[str] = Field(default_factory=list, description="Therapy recommendations")
77
  processing_time_ms: float = Field(..., ge=0.0, description="Processing time in milliseconds")
78
  created_at: datetime = Field(default_factory=datetime.now, description="Analysis timestamp")
79
+ model_version: str = Field(default="wav2vec2-xlsr-53-v2", description="Model version identifier")
80
+ model_trained: bool = Field(default=False, description="Whether classifier head is trained")
81
+ confidence_filter_threshold: float = Field(default=0.65, ge=0.0, le=1.0, description="Confidence threshold for filtering predictions")
82
 
83
 
84
  class StreamingDiagnosisRequest(BaseModel):
app.py CHANGED
@@ -7,7 +7,7 @@ from pathlib import Path
7
  from datetime import datetime
8
  from typing import Optional
9
 
10
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException, WebSocket, WebSocketDisconnect
11
  from fastapi.responses import JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
13
  import gradio as gr
@@ -26,8 +26,7 @@ sys.path.insert(0, str(Path(__file__).parent))
26
  # Import model loaders and inference pipeline
27
  try:
28
  from diagnosis.ai_engine.model_loader import (
29
- get_stutter_detector, # Legacy detector
30
- get_inference_pipeline # New inference pipeline
31
  )
32
  from ui.gradio_interface import create_gradio_interface
33
  from config import APIConfig, GradioConfig, default_api_config, default_gradio_config
@@ -53,27 +52,30 @@ app.add_middleware(
53
  )
54
 
55
  # Global instances
56
- detector = None # Legacy detector
57
- inference_pipeline = None # New inference pipeline
58
 
59
  @app.on_event("startup")
60
  async def startup_event():
61
  """Load models on startup"""
62
- global detector, inference_pipeline
63
  try:
64
  logger.info("πŸš€ Startup event: Loading AI models...")
65
 
66
- # Load legacy detector (for backward compatibility)
67
- try:
68
- detector = get_stutter_detector()
69
- logger.info("βœ… Legacy detector loaded")
70
- except Exception as e:
71
- logger.warning(f"⚠️ Legacy detector not available: {e}")
72
-
73
- # Load new inference pipeline
74
  try:
75
  inference_pipeline = get_inference_pipeline()
76
  logger.info("βœ… Inference pipeline loaded")
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
  logger.error(f"❌ Failed to load inference pipeline: {e}", exc_info=True)
79
  # Don't raise - allow API to start even if new pipeline fails
@@ -83,6 +85,24 @@ async def startup_event():
83
  logger.error(f"❌ Failed to load models: {e}", exc_info=True)
84
  raise
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  # Create and mount new Gradio interface
87
  try:
88
  gradio_interface = create_gradio_interface(default_gradio_config)
@@ -103,31 +123,30 @@ async def health_check():
103
  return {
104
  "status": "healthy",
105
  "models_loaded": {
106
- "legacy_detector": detector is not None,
107
- "inference_pipeline": inference_pipeline is not None
108
  },
109
  "timestamp": datetime.utcnow().isoformat() + "Z"
110
  }
111
 
112
  @app.post("/api/diagnose")
113
  async def diagnose_speech(
114
- audio: UploadFile = File(...)
 
115
  ):
116
  """
117
- Diagnose speech for fluency and articulation issues.
118
 
119
- Uses the new Wav2Vec2-XLSR-53 inference pipeline for phone-level analysis.
 
 
120
 
121
  Parameters:
122
  - audio: Audio file (WAV, MP3, FLAC, M4A)
 
123
 
124
  Returns:
125
- Dictionary with:
126
- - status: "success" or "error"
127
- - fluency_metrics: Fluency statistics
128
- - articulation_results: Articulation analysis
129
- - confidence: Overall confidence score
130
- - processing_time_ms: Processing time in milliseconds
131
  """
132
  if not inference_pipeline:
133
  raise HTTPException(
@@ -135,11 +154,15 @@ async def diagnose_speech(
135
  detail="Inference pipeline not loaded yet. Try again in a moment."
136
  )
137
 
 
 
 
 
138
  start_time = time.time()
139
  temp_file = None
140
 
141
  try:
142
- logger.info(f"πŸ“₯ Processing diagnosis request: {audio.filename}")
143
 
144
  # Validate file extension
145
  file_ext = Path(audio.filename).suffix.lower()
@@ -173,7 +196,6 @@ async def diagnose_speech(
173
 
174
  # Run inference
175
  logger.info("πŸ”„ Running inference pipeline...")
176
- # Use new phone-level prediction
177
  result = inference_pipeline.predict_phone_level(
178
  temp_file,
179
  return_timestamps=True
@@ -181,28 +203,64 @@ async def diagnose_speech(
181
 
182
  processing_time_ms = (time.time() - start_time) * 1000
183
 
184
- # Extract metrics from new PhoneLevelResult format
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  aggregate = result.aggregate
186
  mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
187
- fluency_percentage = (1.0 - mean_fluency_stutter) * 100 # Convert stutter prob to fluency percentage
188
 
189
- # Count fluent frames
190
  fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
191
  fluent_frames_ratio = fluent_frames / result.num_frames if result.num_frames > 0 else 0.0
192
 
193
- # Extract articulation class distribution
194
  articulation_class_counts = {}
195
  for fp in result.frame_predictions:
196
  label = fp.articulation_label
197
  articulation_class_counts[label] = articulation_class_counts.get(label, 0) + 1
198
 
199
- # Get dominant articulation class
200
  dominant_articulation = aggregate.get("articulation_label", "normal")
201
-
202
- # Calculate average confidence
203
  avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
204
 
205
- # Format response
206
  response = {
207
  "status": "success",
208
  "fluency_metrics": {
@@ -225,9 +283,10 @@ async def diagnose_speech(
225
  "fluency_label": fp.fluency_label,
226
  "articulation_class": fp.articulation_class,
227
  "articulation_label": fp.articulation_label,
228
- "confidence": fp.confidence
 
229
  }
230
- for fp in result.frame_predictions
231
  ]
232
  },
233
  "confidence": avg_confidence,
@@ -235,8 +294,14 @@ async def diagnose_speech(
235
  "processing_time_ms": processing_time_ms
236
  }
237
 
238
- logger.info(f"βœ… Diagnosis complete: fluency={response['fluency_metrics']['fluency_percentage']:.1f}%, "
239
- f"confidence={response['confidence_percentage']:.1f}%, "
 
 
 
 
 
 
240
  f"time={processing_time_ms:.0f}ms")
241
 
242
  return response
@@ -257,70 +322,7 @@ async def diagnose_speech(
257
  logger.warning(f"Could not clean up {temp_file}: {e}")
258
 
259
 
260
- @app.post("/analyze")
261
- async def analyze_audio(
262
- audio: UploadFile = File(...),
263
- transcript: str = Form("")
264
- ):
265
- """
266
- Legacy endpoint: Analyze audio file for stuttering.
267
-
268
- Uses the legacy Whisper-based detector for backward compatibility.
269
-
270
- Parameters:
271
- - audio: WAV or MP3 audio file
272
- - transcript: Optional expected transcript
273
-
274
- Returns: Complete stutter analysis results
275
- """
276
- if not detector:
277
- raise HTTPException(
278
- status_code=503,
279
- detail="Legacy detector not loaded. Use /api/diagnose for new analysis."
280
- )
281
-
282
- temp_file = None
283
- try:
284
- logger.info(f"πŸ“₯ Processing (legacy): {audio.filename}")
285
-
286
- # Create temp directory if needed
287
- temp_dir = tempfile.gettempdir()
288
- os.makedirs(temp_dir, exist_ok=True)
289
-
290
- # Save uploaded file
291
- temp_file = os.path.join(temp_dir, f"legacy_{int(time.time())}_{audio.filename}")
292
- content = await audio.read()
293
-
294
- with open(temp_file, "wb") as f:
295
- f.write(content)
296
-
297
- logger.info(f"πŸ“‚ Saved to: {temp_file} ({len(content) / 1024 / 1024:.2f} MB)")
298
-
299
- # Analyze
300
- logger.info(f"πŸ”„ Analyzing audio with transcript: '{transcript[:50] if transcript else '(empty)'}...'")
301
- result = detector.analyze_audio(temp_file, transcript)
302
-
303
- actual = result.get('actual_transcript', '')
304
- target = result.get('target_transcript', '')
305
- logger.info(f"βœ… Analysis complete: severity={result.get('severity', 'N/A')}, "
306
- f"mismatch={result.get('mismatch_percentage', 'N/A')}%")
307
-
308
- return result
309
-
310
- except HTTPException:
311
- raise
312
- except Exception as e:
313
- logger.error(f"❌ Error during analysis: {str(e)}", exc_info=True)
314
- raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
315
-
316
- finally:
317
- # Cleanup
318
- if temp_file and os.path.exists(temp_file):
319
- try:
320
- os.remove(temp_file)
321
- logger.debug(f"🧹 Cleaned up: {temp_file}")
322
- except Exception as e:
323
- logger.warning(f"Could not clean up {temp_file}: {e}")
324
 
325
 
326
  @app.websocket("/ws/audio")
 
7
  from datetime import datetime
8
  from typing import Optional
9
 
10
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException, WebSocket, WebSocketDisconnect, Query
11
  from fastapi.responses import JSONResponse
12
  from fastapi.middleware.cors import CORSMiddleware
13
  import gradio as gr
 
26
  # Import model loaders and inference pipeline
27
  try:
28
  from diagnosis.ai_engine.model_loader import (
29
+ get_inference_pipeline # Wav2Vec2-based inference pipeline
 
30
  )
31
  from ui.gradio_interface import create_gradio_interface
32
  from config import APIConfig, GradioConfig, default_api_config, default_gradio_config
 
52
  )
53
 
54
  # Global instances
55
+ inference_pipeline = None # Wav2Vec2-based inference pipeline
 
56
 
57
  @app.on_event("startup")
58
  async def startup_event():
59
  """Load models on startup"""
60
+ global inference_pipeline
61
  try:
62
  logger.info("πŸš€ Startup event: Loading AI models...")
63
 
64
+ # Load Wav2Vec2-based inference pipeline
 
 
 
 
 
 
 
65
  try:
66
  inference_pipeline = get_inference_pipeline()
67
  logger.info("βœ… Inference pipeline loaded")
68
+
69
+ # Initialize API routes with phoneme and error mappers
70
+ try:
71
+ from api.routes import initialize_routes
72
+ from api.streaming import initialize_streaming
73
+ initialize_routes(inference_pipeline)
74
+ initialize_streaming(inference_pipeline)
75
+ logger.info("βœ… API routes initialized with phoneme/error mappers")
76
+ except Exception as e:
77
+ logger.warning(f"⚠️ API routes initialization failed: {e}", exc_info=True)
78
+ # Continue without phoneme mapping if it fails
79
  except Exception as e:
80
  logger.error(f"❌ Failed to load inference pipeline: {e}", exc_info=True)
81
  # Don't raise - allow API to start even if new pipeline fails
 
85
  logger.error(f"❌ Failed to load models: {e}", exc_info=True)
86
  raise
87
 
88
+ # Include API routers
89
+ try:
90
+ from api.routes import router as diagnose_router
91
+ app.include_router(diagnose_router)
92
+ logger.info("βœ… Diagnosis router included")
93
+ except Exception as e:
94
+ logger.warning(f"⚠️ Failed to include diagnosis router: {e}")
95
+
96
+ # Add WebSocket endpoint
97
+ try:
98
+ from api.streaming import handle_streaming_websocket
99
+ @app.websocket("/ws/diagnose")
100
+ async def websocket_diagnose(websocket: WebSocket, session_id: Optional[str] = None):
101
+ await handle_streaming_websocket(websocket, session_id)
102
+ logger.info("βœ… WebSocket endpoint registered")
103
+ except Exception as e:
104
+ logger.warning(f"⚠️ Failed to register WebSocket endpoint: {e}")
105
+
106
  # Create and mount new Gradio interface
107
  try:
108
  gradio_interface = create_gradio_interface(default_gradio_config)
 
123
  return {
124
  "status": "healthy",
125
  "models_loaded": {
126
+ "inference_pipeline": inference_pipeline is not None,
127
+ "model_version": "wav2vec2-xlsr-53-v2"
128
  },
129
  "timestamp": datetime.utcnow().isoformat() + "Z"
130
  }
131
 
132
  @app.post("/api/diagnose")
133
  async def diagnose_speech(
134
+ audio: UploadFile = File(...),
135
+ text: Optional[str] = Query(None, description="Expected text/transcript for phoneme mapping (optional)")
136
  ):
137
  """
138
+ Legacy endpoint for speech diagnosis.
139
 
140
+ NOTE: For full phoneme-level error detection with therapy recommendations,
141
+ use POST /diagnose/file?text=<expected_text> instead.
142
+ This endpoint is maintained for backward compatibility.
143
 
144
  Parameters:
145
  - audio: Audio file (WAV, MP3, FLAC, M4A)
146
+ - text: Optional expected text for phoneme mapping
147
 
148
  Returns:
149
+ Dictionary with diagnosis results (legacy format for backward compatibility)
 
 
 
 
 
150
  """
151
  if not inference_pipeline:
152
  raise HTTPException(
 
154
  detail="Inference pipeline not loaded yet. Try again in a moment."
155
  )
156
 
157
+ # Import here to avoid circular imports
158
+ from api.routes import get_phoneme_mapper, get_error_mapper
159
+ from models.error_taxonomy import ErrorType
160
+
161
  start_time = time.time()
162
  temp_file = None
163
 
164
  try:
165
+ logger.info(f"πŸ“₯ Processing legacy diagnosis request: {audio.filename}")
166
 
167
  # Validate file extension
168
  file_ext = Path(audio.filename).suffix.lower()
 
196
 
197
  # Run inference
198
  logger.info("πŸ”„ Running inference pipeline...")
 
199
  result = inference_pipeline.predict_phone_level(
200
  temp_file,
201
  return_timestamps=True
 
203
 
204
  processing_time_ms = (time.time() - start_time) * 1000
205
 
206
+ # Get mappers for phoneme/error processing
207
+ phoneme_mapper = get_phoneme_mapper()
208
+ error_mapper = get_error_mapper()
209
+
210
+ # Map phonemes if text provided
211
+ frame_phonemes = []
212
+ errors = []
213
+ if text and phoneme_mapper and error_mapper:
214
+ try:
215
+ frame_phonemes = phoneme_mapper.map_text_to_frames(
216
+ text,
217
+ num_frames=result.num_frames,
218
+ audio_duration=result.duration
219
+ )
220
+
221
+ # Process errors
222
+ for i, frame_pred in enumerate(result.frame_predictions):
223
+ phoneme = frame_phonemes[i] if i < len(frame_phonemes) else ''
224
+ class_id = frame_pred.articulation_class
225
+ if frame_pred.fluency_label == 'stutter':
226
+ class_id += 4
227
+
228
+ error_detail = error_mapper.map_classifier_output(
229
+ class_id=class_id,
230
+ confidence=frame_pred.confidence,
231
+ phoneme=phoneme if phoneme else 'unknown',
232
+ fluency_label=frame_pred.fluency_label
233
+ )
234
+
235
+ if error_detail.error_type != ErrorType.NORMAL:
236
+ errors.append({
237
+ "phoneme": error_detail.phoneme,
238
+ "time": frame_pred.time,
239
+ "error_type": error_detail.error_type.value,
240
+ "wrong_sound": error_detail.wrong_sound,
241
+ "severity": error_mapper.get_severity_level(error_detail.severity).value,
242
+ "therapy": error_detail.therapy
243
+ })
244
+ except Exception as e:
245
+ logger.warning(f"⚠️ Phoneme/error mapping failed: {e}")
246
+
247
+ # Extract metrics
248
  aggregate = result.aggregate
249
  mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
250
+ fluency_percentage = (1.0 - mean_fluency_stutter) * 100
251
 
 
252
  fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
253
  fluent_frames_ratio = fluent_frames / result.num_frames if result.num_frames > 0 else 0.0
254
 
 
255
  articulation_class_counts = {}
256
  for fp in result.frame_predictions:
257
  label = fp.articulation_label
258
  articulation_class_counts[label] = articulation_class_counts.get(label, 0) + 1
259
 
 
260
  dominant_articulation = aggregate.get("articulation_label", "normal")
 
 
261
  avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
262
 
263
+ # Format response (legacy format with optional error info)
264
  response = {
265
  "status": "success",
266
  "fluency_metrics": {
 
283
  "fluency_label": fp.fluency_label,
284
  "articulation_class": fp.articulation_class,
285
  "articulation_label": fp.articulation_label,
286
+ "confidence": fp.confidence,
287
+ "phoneme": frame_phonemes[i] if i < len(frame_phonemes) else ''
288
  }
289
+ for i, fp in enumerate(result.frame_predictions)
290
  ]
291
  },
292
  "confidence": avg_confidence,
 
294
  "processing_time_ms": processing_time_ms
295
  }
296
 
297
+ # Add error info if available
298
+ if errors:
299
+ response["error_count"] = len(errors)
300
+ response["errors"] = errors[:10] # Limit to first 10 for legacy format
301
+ response["problematic_sounds"] = list(set(err["phoneme"] for err in errors if err["phoneme"]))
302
+
303
+ logger.info(f"βœ… Legacy diagnosis complete: fluency={response['fluency_metrics']['fluency_percentage']:.1f}%, "
304
+ f"errors={len(errors) if errors else 0}, "
305
  f"time={processing_time_ms:.0f}ms")
306
 
307
  return response
 
322
  logger.warning(f"Could not clean up {temp_file}: {e}")
323
 
324
 
325
+ # Legacy /analyze endpoint removed - use /api/diagnose or /diagnose/file instead
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
 
328
  @app.websocket("/ws/audio")
inference/__init__.py CHANGED
@@ -6,15 +6,15 @@ This package contains the inference pipeline for real-time and batch processing.
6
 
7
  from .inference_pipeline import (
8
  InferencePipeline,
9
- PredictionResult,
10
- BatchPredictionResult,
11
- create_inference_pipeline
12
  )
13
 
14
  __all__ = [
15
  "InferencePipeline",
16
- "PredictionResult",
17
- "BatchPredictionResult",
18
  "create_inference_pipeline",
19
  ]
20
 
 
6
 
7
  from .inference_pipeline import (
8
  InferencePipeline,
9
+ FramePrediction,
10
+ PhoneLevelResult,
11
+ create_inference_pipeline,
12
  )
13
 
14
  __all__ = [
15
  "InferencePipeline",
16
+ "FramePrediction",
17
+ "PhoneLevelResult",
18
  "create_inference_pipeline",
19
  ]
20
 
models/phoneme_mapper.py CHANGED
@@ -72,6 +72,18 @@ class PhonemeMapper:
72
  "g2p_en library is required. Install with: pip install g2p-en"
73
  )
74
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  try:
76
  self.g2p = g2p_en.G2p()
77
  logger.info("βœ… G2P model loaded successfully")
 
72
  "g2p_en library is required. Install with: pip install g2p-en"
73
  )
74
 
75
+ # Ensure NLTK data is available (required by g2p_en)
76
+ try:
77
+ import nltk
78
+ try:
79
+ nltk.data.find('taggers/averaged_perceptron_tagger_eng')
80
+ except LookupError:
81
+ logger.info("Downloading NLTK averaged_perceptron_tagger_eng...")
82
+ nltk.download('averaged_perceptron_tagger_eng', quiet=True)
83
+ logger.info("βœ… NLTK data downloaded")
84
+ except Exception as e:
85
+ logger.warning(f"⚠️ Could not download NLTK data: {e}")
86
+
87
  try:
88
  self.g2p = g2p_en.G2p()
89
  logger.info("βœ… G2P model loaded successfully")
models/speech_pathology_model.py CHANGED
@@ -210,6 +210,7 @@ class SpeechPathologyClassifier(nn.Module):
210
 
211
  self.device = torch.device(device)
212
  self.use_fp16 = use_fp16 and device == "cuda"
 
213
 
214
  if classifier_hidden_dims is None:
215
  classifier_hidden_dims = [256, 128]
@@ -257,6 +258,9 @@ class SpeechPathologyClassifier(nn.Module):
257
  num_articulation_classes=num_articulation_classes
258
  )
259
 
 
 
 
260
  # Move to device
261
  self.wav2vec2_model = self.wav2vec2_model.to(self.device)
262
  self.classifier_head = self.classifier_head.to(self.device)
@@ -275,6 +279,56 @@ class SpeechPathologyClassifier(nn.Module):
275
  logger.error(f"❌ Failed to initialize model: {e}", exc_info=True)
276
  raise RuntimeError(f"Failed to load Wav2Vec2 model: {e}") from e
277
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  def forward(
279
  self,
280
  input_values: torch.Tensor,
 
210
 
211
  self.device = torch.device(device)
212
  self.use_fp16 = use_fp16 and device == "cuda"
213
+ self.is_trained = False # Track if classifier is trained
214
 
215
  if classifier_hidden_dims is None:
216
  classifier_hidden_dims = [256, 128]
 
258
  num_articulation_classes=num_articulation_classes
259
  )
260
 
261
+ # Try to load trained weights if available (None = try default paths)
262
+ self._load_trained_weights(None)
263
+
264
  # Move to device
265
  self.wav2vec2_model = self.wav2vec2_model.to(self.device)
266
  self.classifier_head = self.classifier_head.to(self.device)
 
279
  logger.error(f"❌ Failed to initialize model: {e}", exc_info=True)
280
  raise RuntimeError(f"Failed to load Wav2Vec2 model: {e}") from e
281
 
282
+ def _load_trained_weights(self, model_path: Optional[str] = None):
283
+ """
284
+ Load trained classifier head weights if available.
285
+
286
+ Args:
287
+ model_path: Optional path to model checkpoint. If None, tries default checkpoint paths.
288
+ """
289
+ from pathlib import Path
290
+
291
+ checkpoint_paths = []
292
+
293
+ # Add user-provided path
294
+ if model_path:
295
+ checkpoint_paths.append(Path(model_path))
296
+
297
+ # Add default checkpoint paths
298
+ checkpoint_paths.extend([
299
+ Path("models/checkpoints/classifier_head_best.pt"),
300
+ Path("models/checkpoints/classifier_head_trained.pt")
301
+ ])
302
+
303
+ for checkpoint_path in checkpoint_paths:
304
+ if checkpoint_path.exists():
305
+ try:
306
+ checkpoint = torch.load(checkpoint_path, map_location=self.device)
307
+
308
+ # Handle both full checkpoint dict and state_dict directly
309
+ if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
310
+ state_dict = checkpoint['model_state_dict']
311
+ epoch = checkpoint.get('epoch', 'unknown')
312
+ val_acc = checkpoint.get('val_accuracy', 'unknown')
313
+ else:
314
+ state_dict = checkpoint
315
+ epoch = 'unknown'
316
+ val_acc = 'unknown'
317
+
318
+ self.classifier_head.load_state_dict(state_dict)
319
+ logger.info(f"βœ… Loaded trained classifier head from {checkpoint_path}")
320
+ logger.info(f" Epoch: {epoch}, Validation Accuracy: {val_acc}")
321
+ self.is_trained = True
322
+ return
323
+ except Exception as e:
324
+ logger.warning(f"⚠️ Could not load checkpoint {checkpoint_path}: {e}")
325
+ continue
326
+
327
+ # No trained weights found
328
+ logger.warning("⚠️ No trained classifier weights found. Using untrained head (beta mode)")
329
+ logger.warning(" To train the classifier, run: python training/train_classifier_head.py")
330
+ self.is_trained = False
331
+
332
  def forward(
333
  self,
334
  input_values: torch.Tensor,
requirements.txt CHANGED
@@ -5,6 +5,7 @@ torchaudio>=2.6.0
5
  transformers>=4.57.3,<5.0
6
  numpy>=1.24.0,<2.0.0
7
  protobuf>=3.20.0
 
8
 
9
  # Audio Processing
10
  librosa>=0.10.0
@@ -25,6 +26,12 @@ gradio==6.1.0
25
  # Logging
26
  python-json-logger>=2.0.0
27
 
 
 
 
 
 
 
28
  # Optional: Legacy/Advanced features
29
  openai-whisper>=20230314
30
  praat-parselmouth>=0.4.3
 
5
  transformers>=4.57.3,<5.0
6
  numpy>=1.24.0,<2.0.0
7
  protobuf>=3.20.0
8
+ g2p-en>=2.1.0
9
 
10
  # Audio Processing
11
  librosa>=0.10.0
 
26
  # Logging
27
  python-json-logger>=2.0.0
28
 
29
+ # Training dependencies
30
+ pyyaml>=6.0
31
+ scikit-learn>=1.3.0
32
+ matplotlib>=3.7.0
33
+ seaborn>=0.12.0
34
+
35
  # Optional: Legacy/Advanced features
36
  openai-whisper>=20230314
37
  praat-parselmouth>=0.4.3
scripts/annotation_helper.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Annotation Helper Utilities
3
+
4
+ Helper functions for phoneme-level annotation tasks.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import List, Dict, Any, Optional
11
+ import numpy as np
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def load_annotations(annotations_file: Path = Path("data/annotations.json")) -> List[Dict[str, Any]]:
17
+ """Load annotations from JSON file."""
18
+ if not annotations_file.exists():
19
+ logger.warning(f"Annotations file not found: {annotations_file}")
20
+ return []
21
+
22
+ try:
23
+ with open(annotations_file, 'r', encoding='utf-8') as f:
24
+ return json.load(f)
25
+ except Exception as e:
26
+ logger.error(f"Failed to load annotations: {e}")
27
+ return []
28
+
29
+
30
+ def save_annotations(annotations: List[Dict[str, Any]], annotations_file: Path = Path("data/annotations.json")):
31
+ """Save annotations to JSON file."""
32
+ annotations_file.parent.mkdir(parents=True, exist_ok=True)
33
+
34
+ with open(annotations_file, 'w', encoding='utf-8') as f:
35
+ json.dump(annotations, f, indent=2, ensure_ascii=False)
36
+
37
+ logger.info(f"Saved {len(annotations)} annotations to {annotations_file}")
38
+
39
+
40
+ def get_annotation_statistics(annotations: List[Dict[str, Any]]) -> Dict[str, Any]:
41
+ """Calculate statistics from annotations."""
42
+ total_samples = len(annotations)
43
+ total_errors = sum(a.get('total_errors', 0) for a in annotations)
44
+
45
+ error_types = {
46
+ 'substitution': 0,
47
+ 'omission': 0,
48
+ 'distortion': 0,
49
+ 'stutter': 0,
50
+ 'normal': 0
51
+ }
52
+
53
+ phoneme_errors = {}
54
+
55
+ for ann in annotations:
56
+ for err in ann.get('phoneme_errors', []):
57
+ err_type = err.get('error_type', 'normal')
58
+ error_types[err_type] = error_types.get(err_type, 0) + 1
59
+
60
+ phoneme = err.get('phoneme', 'unknown')
61
+ if phoneme not in phoneme_errors:
62
+ phoneme_errors[phoneme] = 0
63
+ phoneme_errors[phoneme] += 1
64
+
65
+ return {
66
+ 'total_samples': total_samples,
67
+ 'total_errors': total_errors,
68
+ 'error_types': error_types,
69
+ 'phoneme_errors': phoneme_errors,
70
+ 'avg_errors_per_sample': total_errors / total_samples if total_samples > 0 else 0.0
71
+ }
72
+
73
+
74
+ def export_for_training(
75
+ annotations: List[Dict[str, Any]],
76
+ output_file: Path = Path("data/training_dataset.json")
77
+ ) -> Dict[str, Any]:
78
+ """Export annotations in training-ready format."""
79
+ training_data = []
80
+
81
+ for ann in annotations:
82
+ audio_file = ann.get('audio_file')
83
+ expected_text = ann.get('expected_text', '')
84
+ duration = ann.get('duration', 0.0)
85
+
86
+ # Create frame-level labels
87
+ num_frames = int((duration * 1000) / 20) # 20ms frames
88
+ frame_labels = [0] * num_frames # 0 = normal
89
+
90
+ # Map errors to frames
91
+ for err in ann.get('phoneme_errors', []):
92
+ frame_id = err.get('frame_id', 0)
93
+ err_type = err.get('error_type', 'normal')
94
+
95
+ # Map to 8-class system
96
+ class_id = {
97
+ 'normal': 0,
98
+ 'substitution': 1,
99
+ 'omission': 2,
100
+ 'distortion': 3,
101
+ 'stutter': 4
102
+ }.get(err_type, 0)
103
+
104
+ # Check if stutter + articulation error
105
+ if err_type != 'normal' and err_type != 'stutter':
106
+ # Check if there's also stutter
107
+ if any(e.get('error_type') == 'stutter' for e in ann.get('phoneme_errors', [])
108
+ if e.get('frame_id') == frame_id):
109
+ class_id += 4 # Add 4 for stutter classes (5-7)
110
+
111
+ if 0 <= frame_id < num_frames:
112
+ frame_labels[frame_id] = class_id
113
+
114
+ training_data.append({
115
+ 'audio_file': audio_file,
116
+ 'expected_text': expected_text,
117
+ 'duration': duration,
118
+ 'num_frames': num_frames,
119
+ 'frame_labels': frame_labels,
120
+ 'phoneme_errors': ann.get('phoneme_errors', [])
121
+ })
122
+
123
+ output_file.parent.mkdir(parents=True, exist_ok=True)
124
+ with open(output_file, 'w', encoding='utf-8') as f:
125
+ json.dump(training_data, f, indent=2, ensure_ascii=False)
126
+
127
+ logger.info(f"Exported {len(training_data)} samples for training to {output_file}")
128
+
129
+ return {
130
+ 'samples': len(training_data),
131
+ 'output_file': str(output_file)
132
+ }
133
+
134
+
135
+ if __name__ == "__main__":
136
+ # Example usage
137
+ annotations = load_annotations()
138
+ stats = get_annotation_statistics(annotations)
139
+
140
+ print(f"Total samples: {stats['total_samples']}")
141
+ print(f"Total errors: {stats['total_errors']}")
142
+ print(f"Error types: {stats['error_types']}")
143
+
144
+ if annotations:
145
+ export_for_training(annotations)
146
+
scripts/data_collection.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data Collection Tool for Speech Pathology Annotation
3
+
4
+ This module provides a Gradio-based interface for collecting and annotating
5
+ phoneme-level speech pathology data. Clinicians can record or upload audio,
6
+ then annotate errors at the phoneme level with timestamps.
7
+
8
+ Usage:
9
+ python scripts/data_collection.py
10
+ """
11
+
12
+ import logging
13
+ import os
14
+ import json
15
+ import time
16
+ import tempfile
17
+ from pathlib import Path
18
+ from typing import Optional, List, Dict, Any, Tuple
19
+ from datetime import datetime
20
+ import numpy as np
21
+
22
+ import gradio as gr
23
+ import librosa
24
+ import soundfile as sf
25
+
26
+ from models.phoneme_mapper import PhonemeMapper
27
+ from models.error_taxonomy import ErrorType, SeverityLevel
28
+
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Configuration
33
+ DATA_DIR = Path("data/raw")
34
+ ANNOTATIONS_FILE = Path("data/annotations.json")
35
+ SAMPLE_RATE = 16000
36
+ FRAME_DURATION_MS = 20
37
+
38
+ # Ensure directories exist
39
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
40
+ ANNOTATIONS_FILE.parent.mkdir(parents=True, exist_ok=True)
41
+
42
+ # Load existing annotations
43
+ annotations_db: List[Dict[str, Any]] = []
44
+ if ANNOTATIONS_FILE.exists():
45
+ try:
46
+ with open(ANNOTATIONS_FILE, 'r', encoding='utf-8') as f:
47
+ annotations_db = json.load(f)
48
+ logger.info(f"βœ… Loaded {len(annotations_db)} existing annotations")
49
+ except Exception as e:
50
+ logger.warning(f"⚠️ Could not load annotations: {e}")
51
+
52
+
53
+ def save_audio_file(audio_data: Optional[Tuple[int, np.ndarray]], filename: str) -> Optional[str]:
54
+ """Save uploaded/recorded audio to file."""
55
+ if audio_data is None:
56
+ return None
57
+
58
+ sample_rate, audio_array = audio_data
59
+
60
+ # Resample to 16kHz if needed
61
+ if sample_rate != SAMPLE_RATE:
62
+ audio_array = librosa.resample(
63
+ audio_array.astype(np.float32),
64
+ orig_sr=sample_rate,
65
+ target_sr=SAMPLE_RATE
66
+ )
67
+ sample_rate = SAMPLE_RATE
68
+
69
+ # Normalize
70
+ if np.max(np.abs(audio_array)) > 0:
71
+ audio_array = audio_array / np.max(np.abs(audio_array))
72
+
73
+ # Save to data/raw
74
+ output_path = DATA_DIR / filename
75
+ sf.write(str(output_path), audio_array, sample_rate)
76
+ logger.info(f"βœ… Saved audio to {output_path}")
77
+
78
+ return str(output_path)
79
+
80
+
81
+ def get_phoneme_list(text: str) -> List[str]:
82
+ """Convert text to phoneme list using PhonemeMapper."""
83
+ try:
84
+ mapper = PhonemeMapper(
85
+ frame_duration_ms=FRAME_DURATION_MS,
86
+ sample_rate=SAMPLE_RATE
87
+ )
88
+ phonemes = mapper.g2p.convert(text)
89
+ return [p for p in phonemes if p.strip()] if phonemes else []
90
+ except Exception as e:
91
+ logger.error(f"❌ G2P conversion failed: {e}")
92
+ return []
93
+
94
+
95
+ def calculate_frame_count(audio_path: str) -> int:
96
+ """Calculate number of frames for audio file."""
97
+ try:
98
+ duration = librosa.get_duration(path=audio_path)
99
+ frames = int((duration * 1000) / FRAME_DURATION_MS)
100
+ return max(1, frames)
101
+ except Exception as e:
102
+ logger.error(f"❌ Could not calculate frames: {e}")
103
+ return 0
104
+
105
+
106
+ def save_annotation(
107
+ audio_path: str,
108
+ expected_text: str,
109
+ phoneme_errors: List[Dict[str, Any]],
110
+ annotator_name: str,
111
+ notes: str
112
+ ) -> Dict[str, Any]:
113
+ """Save annotation to database."""
114
+ try:
115
+ duration = librosa.get_duration(path=audio_path)
116
+
117
+ annotation = {
118
+ 'id': f"annot_{int(time.time())}",
119
+ 'audio_file': audio_path,
120
+ 'expected_text': expected_text,
121
+ 'duration': float(duration),
122
+ 'annotator': annotator_name,
123
+ 'notes': notes,
124
+ 'created_at': datetime.utcnow().isoformat() + "Z",
125
+ 'phoneme_errors': phoneme_errors,
126
+ 'total_errors': len(phoneme_errors),
127
+ 'error_types': {
128
+ 'substitution': sum(1 for e in phoneme_errors if e.get('error_type') == 'substitution'),
129
+ 'omission': sum(1 for e in phoneme_errors if e.get('error_type') == 'omission'),
130
+ 'distortion': sum(1 for e in phoneme_errors if e.get('error_type') == 'distortion'),
131
+ 'stutter': sum(1 for e in phoneme_errors if e.get('error_type') == 'stutter'),
132
+ }
133
+ }
134
+
135
+ annotations_db.append(annotation)
136
+
137
+ # Save to file
138
+ with open(ANNOTATIONS_FILE, 'w', encoding='utf-8') as f:
139
+ json.dump(annotations_db, f, indent=2, ensure_ascii=False)
140
+
141
+ logger.info(f"βœ… Saved annotation {annotation['id']} with {len(phoneme_errors)} errors")
142
+
143
+ return {
144
+ 'status': 'success',
145
+ 'annotation_id': annotation['id'],
146
+ 'total_errors': len(phoneme_errors),
147
+ 'message': f"βœ… Annotation saved! Total annotations: {len(annotations_db)}"
148
+ }
149
+ except Exception as e:
150
+ logger.error(f"❌ Failed to save annotation: {e}", exc_info=True)
151
+ return {
152
+ 'status': 'error',
153
+ 'message': f"❌ Failed to save: {str(e)}"
154
+ }
155
+
156
+
157
+ def create_annotation_interface():
158
+ """Create Gradio interface for data collection."""
159
+
160
+ with gr.Blocks(title="Speech Pathology Data Collection", theme=gr.themes.Soft()) as interface:
161
+ gr.Markdown("""
162
+ # 🎀 Speech Pathology Data Collection Tool
163
+
164
+ **Purpose:** Collect and annotate phoneme-level speech pathology data for training.
165
+
166
+ **Instructions:**
167
+ 1. Upload or record audio (5-30 seconds, 16kHz WAV)
168
+ 2. Enter expected text/transcript
169
+ 3. Review phoneme list
170
+ 4. Annotate errors at phoneme level
171
+ 5. Save annotation
172
+ """)
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=1):
176
+ gr.Markdown("### πŸ“₯ Audio Input")
177
+
178
+ audio_input = gr.Audio(
179
+ type="numpy",
180
+ label="Record or Upload Audio",
181
+ sources=["microphone", "upload"],
182
+ format="wav"
183
+ )
184
+
185
+ expected_text = gr.Textbox(
186
+ label="Expected Text/Transcript",
187
+ placeholder="Enter the expected text that should be spoken",
188
+ lines=3
189
+ )
190
+
191
+ phoneme_display = gr.Textbox(
192
+ label="Phonemes (G2P)",
193
+ lines=5,
194
+ interactive=False,
195
+ info="Phonemes extracted from expected text"
196
+ )
197
+
198
+ btn_get_phonemes = gr.Button("πŸ” Extract Phonemes", variant="secondary")
199
+
200
+ with gr.Column(scale=1):
201
+ gr.Markdown("### ✏️ Annotation")
202
+
203
+ annotator_name = gr.Textbox(
204
+ label="Annotator Name",
205
+ placeholder="Your name",
206
+ value="clinician"
207
+ )
208
+
209
+ error_frame_id = gr.Number(
210
+ label="Frame ID (0-based)",
211
+ value=0,
212
+ precision=0,
213
+ info="Frame number where error occurs"
214
+ )
215
+
216
+ error_phoneme = gr.Textbox(
217
+ label="Phoneme with Error",
218
+ placeholder="/r/",
219
+ info="The phoneme that has an error"
220
+ )
221
+
222
+ error_type = gr.Dropdown(
223
+ label="Error Type",
224
+ choices=["normal", "substitution", "omission", "distortion", "stutter"],
225
+ value="normal",
226
+ info="Type of error detected"
227
+ )
228
+
229
+ wrong_sound = gr.Textbox(
230
+ label="Wrong Sound (if substitution)",
231
+ placeholder="/w/",
232
+ info="What sound was produced instead (for substitutions)"
233
+ )
234
+
235
+ error_severity = gr.Slider(
236
+ label="Severity (0-1)",
237
+ minimum=0.0,
238
+ maximum=1.0,
239
+ value=0.5,
240
+ step=0.1,
241
+ info="Severity of the error"
242
+ )
243
+
244
+ error_timestamp = gr.Number(
245
+ label="Timestamp (seconds)",
246
+ value=0.0,
247
+ precision=2,
248
+ info="Time in audio where error occurs"
249
+ )
250
+
251
+ btn_add_error = gr.Button("βž• Add Error", variant="primary")
252
+
253
+ errors_list = gr.Dataframe(
254
+ label="Annotated Errors",
255
+ headers=["Frame", "Phoneme", "Type", "Wrong Sound", "Severity", "Time"],
256
+ interactive=False,
257
+ wrap=True
258
+ )
259
+
260
+ notes = gr.Textbox(
261
+ label="Notes",
262
+ placeholder="Additional notes about this sample",
263
+ lines=3
264
+ )
265
+
266
+ btn_save = gr.Button("πŸ’Ύ Save Annotation", variant="primary", size="lg")
267
+
268
+ output_status = gr.Textbox(
269
+ label="Status",
270
+ interactive=False,
271
+ lines=3
272
+ )
273
+
274
+ # Statistics panel
275
+ with gr.Row():
276
+ gr.Markdown("### πŸ“Š Statistics")
277
+ stats_display = gr.Markdown("**Total Annotations:** 0 | **Total Errors:** 0")
278
+
279
+ # Event handlers
280
+ errors_data = gr.State(value=[])
281
+
282
+ def extract_phonemes(text: str) -> str:
283
+ """Extract phonemes from text."""
284
+ if not text:
285
+ return "Enter expected text first"
286
+ phonemes = get_phoneme_list(text)
287
+ return " ".join([f"/{p}/" for p in phonemes]) if phonemes else "No phonemes found"
288
+
289
+ def add_error(
290
+ frame_id: int,
291
+ phoneme: str,
292
+ error_type: str,
293
+ wrong_sound: str,
294
+ severity: float,
295
+ timestamp: float,
296
+ current_errors: List[Dict]
297
+ ) -> Tuple[List[Dict], gr.Dataframe]:
298
+ """Add an error to the list."""
299
+ error = {
300
+ 'frame_id': int(frame_id),
301
+ 'phoneme': phoneme.strip(),
302
+ 'error_type': error_type,
303
+ 'wrong_sound': wrong_sound.strip() if wrong_sound else None,
304
+ 'severity': float(severity),
305
+ 'timestamp': float(timestamp),
306
+ 'confidence': 1.0 # Manual annotation is always confident
307
+ }
308
+
309
+ new_errors = current_errors + [error]
310
+
311
+ # Create dataframe
312
+ df_data = [
313
+ [
314
+ e['frame_id'],
315
+ e['phoneme'],
316
+ e['error_type'],
317
+ e.get('wrong_sound', 'N/A'),
318
+ f"{e['severity']:.2f}",
319
+ f"{e['timestamp']:.2f}s"
320
+ ]
321
+ for e in new_errors
322
+ ]
323
+
324
+ return new_errors, df_data
325
+
326
+ def save_annotation_handler(
327
+ audio_data: Optional[Tuple[int, np.ndarray]],
328
+ expected_text: str,
329
+ errors: List[Dict],
330
+ annotator: str,
331
+ notes: str
332
+ ) -> str:
333
+ """Handle annotation saving."""
334
+ if audio_data is None:
335
+ return "❌ Please provide audio first"
336
+
337
+ if not expected_text:
338
+ return "❌ Please provide expected text"
339
+
340
+ # Save audio
341
+ filename = f"sample_{int(time.time())}.wav"
342
+ audio_path = save_audio_file(audio_data, filename)
343
+
344
+ if not audio_path:
345
+ return "❌ Failed to save audio file"
346
+
347
+ # Save annotation
348
+ result = save_annotation(
349
+ audio_path=audio_path,
350
+ expected_text=expected_text,
351
+ phoneme_errors=errors,
352
+ annotator_name=annotator,
353
+ notes=notes
354
+ )
355
+
356
+ return result.get('message', 'Unknown status')
357
+
358
+ def update_stats() -> str:
359
+ """Update statistics display."""
360
+ total_annotations = len(annotations_db)
361
+ total_errors = sum(a.get('total_errors', 0) for a in annotations_db)
362
+
363
+ error_breakdown = {}
364
+ for ann in annotations_db:
365
+ for err_type, count in ann.get('error_types', {}).items():
366
+ error_breakdown[err_type] = error_breakdown.get(err_type, 0) + count
367
+
368
+ stats_text = f"""
369
+ **Total Annotations:** {total_annotations} | **Total Errors:** {total_errors}
370
+
371
+ **Error Breakdown:**
372
+ - Substitution: {error_breakdown.get('substitution', 0)}
373
+ - Omission: {error_breakdown.get('omission', 0)}
374
+ - Distortion: {error_breakdown.get('distortion', 0)}
375
+ - Stutter: {error_breakdown.get('stutter', 0)}
376
+ """
377
+ return stats_text
378
+
379
+ # Wire up events
380
+ btn_get_phonemes.click(
381
+ fn=extract_phonemes,
382
+ inputs=[expected_text],
383
+ outputs=[phoneme_display]
384
+ )
385
+
386
+ btn_add_error.click(
387
+ fn=add_error,
388
+ inputs=[
389
+ error_frame_id,
390
+ error_phoneme,
391
+ error_type,
392
+ wrong_sound,
393
+ error_severity,
394
+ error_timestamp,
395
+ errors_data
396
+ ],
397
+ outputs=[errors_data, errors_list]
398
+ )
399
+
400
+ btn_save.click(
401
+ fn=save_annotation_handler,
402
+ inputs=[audio_input, expected_text, errors_data, annotator_name, notes],
403
+ outputs=[output_status]
404
+ ).then(
405
+ fn=update_stats,
406
+ outputs=[stats_display]
407
+ )
408
+
409
+ # Load stats on startup
410
+ interface.load(fn=update_stats, outputs=[stats_display])
411
+
412
+ return interface
413
+
414
+
415
+ if __name__ == "__main__":
416
+ interface = create_annotation_interface()
417
+ interface.launch(server_name="0.0.0.0", server_port=7861, share=False)
418
+
training/config.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Training Configuration for Classifier Head
2
+
3
+ # Data Configuration
4
+ data:
5
+ annotations_file: "data/annotations.json"
6
+ training_dataset: "data/training_dataset.json"
7
+ train_split: 0.8
8
+ val_split: 0.2
9
+ test_split: 0.0 # Use validation as test if test_split is 0
10
+ random_seed: 42
11
+
12
+ # Model Configuration
13
+ model:
14
+ input_dim: 1024 # Wav2Vec2-XLSR-53 feature dimension
15
+ hidden_dims: [512, 256] # Shared layers
16
+ dropout: 0.1
17
+ num_classes: 8 # 8-class output (fluency + articulation combined)
18
+ use_pretrained_head: false # Set to true after first training
19
+
20
+ # Training Configuration
21
+ training:
22
+ batch_size: 16
23
+ num_epochs: 50
24
+ learning_rate: 0.001
25
+ weight_decay: 0.0001
26
+ gradient_clip_norm: 1.0
27
+
28
+ # Loss Configuration
29
+ loss:
30
+ type: "cross_entropy" # or "focal" for imbalanced data
31
+ class_weights: null # Auto-calculate from data if null
32
+ focal_alpha: 0.25
33
+ focal_gamma: 2.0
34
+
35
+ # Optimizer
36
+ optimizer: "adam"
37
+ adam_betas: [0.9, 0.999]
38
+
39
+ # Scheduler
40
+ scheduler: "reduce_on_plateau"
41
+ scheduler_patience: 5
42
+ scheduler_factor: 0.5
43
+ scheduler_min_lr: 0.00001
44
+
45
+ # Early Stopping
46
+ early_stopping:
47
+ enabled: true
48
+ patience: 10
49
+ min_delta: 0.001
50
+ monitor: "val_loss"
51
+
52
+ # Data Augmentation
53
+ augmentation:
54
+ enabled: false # Enable after initial training
55
+ time_stretch: [0.9, 1.1]
56
+ noise_injection: 0.01
57
+ pitch_shift: [-2, 2] # semitones
58
+
59
+ # Validation Configuration
60
+ validation:
61
+ metrics: ["accuracy", "f1_score", "precision", "recall", "confusion_matrix"]
62
+ per_class_metrics: true
63
+ save_predictions: true
64
+
65
+ # Checkpoint Configuration
66
+ checkpoint:
67
+ save_dir: "models/checkpoints"
68
+ save_best: true
69
+ save_last: true
70
+ save_frequency: 5 # Save every N epochs
71
+ filename: "classifier_head_trained.pt"
72
+ best_filename: "classifier_head_best.pt"
73
+
74
+ # Logging Configuration
75
+ logging:
76
+ log_dir: "training/logs"
77
+ tensorboard: false # Enable if tensorboard installed
78
+ wandb: false # Enable if wandb installed
79
+ log_frequency: 10 # Log every N batches
80
+
81
+ # Device Configuration
82
+ device:
83
+ use_cuda: true
84
+ cuda_device: 0
85
+ mixed_precision: false # Use FP16 training
86
+
training/evaluate_classifier.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluation Script for Trained Classifier Head
3
+
4
+ Evaluates the trained classifier on test/validation data and generates
5
+ comprehensive metrics including per-class accuracy, confusion matrix, etc.
6
+
7
+ Usage:
8
+ python training/evaluate_classifier.py --checkpoint models/checkpoints/classifier_head_best.pt
9
+ """
10
+
11
+ import logging
12
+ import argparse
13
+ import json
14
+ import yaml
15
+ from pathlib import Path
16
+ from typing import Dict, List, Any
17
+ import sys
18
+
19
+ import torch
20
+ import numpy as np
21
+ from sklearn.metrics import (
22
+ accuracy_score, f1_score, precision_score, recall_score,
23
+ confusion_matrix, classification_report
24
+ )
25
+ import matplotlib.pyplot as plt
26
+ import seaborn as sns
27
+
28
+ # Add project root to path
29
+ sys.path.insert(0, str(Path(__file__).parent.parent))
30
+
31
+ from training.train_classifier_head import PhonemeDataset, collate_fn
32
+ from torch.utils.data import DataLoader
33
+ from models.speech_pathology_model import SpeechPathologyClassifier
34
+ from models.phoneme_mapper import PhonemeMapper
35
+ from inference.inference_pipeline import InferencePipeline
36
+ from config import default_audio_config, default_model_config, default_inference_config
37
+
38
+ logging.basicConfig(level=logging.INFO)
39
+ logger = logging.getLogger(__name__)
40
+
41
+
42
+ def load_trained_model(checkpoint_path: Path, config_path: Path = Path("training/config.yaml")) -> torch.nn.Module:
43
+ """Load trained classifier head from checkpoint."""
44
+ # Load config
45
+ with open(config_path, 'r') as f:
46
+ config = yaml.safe_load(f)
47
+
48
+ # Initialize inference pipeline
49
+ inference_pipeline = InferencePipeline(
50
+ audio_config=default_audio_config,
51
+ model_config=default_model_config,
52
+ inference_config=default_inference_config
53
+ )
54
+
55
+ model = inference_pipeline.model
56
+
57
+ # Load checkpoint
58
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
59
+ model.classifier_head.load_state_dict(checkpoint['model_state_dict'])
60
+
61
+ logger.info(f"βœ… Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
62
+ logger.info(f" Validation loss: {checkpoint.get('val_loss', 'unknown'):.4f}")
63
+ logger.info(f" Validation accuracy: {checkpoint.get('val_accuracy', 'unknown'):.4f}")
64
+
65
+ return model
66
+
67
+
68
+ def evaluate_model(
69
+ model: torch.nn.Module,
70
+ dataloader: DataLoader,
71
+ device: torch.device,
72
+ class_names: List[str]
73
+ ) -> Dict[str, Any]:
74
+ """Evaluate model and return comprehensive metrics."""
75
+ model.eval()
76
+
77
+ all_preds = []
78
+ all_labels = []
79
+ all_probs = []
80
+
81
+ with torch.no_grad():
82
+ for batch in dataloader:
83
+ features = batch['features'].to(device)
84
+ labels = batch['labels'].to(device)
85
+
86
+ batch_size, seq_len, feat_dim = features.shape
87
+ features_flat = features.view(-1, feat_dim)
88
+ labels_flat = labels.view(-1)
89
+
90
+ # Forward pass
91
+ shared_features = model.classifier_head.shared_layers(features_flat)
92
+ logits = model.classifier_head.full_head(shared_features)
93
+ probs = torch.softmax(logits, dim=-1)
94
+
95
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
96
+ all_preds.extend(preds)
97
+ all_labels.extend(labels_flat.cpu().numpy())
98
+ all_probs.extend(probs.cpu().numpy())
99
+
100
+ # Calculate metrics
101
+ accuracy = accuracy_score(all_labels, all_preds)
102
+ f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
103
+ f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
104
+ precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0)
105
+ recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0)
106
+
107
+ # Per-class metrics
108
+ cm = confusion_matrix(all_labels, all_preds, labels=list(range(len(class_names))))
109
+
110
+ # Per-class accuracy
111
+ per_class_accuracy = cm.diagonal() / cm.sum(axis=1)
112
+ per_class_accuracy = np.nan_to_num(per_class_accuracy) # Handle division by zero
113
+
114
+ # Classification report
115
+ report = classification_report(
116
+ all_labels, all_preds,
117
+ target_names=class_names,
118
+ output_dict=True,
119
+ zero_division=0
120
+ )
121
+
122
+ # Confidence analysis
123
+ all_probs = np.array(all_probs)
124
+ max_probs = np.max(all_probs, axis=1)
125
+ correct_mask = np.array(all_preds) == np.array(all_labels)
126
+
127
+ avg_confidence_correct = np.mean(max_probs[correct_mask]) if np.any(correct_mask) else 0.0
128
+ avg_confidence_incorrect = np.mean(max_probs[~correct_mask]) if np.any(~correct_mask) else 0.0
129
+
130
+ return {
131
+ 'overall_accuracy': float(accuracy),
132
+ 'f1_macro': float(f1_macro),
133
+ 'f1_weighted': float(f1_weighted),
134
+ 'precision_macro': float(precision_macro),
135
+ 'recall_macro': float(recall_macro),
136
+ 'confusion_matrix': cm.tolist(),
137
+ 'per_class_accuracy': per_class_accuracy.tolist(),
138
+ 'classification_report': report,
139
+ 'confidence': {
140
+ 'avg_correct': float(avg_confidence_correct),
141
+ 'avg_incorrect': float(avg_confidence_incorrect),
142
+ 'confidence_distribution': {
143
+ 'mean': float(np.mean(max_probs)),
144
+ 'std': float(np.std(max_probs)),
145
+ 'min': float(np.min(max_probs)),
146
+ 'max': float(np.max(max_probs))
147
+ }
148
+ },
149
+ 'num_samples': len(all_labels)
150
+ }
151
+
152
+
153
+ def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], output_path: Path):
154
+ """Plot and save confusion matrix."""
155
+ plt.figure(figsize=(10, 8))
156
+ sns.heatmap(
157
+ cm,
158
+ annot=True,
159
+ fmt='d',
160
+ cmap='Blues',
161
+ xticklabels=class_names,
162
+ yticklabels=class_names
163
+ )
164
+ plt.title('Confusion Matrix')
165
+ plt.ylabel('True Label')
166
+ plt.xlabel('Predicted Label')
167
+ plt.tight_layout()
168
+ plt.savefig(output_path)
169
+ logger.info(f"βœ… Saved confusion matrix to {output_path}")
170
+
171
+
172
+ def main():
173
+ parser = argparse.ArgumentParser(description="Evaluate trained classifier")
174
+ parser.add_argument('--checkpoint', type=str, required=True,
175
+ help='Path to checkpoint file')
176
+ parser.add_argument('--config', type=str, default='training/config.yaml',
177
+ help='Path to config file')
178
+ parser.add_argument('--dataset', type=str, default='data/training_dataset.json',
179
+ help='Path to evaluation dataset')
180
+ parser.add_argument('--output', type=str, default='training/evaluation_results.json',
181
+ help='Path to save evaluation results')
182
+ parser.add_argument('--plot', type=str, default='training/confusion_matrix.png',
183
+ help='Path to save confusion matrix plot')
184
+ args = parser.parse_args()
185
+
186
+ # Load config
187
+ with open(args.config, 'r') as f:
188
+ config = yaml.safe_load(f)
189
+
190
+ # Set device
191
+ device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu')
192
+ logger.info(f"Using device: {device}")
193
+
194
+ # Load model
195
+ checkpoint_path = Path(args.checkpoint)
196
+ if not checkpoint_path.exists():
197
+ logger.error(f"Checkpoint not found: {checkpoint_path}")
198
+ return
199
+
200
+ model = load_trained_model(checkpoint_path, Path(args.config))
201
+ model = model.to(device)
202
+
203
+ # Load evaluation dataset
204
+ dataset_path = Path(args.dataset)
205
+ if not dataset_path.exists():
206
+ logger.error(f"Dataset not found: {dataset_path}")
207
+ return
208
+
209
+ with open(dataset_path, 'r') as f:
210
+ eval_data = json.load(f)
211
+
212
+ logger.info(f"Loaded {len(eval_data)} evaluation samples")
213
+
214
+ # Create dataset and dataloader
215
+ inference_pipeline = InferencePipeline(
216
+ audio_config=default_audio_config,
217
+ model_config=default_model_config,
218
+ inference_config=default_inference_config
219
+ )
220
+
221
+ phoneme_mapper = PhonemeMapper(frame_duration_ms=20, sample_rate=16000)
222
+
223
+ from training.train_classifier_head import PhonemeDataset
224
+ dataset = PhonemeDataset(eval_data, inference_pipeline, phoneme_mapper)
225
+
226
+ dataloader = DataLoader(
227
+ dataset,
228
+ batch_size=config['training']['batch_size'],
229
+ shuffle=False,
230
+ collate_fn=collate_fn
231
+ )
232
+
233
+ # Class names
234
+ class_names = [
235
+ "Normal",
236
+ "Substitution",
237
+ "Omission",
238
+ "Distortion",
239
+ "Normal+Stutter",
240
+ "Substitution+Stutter",
241
+ "Omission+Stutter",
242
+ "Distortion+Stutter"
243
+ ]
244
+
245
+ # Evaluate
246
+ logger.info("Evaluating model...")
247
+ metrics = evaluate_model(model, dataloader, device, class_names)
248
+
249
+ # Print results
250
+ logger.info("\n" + "="*50)
251
+ logger.info("EVALUATION RESULTS")
252
+ logger.info("="*50)
253
+ logger.info(f"Overall Accuracy: {metrics['overall_accuracy']:.4f}")
254
+ logger.info(f"F1 Score (macro): {metrics['f1_macro']:.4f}")
255
+ logger.info(f"F1 Score (weighted): {metrics['f1_weighted']:.4f}")
256
+ logger.info(f"Precision (macro): {metrics['precision_macro']:.4f}")
257
+ logger.info(f"Recall (macro): {metrics['recall_macro']:.4f}")
258
+ logger.info(f"\nPer-Class Accuracy:")
259
+ for i, (name, acc) in enumerate(zip(class_names, metrics['per_class_accuracy'])):
260
+ logger.info(f" {name}: {acc:.4f}")
261
+ logger.info(f"\nConfidence Analysis:")
262
+ logger.info(f" Avg confidence (correct): {metrics['confidence']['avg_correct']:.4f}")
263
+ logger.info(f" Avg confidence (incorrect): {metrics['confidence']['avg_incorrect']:.4f}")
264
+
265
+ # Save results
266
+ output_path = Path(args.output)
267
+ output_path.parent.mkdir(parents=True, exist_ok=True)
268
+ with open(output_path, 'w') as f:
269
+ json.dump(metrics, f, indent=2)
270
+ logger.info(f"\nβœ… Saved evaluation results to {output_path}")
271
+
272
+ # Plot confusion matrix
273
+ if args.plot:
274
+ plot_path = Path(args.plot)
275
+ plot_path.parent.mkdir(parents=True, exist_ok=True)
276
+ cm = np.array(metrics['confusion_matrix'])
277
+ plot_confusion_matrix(cm, class_names, plot_path)
278
+
279
+
280
+ if __name__ == "__main__":
281
+ main()
282
+
training/train_classifier_head.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Training Script for Speech Pathology Classifier Head
3
+
4
+ This script fine-tunes the classification head on phoneme-level labeled data.
5
+ Wav2Vec2 encoder is frozen; only the classifier head is trained.
6
+
7
+ Usage:
8
+ python training/train_classifier_head.py --config training/config.yaml
9
+ """
10
+
11
+ import logging
12
+ import os
13
+ import sys
14
+ import json
15
+ import yaml
16
+ import argparse
17
+ from pathlib import Path
18
+ from typing import Dict, List, Tuple, Optional, Any
19
+ from datetime import datetime
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.optim as optim
24
+ from torch.utils.data import Dataset, DataLoader, random_split
25
+ import numpy as np
26
+ from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
27
+ import librosa
28
+ import soundfile as sf
29
+
30
+ # Add project root to path
31
+ sys.path.insert(0, str(Path(__file__).parent.parent))
32
+
33
+ from models.speech_pathology_model import SpeechPathologyClassifier, MultiTaskClassifierHead
34
+ from models.phoneme_mapper import PhonemeMapper
35
+ from inference.inference_pipeline import InferencePipeline
36
+ from config import default_audio_config, default_model_config, default_inference_config
37
+
38
+ logging.basicConfig(
39
+ level=logging.INFO,
40
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
41
+ )
42
+ logger = logging.getLogger(__name__)
43
+
44
+
45
+ class PhonemeDataset(Dataset):
46
+ """Dataset for phoneme-level speech pathology training."""
47
+
48
+ def __init__(
49
+ self,
50
+ training_data: List[Dict[str, Any]],
51
+ inference_pipeline: InferencePipeline,
52
+ phoneme_mapper: PhonemeMapper
53
+ ):
54
+ """
55
+ Initialize dataset.
56
+
57
+ Args:
58
+ training_data: List of training samples with frame labels
59
+ inference_pipeline: Pipeline for extracting Wav2Vec2 features
60
+ phoneme_mapper: Mapper for phoneme alignment
61
+ """
62
+ self.training_data = training_data
63
+ self.inference_pipeline = inference_pipeline
64
+ self.phoneme_mapper = phoneme_mapper
65
+
66
+ logger.info(f"Initialized dataset with {len(training_data)} samples")
67
+
68
+ def __len__(self) -> int:
69
+ return len(self.training_data)
70
+
71
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
72
+ """Get a training sample."""
73
+ sample = self.training_data[idx]
74
+ audio_file = sample['audio_file']
75
+ frame_labels = sample['frame_labels']
76
+
77
+ # Load audio
78
+ try:
79
+ audio, sr = librosa.load(audio_file, sr=16000)
80
+ except Exception as e:
81
+ logger.error(f"Failed to load {audio_file}: {e}")
82
+ # Return dummy data
83
+ return {
84
+ 'features': torch.zeros(1, 1024),
85
+ 'labels': torch.tensor([0], dtype=torch.long),
86
+ 'valid': torch.tensor(False)
87
+ }
88
+
89
+ # Extract Wav2Vec2 features
90
+ try:
91
+ frame_features, frame_times = self.inference_pipeline.get_phone_level_features(audio)
92
+
93
+ # Align labels with features
94
+ num_features = len(frame_features)
95
+ num_labels = len(frame_labels)
96
+
97
+ # Pad or truncate labels to match features
98
+ if num_labels < num_features:
99
+ frame_labels = frame_labels + [0] * (num_features - num_labels)
100
+ elif num_labels > num_features:
101
+ frame_labels = frame_labels[:num_features]
102
+
103
+ # Convert to tensors
104
+ features_tensor = frame_features # Already a tensor
105
+ labels_tensor = torch.tensor(frame_labels[:num_features], dtype=torch.long)
106
+
107
+ return {
108
+ 'features': features_tensor,
109
+ 'labels': labels_tensor,
110
+ 'valid': torch.tensor(True)
111
+ }
112
+ except Exception as e:
113
+ logger.error(f"Failed to extract features from {audio_file}: {e}")
114
+ return {
115
+ 'features': torch.zeros(1, 1024),
116
+ 'labels': torch.tensor([0], dtype=torch.long),
117
+ 'valid': torch.tensor(False)
118
+ }
119
+
120
+
121
+ def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
122
+ """Collate function for DataLoader."""
123
+ # Filter out invalid samples
124
+ valid_batch = [b for b in batch if b['valid'].item()]
125
+
126
+ if not valid_batch:
127
+ # Return dummy batch
128
+ return {
129
+ 'features': torch.zeros(1, 1, 1024),
130
+ 'labels': torch.zeros(1, 1, dtype=torch.long)
131
+ }
132
+
133
+ # Stack features and labels
134
+ features_list = []
135
+ labels_list = []
136
+
137
+ for item in valid_batch:
138
+ features_list.append(item['features'])
139
+ labels_list.append(item['labels'])
140
+
141
+ # Pad to same length
142
+ max_len = max(f.shape[0] for f in features_list)
143
+
144
+ padded_features = []
145
+ padded_labels = []
146
+
147
+ for feat, lab in zip(features_list, labels_list):
148
+ if feat.shape[0] < max_len:
149
+ padding = max_len - feat.shape[0]
150
+ feat = torch.cat([feat, torch.zeros(padding, feat.shape[1])])
151
+ lab = torch.cat([lab, torch.zeros(padding, dtype=torch.long)])
152
+ padded_features.append(feat)
153
+ padded_labels.append(lab)
154
+
155
+ return {
156
+ 'features': torch.stack(padded_features),
157
+ 'labels': torch.stack(padded_labels)
158
+ }
159
+
160
+
161
+ def calculate_class_weights(dataset: PhonemeDataset) -> torch.Tensor:
162
+ """Calculate class weights for imbalanced data."""
163
+ all_labels = []
164
+ for i in range(len(dataset)):
165
+ sample = dataset[i]
166
+ if sample['valid'].item():
167
+ all_labels.extend(sample['labels'].tolist())
168
+
169
+ if not all_labels:
170
+ return torch.ones(8)
171
+
172
+ unique, counts = np.unique(all_labels, return_counts=True)
173
+ total = len(all_labels)
174
+
175
+ weights = torch.ones(8)
176
+ for cls, count in zip(unique, counts):
177
+ if count > 0:
178
+ weights[int(cls)] = total / (8 * count) # Inverse frequency weighting
179
+
180
+ logger.info(f"Class weights: {weights.tolist()}")
181
+ return weights
182
+
183
+
184
+ def train_epoch(
185
+ model: nn.Module,
186
+ dataloader: DataLoader,
187
+ optimizer: optim.Optimizer,
188
+ criterion: nn.Module,
189
+ device: torch.device,
190
+ epoch: int
191
+ ) -> Dict[str, float]:
192
+ """Train for one epoch."""
193
+ model.train()
194
+ total_loss = 0.0
195
+ all_preds = []
196
+ all_labels = []
197
+
198
+ for batch_idx, batch in enumerate(dataloader):
199
+ features = batch['features'].to(device) # (batch, seq_len, 1024)
200
+ labels = batch['labels'].to(device) # (batch, seq_len)
201
+
202
+ # Flatten for processing
203
+ batch_size, seq_len, feat_dim = features.shape
204
+ features_flat = features.view(-1, feat_dim) # (batch * seq_len, 1024)
205
+ labels_flat = labels.view(-1) # (batch * seq_len)
206
+
207
+ # Forward pass
208
+ optimizer.zero_grad()
209
+
210
+ # Get predictions from full_head
211
+ shared_features = model.classifier_head.shared_layers(features_flat)
212
+ logits = model.classifier_head.full_head(shared_features) # (batch * seq_len, 8)
213
+
214
+ # Calculate loss
215
+ loss = criterion(logits, labels_flat)
216
+
217
+ # Backward pass
218
+ loss.backward()
219
+ torch.nn.utils.clip_grad_norm_(model.classifier_head.parameters(), max_norm=1.0)
220
+ optimizer.step()
221
+
222
+ # Metrics
223
+ total_loss += loss.item()
224
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
225
+ all_preds.extend(preds)
226
+ all_labels.extend(labels_flat.cpu().numpy())
227
+
228
+ if batch_idx % 10 == 0:
229
+ logger.info(f"Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}")
230
+
231
+ avg_loss = total_loss / len(dataloader)
232
+ accuracy = accuracy_score(all_labels, all_preds)
233
+
234
+ return {
235
+ 'loss': avg_loss,
236
+ 'accuracy': accuracy
237
+ }
238
+
239
+
240
+ def validate(
241
+ model: nn.Module,
242
+ dataloader: DataLoader,
243
+ criterion: nn.Module,
244
+ device: torch.device
245
+ ) -> Dict[str, float]:
246
+ """Validate model."""
247
+ model.eval()
248
+ total_loss = 0.0
249
+ all_preds = []
250
+ all_labels = []
251
+
252
+ with torch.no_grad():
253
+ for batch in dataloader:
254
+ features = batch['features'].to(device)
255
+ labels = batch['labels'].to(device)
256
+
257
+ batch_size, seq_len, feat_dim = features.shape
258
+ features_flat = features.view(-1, feat_dim)
259
+ labels_flat = labels.view(-1)
260
+
261
+ # Forward pass
262
+ shared_features = model.classifier_head.shared_layers(features_flat)
263
+ logits = model.classifier_head.full_head(shared_features)
264
+
265
+ loss = criterion(logits, labels_flat)
266
+ total_loss += loss.item()
267
+
268
+ preds = torch.argmax(logits, dim=-1).cpu().numpy()
269
+ all_preds.extend(preds)
270
+ all_labels.extend(labels_flat.cpu().numpy())
271
+
272
+ avg_loss = total_loss / len(dataloader)
273
+ accuracy = accuracy_score(all_labels, all_preds)
274
+ f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
275
+ precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
276
+ recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
277
+
278
+ # Per-class metrics
279
+ cm = confusion_matrix(all_labels, all_preds, labels=list(range(8)))
280
+
281
+ return {
282
+ 'loss': avg_loss,
283
+ 'accuracy': accuracy,
284
+ 'f1_score': f1,
285
+ 'precision': precision,
286
+ 'recall': recall,
287
+ 'confusion_matrix': cm.tolist()
288
+ }
289
+
290
+
291
+ def main():
292
+ parser = argparse.ArgumentParser(description="Train classifier head")
293
+ parser.add_argument('--config', type=str, default='training/config.yaml',
294
+ help='Path to config file')
295
+ parser.add_argument('--resume', type=str, default=None,
296
+ help='Resume from checkpoint')
297
+ args = parser.parse_args()
298
+
299
+ # Load config
300
+ with open(args.config, 'r') as f:
301
+ config = yaml.safe_load(f)
302
+
303
+ # Set device
304
+ device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu')
305
+ logger.info(f"Using device: {device}")
306
+
307
+ # Load training data
308
+ training_file = Path(config['data']['training_dataset'])
309
+ if not training_file.exists():
310
+ logger.error(f"Training dataset not found: {training_file}")
311
+ logger.info("Run scripts/annotation_helper.py to export training data first")
312
+ return
313
+
314
+ with open(training_file, 'r') as f:
315
+ training_data = json.load(f)
316
+
317
+ logger.info(f"Loaded {len(training_data)} training samples")
318
+
319
+ # Initialize inference pipeline for feature extraction
320
+ inference_pipeline = InferencePipeline(
321
+ audio_config=default_audio_config,
322
+ model_config=default_model_config,
323
+ inference_config=default_inference_config
324
+ )
325
+
326
+ # Initialize phoneme mapper
327
+ phoneme_mapper = PhonemeMapper(
328
+ frame_duration_ms=20,
329
+ sample_rate=16000
330
+ )
331
+
332
+ # Create dataset
333
+ dataset = PhonemeDataset(training_data, inference_pipeline, phoneme_mapper)
334
+
335
+ # Split dataset
336
+ train_size = int(config['data']['train_split'] * len(dataset))
337
+ val_size = len(dataset) - train_size
338
+
339
+ train_dataset, val_dataset = random_split(
340
+ dataset,
341
+ [train_size, val_size],
342
+ generator=torch.Generator().manual_seed(config['data']['random_seed'])
343
+ )
344
+
345
+ logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
346
+
347
+ # Create data loaders
348
+ train_loader = DataLoader(
349
+ train_dataset,
350
+ batch_size=config['training']['batch_size'],
351
+ shuffle=True,
352
+ collate_fn=collate_fn
353
+ )
354
+
355
+ val_loader = DataLoader(
356
+ val_dataset,
357
+ batch_size=config['training']['batch_size'],
358
+ shuffle=False,
359
+ collate_fn=collate_fn
360
+ )
361
+
362
+ # Load model
363
+ model = inference_pipeline.model
364
+ model.train() # Set to training mode
365
+
366
+ # Freeze Wav2Vec2 (should already be frozen, but ensure it)
367
+ for param in model.wav2vec2_model.parameters():
368
+ param.requires_grad = False
369
+
370
+ # Unfreeze classifier head
371
+ for param in model.classifier_head.parameters():
372
+ param.requires_grad = True
373
+
374
+ logger.info("Model prepared: Wav2Vec2 frozen, classifier head trainable")
375
+
376
+ # Calculate class weights
377
+ class_weights = calculate_class_weights(dataset)
378
+ class_weights = class_weights.to(device)
379
+
380
+ # Loss function
381
+ if config['training']['loss']['type'] == 'cross_entropy':
382
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
383
+ else:
384
+ # Focal loss implementation would go here
385
+ criterion = nn.CrossEntropyLoss(weight=class_weights)
386
+
387
+ # Optimizer
388
+ optimizer = optim.Adam(
389
+ model.classifier_head.parameters(),
390
+ lr=config['training']['learning_rate'],
391
+ weight_decay=config['training']['weight_decay']
392
+ )
393
+
394
+ # Scheduler
395
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
396
+ optimizer,
397
+ mode='min',
398
+ factor=config['training']['scheduler_factor'],
399
+ patience=config['training']['scheduler_patience'],
400
+ min_lr=config['training']['scheduler_min_lr']
401
+ )
402
+
403
+ # Training loop
404
+ best_val_loss = float('inf')
405
+ patience_counter = 0
406
+
407
+ checkpoint_dir = Path(config['checkpoint']['save_dir'])
408
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
409
+
410
+ for epoch in range(config['training']['num_epochs']):
411
+ logger.info(f"\n{'='*50}")
412
+ logger.info(f"Epoch {epoch+1}/{config['training']['num_epochs']}")
413
+ logger.info(f"{'='*50}")
414
+
415
+ # Train
416
+ train_metrics = train_epoch(model, train_loader, optimizer, criterion, device, epoch+1)
417
+ logger.info(f"Train - Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}")
418
+
419
+ # Validate
420
+ val_metrics = validate(model, val_loader, criterion, device)
421
+ logger.info(f"Val - Loss: {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}, "
422
+ f"F1: {val_metrics['f1_score']:.4f}")
423
+
424
+ # Scheduler step
425
+ scheduler.step(val_metrics['loss'])
426
+
427
+ # Save checkpoint
428
+ if config['checkpoint']['save_best'] and val_metrics['loss'] < best_val_loss:
429
+ best_val_loss = val_metrics['loss']
430
+ checkpoint_path = checkpoint_dir / config['checkpoint']['best_filename']
431
+ torch.save({
432
+ 'epoch': epoch,
433
+ 'model_state_dict': model.classifier_head.state_dict(),
434
+ 'optimizer_state_dict': optimizer.state_dict(),
435
+ 'val_loss': val_metrics['loss'],
436
+ 'val_accuracy': val_metrics['accuracy'],
437
+ 'config': config
438
+ }, checkpoint_path)
439
+ logger.info(f"βœ… Saved best checkpoint to {checkpoint_path}")
440
+ patience_counter = 0
441
+ else:
442
+ patience_counter += 1
443
+
444
+ # Early stopping
445
+ if config['training']['early_stopping']['enabled']:
446
+ if patience_counter >= config['training']['early_stopping']['patience']:
447
+ logger.info(f"Early stopping triggered after {epoch+1} epochs")
448
+ break
449
+
450
+ # Save last checkpoint
451
+ if config['checkpoint']['save_last'] and (epoch + 1) % config['checkpoint']['save_frequency'] == 0:
452
+ checkpoint_path = checkpoint_dir / config['checkpoint']['filename']
453
+ torch.save({
454
+ 'epoch': epoch,
455
+ 'model_state_dict': model.classifier_head.state_dict(),
456
+ 'optimizer_state_dict': optimizer.state_dict(),
457
+ 'val_loss': val_metrics['loss'],
458
+ 'val_accuracy': val_metrics['accuracy'],
459
+ 'config': config
460
+ }, checkpoint_path)
461
+ logger.info(f"Saved checkpoint to {checkpoint_path}")
462
+
463
+ logger.info("\nβœ… Training complete!")
464
+ logger.info(f"Best validation loss: {best_val_loss:.4f}")
465
+
466
+
467
+ if __name__ == "__main__":
468
+ main()
469
+
ui/gradio_interface.py CHANGED
@@ -17,6 +17,8 @@ import numpy as np
17
  import gradio as gr
18
 
19
  from diagnosis.ai_engine.model_loader import get_inference_pipeline
 
 
20
  from config import GradioConfig, default_gradio_config
21
 
22
  logger = logging.getLogger(__name__)
@@ -83,8 +85,9 @@ def format_articulation_issues(articulation_scores: list) -> str:
83
 
84
  def analyze_speech(
85
  audio_input: Optional[Tuple[int, np.ndarray]],
86
- audio_file: Optional[str]
87
- ) -> Tuple[str, str, str, str, Dict[str, Any]]:
 
88
  """
89
  Analyze speech audio for fluency and articulation issues.
90
 
@@ -167,6 +170,72 @@ def analyze_speech(
167
  except: pass
168
  # #endregion
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  # Calculate processing time
171
  processing_time_ms = (time.time() - start_time) * 1000
172
 
@@ -240,7 +309,121 @@ def analyze_speech(
240
  </div>
241
  """
242
 
243
- # Create JSON output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  json_output = {
245
  "status": "success",
246
  "fluency_metrics": {
@@ -259,6 +442,18 @@ def analyze_speech(
259
  "confidence": avg_confidence,
260
  "confidence_percentage": confidence_percentage,
261
  "processing_time_ms": processing_time_ms,
 
 
 
 
 
 
 
 
 
 
 
 
262
  "frame_predictions": [
263
  {
264
  "time": fp.time,
@@ -266,9 +461,10 @@ def analyze_speech(
266
  "fluency_label": fp.fluency_label,
267
  "articulation_class": fp.articulation_class,
268
  "articulation_label": fp.articulation_label,
269
- "confidence": fp.confidence
 
270
  }
271
- for fp in result.frame_predictions[:20] # First 20 frames for preview
272
  ]
273
  }
274
 
@@ -289,6 +485,7 @@ def analyze_speech(
289
  articulation_text,
290
  confidence_html,
291
  processing_time_html,
 
292
  json_output
293
  )
294
 
@@ -302,11 +499,13 @@ def analyze_speech(
302
  except: pass
303
  # #endregion
304
  error_html = f"<p style='color: red;'>❌ Error: {str(e)}</p>"
 
305
  return (
306
  error_html,
307
  f"Error: {str(e)}",
308
  "N/A",
309
  "N/A",
 
310
  {"error": str(e), "status": "error"}
311
  )
312
 
@@ -370,6 +569,13 @@ def create_gradio_interface(gradio_config: Optional[GradioConfig] = None) -> gr.
370
  format="wav"
371
  )
372
 
 
 
 
 
 
 
 
373
  analyze_btn = gr.Button(
374
  "πŸ” Analyze Speech",
375
  variant="primary",
@@ -409,6 +615,11 @@ def create_gradio_interface(gradio_config: Optional[GradioConfig] = None) -> gr.
409
  elem_classes=["output-box"]
410
  )
411
 
 
 
 
 
 
412
  json_output = gr.JSON(
413
  label="Detailed Results (JSON)",
414
  elem_classes=["output-box"]
@@ -417,12 +628,13 @@ def create_gradio_interface(gradio_config: Optional[GradioConfig] = None) -> gr.
417
  # Set up event handlers
418
  analyze_btn.click(
419
  fn=analyze_speech,
420
- inputs=[audio_mic, audio_file],
421
  outputs=[
422
  fluency_output,
423
  articulation_output,
424
  confidence_output,
425
  processing_time_output,
 
426
  json_output
427
  ]
428
  )
 
17
  import gradio as gr
18
 
19
  from diagnosis.ai_engine.model_loader import get_inference_pipeline
20
+ from api.routes import get_phoneme_mapper, get_error_mapper
21
+ from models.error_taxonomy import ErrorType, SeverityLevel
22
  from config import GradioConfig, default_gradio_config
23
 
24
  logger = logging.getLogger(__name__)
 
85
 
86
  def analyze_speech(
87
  audio_input: Optional[Tuple[int, np.ndarray]],
88
+ audio_file: Optional[str],
89
+ expected_text: Optional[str] = None
90
+ ) -> Tuple[str, str, str, str, str, Dict[str, Any]]:
91
  """
92
  Analyze speech audio for fluency and articulation issues.
93
 
 
170
  except: pass
171
  # #endregion
172
 
173
+ # Get phoneme and error mappers
174
+ phoneme_mapper = get_phoneme_mapper()
175
+ error_mapper = get_error_mapper()
176
+
177
+ # Map phonemes to frames if text provided
178
+ frame_phonemes = []
179
+ if expected_text and phoneme_mapper:
180
+ try:
181
+ frame_phonemes = phoneme_mapper.map_text_to_frames(
182
+ expected_text,
183
+ num_frames=result.num_frames,
184
+ audio_duration=result.duration
185
+ )
186
+ logger.info(f"βœ… Mapped {len(frame_phonemes)} phonemes to frames")
187
+ except Exception as e:
188
+ logger.warning(f"⚠️ Phoneme mapping failed: {e}")
189
+ frame_phonemes = [''] * result.num_frames
190
+ else:
191
+ frame_phonemes = [''] * result.num_frames
192
+
193
+ # Process errors with error mapper
194
+ errors = []
195
+ error_table_rows = []
196
+
197
+ for i, frame_pred in enumerate(result.frame_predictions):
198
+ phoneme = frame_phonemes[i] if i < len(frame_phonemes) else ''
199
+
200
+ # Map classifier output to error detail (8-class system)
201
+ class_id = frame_pred.articulation_class
202
+ if frame_pred.fluency_label == 'stutter':
203
+ class_id += 4 # Add 4 for stutter classes (4-7)
204
+
205
+ # Get error detail
206
+ if error_mapper:
207
+ try:
208
+ error_detail = error_mapper.map_classifier_output(
209
+ class_id=class_id,
210
+ confidence=frame_pred.confidence,
211
+ phoneme=phoneme if phoneme else 'unknown',
212
+ fluency_label=frame_pred.fluency_label
213
+ )
214
+
215
+ if error_detail.error_type != ErrorType.NORMAL:
216
+ errors.append((i, frame_pred.time, error_detail))
217
+
218
+ # Add to error table
219
+ severity_level = error_mapper.get_severity_level(error_detail.severity)
220
+ severity_color = {
221
+ SeverityLevel.NONE: "green",
222
+ SeverityLevel.LOW: "orange",
223
+ SeverityLevel.MEDIUM: "orange",
224
+ SeverityLevel.HIGH: "red"
225
+ }.get(severity_level, "gray")
226
+
227
+ error_table_rows.append({
228
+ "phoneme": error_detail.phoneme,
229
+ "time": f"{frame_pred.time:.2f}s",
230
+ "error_type": error_detail.error_type.value,
231
+ "wrong_sound": error_detail.wrong_sound or "N/A",
232
+ "severity": severity_level.value,
233
+ "severity_color": severity_color,
234
+ "therapy": error_detail.therapy[:80] + "..." if len(error_detail.therapy) > 80 else error_detail.therapy
235
+ })
236
+ except Exception as e:
237
+ logger.warning(f"Error mapping failed for frame {i}: {e}")
238
+
239
  # Calculate processing time
240
  processing_time_ms = (time.time() - start_time) * 1000
241
 
 
309
  </div>
310
  """
311
 
312
+ # Format error table with summary of problematic sounds
313
+ if error_table_rows:
314
+ # Group errors by phoneme to show which sounds have issues
315
+ phoneme_errors = {}
316
+ for row in error_table_rows:
317
+ phoneme = row['phoneme']
318
+ if phoneme not in phoneme_errors:
319
+ phoneme_errors[phoneme] = {
320
+ 'count': 0,
321
+ 'types': set(),
322
+ 'severity': 'low',
323
+ 'examples': []
324
+ }
325
+ phoneme_errors[phoneme]['count'] += 1
326
+ phoneme_errors[phoneme]['types'].add(row['error_type'])
327
+ if row['severity'] in ['high', 'medium']:
328
+ phoneme_errors[phoneme]['severity'] = row['severity']
329
+ if len(phoneme_errors[phoneme]['examples']) < 2:
330
+ phoneme_errors[phoneme]['examples'].append(row)
331
+
332
+ # Create summary section
333
+ problematic_sounds = sorted(phoneme_errors.keys())
334
+ summary_html = f"""
335
+ <div style='background-color: #fff3cd; border: 2px solid #ffc107; border-radius: 8px; padding: 15px; margin-bottom: 20px;'>
336
+ <h3 style='color: #856404; margin-top: 0;'>⚠️ Problematic Sounds Detected</h3>
337
+ <p style='color: #856404; font-size: 14px; margin-bottom: 10px;'>
338
+ <strong>{len(problematic_sounds)} sound(s) with issues:</strong> {', '.join([f'<strong style="color: red;">/{p}/</strong>' for p in problematic_sounds[:10]])}
339
+ {f'<span style="color: #666;">(+{len(problematic_sounds) - 10} more)</span>' if len(problematic_sounds) > 10 else ''}
340
+ </p>
341
+ <div style='display: flex; flex-wrap: wrap; gap: 10px;'>
342
+ """
343
+
344
+ for phoneme in problematic_sounds[:10]:
345
+ error_info = phoneme_errors[phoneme]
346
+ severity_color = 'red' if error_info['severity'] == 'high' else 'orange' if error_info['severity'] == 'medium' else '#666'
347
+ summary_html += f"""
348
+ <div style='background-color: white; border: 1px solid {severity_color}; border-radius: 4px; padding: 8px; min-width: 120px;'>
349
+ <strong style='color: {severity_color}; font-size: 18px;'>/{phoneme}/</strong>
350
+ <div style='font-size: 12px; color: #666;'>
351
+ {error_info['count']} error(s)<br/>
352
+ Types: {', '.join(error_info['types'])}
353
+ </div>
354
+ </div>
355
+ """
356
+
357
+ summary_html += """
358
+ </div>
359
+ </div>
360
+ """
361
+
362
+ # Create detailed error table
363
+ error_table_html = summary_html + """
364
+ <h4 style='color: #333; margin-top: 20px;'>πŸ“‹ Detailed Error Report</h4>
365
+ <table style='width: 100%; border-collapse: collapse; margin: 10px 0; font-size: 13px;'>
366
+ <thead>
367
+ <tr style='background-color: #f0f0f0;'>
368
+ <th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Sound</th>
369
+ <th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Time</th>
370
+ <th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Error Type</th>
371
+ <th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Wrong Sound</th>
372
+ <th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Severity</th>
373
+ <th style='padding: 10px; border: 1px solid #ddd; text-align: left; background-color: #e8e8e8;'>Therapy Recommendation</th>
374
+ </tr>
375
+ </thead>
376
+ <tbody>
377
+ """
378
+
379
+ for row in error_table_rows[:20]: # Limit to first 20 errors
380
+ severity_bg = {
381
+ 'high': '#ffebee',
382
+ 'medium': '#fff3e0',
383
+ 'low': '#f3e5f5',
384
+ 'none': '#e8f5e9'
385
+ }.get(row['severity'], '#f5f5f5')
386
+
387
+ error_table_html += f"""
388
+ <tr style='background-color: {severity_bg};'>
389
+ <td style='padding: 10px; border: 1px solid #ddd;'>
390
+ <strong style='color: {row['severity_color']}; font-size: 16px;'>/{row['phoneme']}/</strong>
391
+ </td>
392
+ <td style='padding: 10px; border: 1px solid #ddd;'>{row['time']}</td>
393
+ <td style='padding: 10px; border: 1px solid #ddd;'>
394
+ <span style='background-color: {row['severity_color']}; color: white; padding: 3px 8px; border-radius: 3px; font-size: 11px;'>
395
+ {row['error_type'].upper()}
396
+ </span>
397
+ </td>
398
+ <td style='padding: 10px; border: 1px solid #ddd;'>
399
+ {f"<strong style='color: red;'>/{row['wrong_sound']}/</strong>" if row['wrong_sound'] != 'N/A' else '<span style="color: #999;">N/A</span>'}
400
+ </td>
401
+ <td style='padding: 10px; border: 1px solid #ddd;'>
402
+ <strong style='color: {row['severity_color']};'>{row['severity'].upper()}</strong>
403
+ </td>
404
+ <td style='padding: 10px; border: 1px solid #ddd; font-size: 12px;'>{row['therapy']}</td>
405
+ </tr>
406
+ """
407
+
408
+ error_table_html += """
409
+ </tbody>
410
+ </table>
411
+ """
412
+
413
+ if len(error_table_rows) > 20:
414
+ error_table_html += f"<p style='color: #666; font-size: 12px; margin-top: 10px;'>πŸ“Š Showing first 20 of <strong>{len(error_table_rows)}</strong> total errors detected</p>"
415
+ else:
416
+ error_table_html = """
417
+ <div style='background-color: #d4edda; border: 2px solid #28a745; border-radius: 8px; padding: 20px; text-align: center;'>
418
+ <h3 style='color: #155724; margin-top: 0;'>βœ… No Errors Detected</h3>
419
+ <p style='color: #155724; font-size: 16px;'>
420
+ All sounds/phonemes were produced correctly!<br/>
421
+ <span style='font-size: 14px; color: #666;'>Great job! πŸŽ‰</span>
422
+ </p>
423
+ </div>
424
+ """
425
+
426
+ # Create JSON output with errors
427
  json_output = {
428
  "status": "success",
429
  "fluency_metrics": {
 
442
  "confidence": avg_confidence,
443
  "confidence_percentage": confidence_percentage,
444
  "processing_time_ms": processing_time_ms,
445
+ "error_count": len(errors),
446
+ "errors": [
447
+ {
448
+ "phoneme": err[2].phoneme,
449
+ "time": err[1],
450
+ "error_type": err[2].error_type.value,
451
+ "wrong_sound": err[2].wrong_sound,
452
+ "severity": error_mapper.get_severity_level(err[2].severity).value if error_mapper else "unknown",
453
+ "therapy": err[2].therapy
454
+ }
455
+ for err in errors[:20]
456
+ ] if errors else [],
457
  "frame_predictions": [
458
  {
459
  "time": fp.time,
 
461
  "fluency_label": fp.fluency_label,
462
  "articulation_class": fp.articulation_class,
463
  "articulation_label": fp.articulation_label,
464
+ "confidence": fp.confidence,
465
+ "phoneme": frame_phonemes[i] if i < len(frame_phonemes) else ''
466
  }
467
+ for i, fp in enumerate(result.frame_predictions[:20]) # First 20 frames for preview
468
  ]
469
  }
470
 
 
485
  articulation_text,
486
  confidence_html,
487
  processing_time_html,
488
+ error_table_html,
489
  json_output
490
  )
491
 
 
499
  except: pass
500
  # #endregion
501
  error_html = f"<p style='color: red;'>❌ Error: {str(e)}</p>"
502
+ error_table_html = "<p style='color: #999;'>No error details available</p>"
503
  return (
504
  error_html,
505
  f"Error: {str(e)}",
506
  "N/A",
507
  "N/A",
508
+ error_table_html,
509
  {"error": str(e), "status": "error"}
510
  )
511
 
 
569
  format="wav"
570
  )
571
 
572
+ expected_text_input = gr.Textbox(
573
+ label="Expected Text (Optional)",
574
+ placeholder="Enter the expected text/transcript for phoneme mapping",
575
+ lines=2,
576
+ info="Provide the expected text to enable phoneme-level error detection"
577
+ )
578
+
579
  analyze_btn = gr.Button(
580
  "πŸ” Analyze Speech",
581
  variant="primary",
 
615
  elem_classes=["output-box"]
616
  )
617
 
618
+ error_table_output = gr.HTML(
619
+ label="Error Details",
620
+ elem_classes=["output-box"]
621
+ )
622
+
623
  json_output = gr.JSON(
624
  label="Detailed Results (JSON)",
625
  elem_classes=["output-box"]
 
628
  # Set up event handlers
629
  analyze_btn.click(
630
  fn=analyze_speech,
631
+ inputs=[audio_mic, audio_file, expected_text_input],
632
  outputs=[
633
  fluency_output,
634
  articulation_output,
635
  confidence_output,
636
  processing_time_output,
637
+ error_table_output,
638
  json_output
639
  ]
640
  )