anfastech commited on
Commit
278e294
·
1 Parent(s): 3d66487

New: implemented many, many changes. 10% Phone-level detection: WORKING

Browse files
Docs/QUICK_START.md CHANGED
@@ -327,6 +327,7 @@ with Pool(4) as pool:
327
  ## API Reference
328
 
329
  ### Main Method
 
330
  ```python
331
  analyze_audio(
332
  audio_path: str, # Path to .wav file
@@ -336,6 +337,7 @@ analyze_audio(
336
  ```
337
 
338
  ### Utility Methods
 
339
  ```python
340
  # Phonetic similarity (0-1)
341
  _calculate_phonetic_similarity(char1: str, char2: str) -> float
 
327
  ## API Reference
328
 
329
  ### Main Method
330
+
331
  ```python
332
  analyze_audio(
333
  audio_path: str, # Path to .wav file
 
337
  ```
338
 
339
  ### Utility Methods
340
+
341
  ```python
342
  # Phonetic similarity (0-1)
343
  _calculate_phonetic_similarity(char1: str, char2: str) -> float
api/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ API module for speech pathology diagnosis endpoints.
3
+ """
4
+
api/routes.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ REST API routes for Speech Pathology Diagnosis.
3
+
4
+ This module provides FastAPI endpoints for batch file analysis,
5
+ session management, and health checks.
6
+ """
7
+
8
+ import logging
9
+ import os
10
+ import time
11
+ import tempfile
12
+ import uuid
13
+ from pathlib import Path
14
+ from typing import Optional, List, Dict, Any
15
+ from datetime import datetime
16
+
17
+ from fastapi import APIRouter, UploadFile, File, HTTPException, Query
18
+ from fastapi.responses import JSONResponse
19
+
20
+ from api.schemas import (
21
+ BatchDiagnosisResponse,
22
+ FrameDiagnosis,
23
+ ErrorReport,
24
+ SummaryMetrics,
25
+ SessionListResponse,
26
+ HealthResponse,
27
+ ErrorDetailSchema,
28
+ FluencyInfo,
29
+ ArticulationInfo
30
+ )
31
+ from models.phoneme_mapper import PhonemeMapper
32
+ from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel
33
+ from inference.inference_pipeline import InferencePipeline
34
+ from config import AudioConfig, default_audio_config
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+ # Create router
39
+ router = APIRouter(prefix="/diagnose", tags=["diagnosis"])
40
+
41
+ # In-memory session storage (in production, use Redis or database)
42
+ sessions: Dict[str, BatchDiagnosisResponse] = {}
43
+
44
+ # Global instances (will be injected)
45
+ inference_pipeline: Optional[InferencePipeline] = None
46
+ phoneme_mapper: Optional[PhonemeMapper] = None
47
+ error_mapper: Optional[ErrorMapper] = None
48
+
49
+
50
+ def initialize_routes(
51
+ pipeline: InferencePipeline,
52
+ mapper: Optional[PhonemeMapper] = None,
53
+ error_mapper_instance: Optional[ErrorMapper] = None
54
+ ):
55
+ """
56
+ Initialize routes with dependencies.
57
+
58
+ Args:
59
+ pipeline: InferencePipeline instance
60
+ mapper: Optional PhonemeMapper instance
61
+ error_mapper_instance: Optional ErrorMapper instance
62
+ """
63
+ global inference_pipeline, phoneme_mapper, error_mapper
64
+
65
+ inference_pipeline = pipeline
66
+
67
+ if mapper is None:
68
+ try:
69
+ phoneme_mapper = PhonemeMapper(
70
+ frame_duration_ms=default_audio_config.chunk_duration_ms,
71
+ sample_rate=default_audio_config.sample_rate
72
+ )
73
+ logger.info("✅ PhonemeMapper initialized")
74
+ except Exception as e:
75
+ logger.warning(f"⚠️ PhonemeMapper not available: {e}")
76
+ phoneme_mapper = None
77
+
78
+ if error_mapper_instance is None:
79
+ try:
80
+ error_mapper = ErrorMapper()
81
+ logger.info("✅ ErrorMapper initialized")
82
+ except Exception as e:
83
+ logger.error(f"❌ ErrorMapper failed to initialize: {e}")
84
+ error_mapper = None
85
+
86
+
87
+ @router.post("/file", response_model=BatchDiagnosisResponse)
88
+ async def diagnose_file(
89
+ audio: UploadFile = File(...),
90
+ text: Optional[str] = Query(None, description="Expected text/transcript for phoneme mapping"),
91
+ session_id: Optional[str] = Query(None, description="Optional session ID")
92
+ ):
93
+ """
94
+ Analyze audio file for speech pathology errors.
95
+
96
+ Performs complete phoneme-level analysis:
97
+ - Extracts Wav2Vec2 features
98
+ - Classifies fluency and articulation per frame
99
+ - Maps phonemes to frames
100
+ - Detects errors and generates therapy recommendations
101
+
102
+ Args:
103
+ audio: Audio file (WAV, MP3, etc.)
104
+ text: Optional expected text for phoneme mapping
105
+ session_id: Optional session ID (auto-generated if not provided)
106
+
107
+ Returns:
108
+ BatchDiagnosisResponse with detailed error analysis
109
+ """
110
+ if inference_pipeline is None:
111
+ raise HTTPException(status_code=503, detail="Inference pipeline not loaded")
112
+
113
+ start_time = time.time()
114
+
115
+ # Generate session ID
116
+ if not session_id:
117
+ session_id = str(uuid.uuid4())
118
+
119
+ # Save uploaded file
120
+ temp_file = None
121
+ try:
122
+ # Create temp file
123
+ temp_dir = tempfile.gettempdir()
124
+ os.makedirs(temp_dir, exist_ok=True)
125
+ temp_file = os.path.join(temp_dir, f"diagnosis_{session_id}_{audio.filename}")
126
+
127
+ # Save file
128
+ content = await audio.read()
129
+ with open(temp_file, "wb") as f:
130
+ f.write(content)
131
+
132
+ file_size_mb = len(content) / 1024 / 1024
133
+ logger.info(f"📂 Saved file: {temp_file} ({file_size_mb:.2f} MB)")
134
+
135
+ # Run inference
136
+ logger.info("🔄 Running phone-level inference...")
137
+ result = inference_pipeline.predict_phone_level(
138
+ temp_file,
139
+ return_timestamps=True
140
+ )
141
+
142
+ # Map phonemes to frames if text provided
143
+ frame_phonemes = []
144
+ if text and phoneme_mapper:
145
+ try:
146
+ frame_phonemes = phoneme_mapper.map_text_to_frames(
147
+ text,
148
+ num_frames=result.num_frames,
149
+ audio_duration=result.duration
150
+ )
151
+ logger.info(f"✅ Mapped {len(frame_phonemes)} phonemes to frames")
152
+ except Exception as e:
153
+ logger.warning(f"⚠️ Phoneme mapping failed: {e}, using empty phonemes")
154
+ frame_phonemes = [''] * result.num_frames
155
+ else:
156
+ frame_phonemes = [''] * result.num_frames
157
+ if not text:
158
+ logger.warning("⚠️ No text provided, phoneme mapping skipped")
159
+
160
+ # Process frame predictions with error mapping
161
+ frame_diagnoses = []
162
+ error_reports = []
163
+ error_count = 0
164
+
165
+ for i, frame_pred in enumerate(result.frame_predictions):
166
+ # Get phoneme for this frame
167
+ phoneme = frame_phonemes[i] if i < len(frame_phonemes) else ''
168
+
169
+ # Map classifier output to error detail
170
+ # Combine fluency and articulation into 8-class system
171
+ # Class = articulation_class * 2 + (1 if stutter else 0)
172
+ class_id = frame_pred.articulation_class
173
+ if frame_pred.fluency_label == 'stutter':
174
+ class_id += 4 # Add 4 for stutter classes (4-7)
175
+
176
+ # Get error detail
177
+ error_detail = None
178
+ if error_mapper:
179
+ try:
180
+ error_detail_obj = error_mapper.map_classifier_output(
181
+ class_id=class_id,
182
+ confidence=frame_pred.confidence,
183
+ phoneme=phoneme if phoneme else 'unknown',
184
+ fluency_label=frame_pred.fluency_label
185
+ )
186
+
187
+ # Add frame index
188
+ error_detail_obj.frame_indices = [i]
189
+
190
+ # Convert to schema
191
+ if error_detail_obj.error_type != ErrorType.NORMAL:
192
+ error_detail = ErrorDetailSchema(
193
+ phoneme=error_detail_obj.phoneme,
194
+ error_type=error_detail_obj.error_type.value,
195
+ wrong_sound=error_detail_obj.wrong_sound,
196
+ severity=error_detail_obj.severity,
197
+ confidence=error_detail_obj.confidence,
198
+ therapy=error_detail_obj.therapy,
199
+ frame_indices=[i]
200
+ )
201
+ error_count += 1
202
+
203
+ # Create error report
204
+ severity_level = error_mapper.get_severity_level(error_detail_obj.severity)
205
+ error_reports.append(ErrorReport(
206
+ frame_id=i,
207
+ timestamp=frame_pred.time,
208
+ phoneme=error_detail_obj.phoneme,
209
+ error=error_detail,
210
+ severity_level=severity_level.value
211
+ ))
212
+ except Exception as e:
213
+ logger.warning(f"Error mapping failed for frame {i}: {e}")
214
+
215
+ # Create frame diagnosis
216
+ severity_level_str = "none"
217
+ if error_detail:
218
+ severity_level_str = error_mapper.get_severity_level(error_detail.severity).value if error_mapper else "none"
219
+
220
+ frame_diagnoses.append(FrameDiagnosis(
221
+ frame_id=i,
222
+ timestamp=frame_pred.time,
223
+ phoneme=phoneme if phoneme else 'unknown',
224
+ fluency=FluencyInfo(
225
+ label=frame_pred.fluency_label,
226
+ confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob)
227
+ ),
228
+ articulation=ArticulationInfo(
229
+ label=frame_pred.articulation_label,
230
+ confidence=frame_pred.confidence,
231
+ class_id=frame_pred.articulation_class
232
+ ),
233
+ error=error_detail,
234
+ severity_level=severity_level_str,
235
+ confidence=frame_pred.confidence
236
+ ))
237
+
238
+ # Calculate summary metrics
239
+ fluency_scores = [1.0 - fp.fluency_prob for fp in result.frame_predictions] # Convert stutter prob to fluency
240
+ avg_fluency = sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0.0
241
+
242
+ # Articulation score: percentage of normal frames
243
+ normal_frames = sum(1 for fp in result.frame_predictions if fp.articulation_class == 0)
244
+ articulation_score = normal_frames / result.num_frames if result.num_frames > 0 else 0.0
245
+
246
+ summary = SummaryMetrics(
247
+ fluency_score=avg_fluency,
248
+ fluency_percentage=avg_fluency * 100.0,
249
+ articulation_score=articulation_score,
250
+ error_count=error_count,
251
+ error_rate=error_count / result.num_frames if result.num_frames > 0 else 0.0
252
+ )
253
+
254
+ # Generate therapy plan (unique therapy recommendations)
255
+ therapy_plan = []
256
+ if error_mapper:
257
+ seen_therapies = set()
258
+ for error_report in error_reports:
259
+ if error_report.error.therapy and error_report.error.therapy not in seen_therapies:
260
+ therapy_plan.append(error_report.error.therapy)
261
+ seen_therapies.add(error_report.error.therapy)
262
+
263
+ processing_time_ms = (time.time() - start_time) * 1000
264
+
265
+ # Create response
266
+ response = BatchDiagnosisResponse(
267
+ session_id=session_id,
268
+ filename=audio.filename or "unknown",
269
+ duration=result.duration,
270
+ total_frames=result.num_frames,
271
+ error_count=error_count,
272
+ errors=error_reports,
273
+ frame_diagnoses=frame_diagnoses,
274
+ summary=summary,
275
+ therapy_plan=therapy_plan,
276
+ processing_time_ms=processing_time_ms,
277
+ created_at=datetime.now()
278
+ )
279
+
280
+ # Store in sessions
281
+ sessions[session_id] = response
282
+
283
+ logger.info(f"✅ Diagnosis complete: {error_count} errors, {processing_time_ms:.0f}ms")
284
+
285
+ return response
286
+
287
+ except HTTPException:
288
+ raise
289
+ except Exception as e:
290
+ logger.error(f"❌ Diagnosis failed: {e}", exc_info=True)
291
+ raise HTTPException(status_code=500, detail=f"Diagnosis failed: {str(e)}")
292
+
293
+ finally:
294
+ # Cleanup temp file
295
+ if temp_file and os.path.exists(temp_file):
296
+ try:
297
+ os.remove(temp_file)
298
+ logger.debug(f"🧹 Cleaned up: {temp_file}")
299
+ except Exception as e:
300
+ logger.warning(f"Could not clean up {temp_file}: {e}")
301
+
302
+
303
+ @router.get("/results/{session_id}", response_model=BatchDiagnosisResponse)
304
+ async def get_results(session_id: str):
305
+ """
306
+ Get cached diagnosis results for a session.
307
+
308
+ Args:
309
+ session_id: Session identifier
310
+
311
+ Returns:
312
+ BatchDiagnosisResponse
313
+ """
314
+ if session_id not in sessions:
315
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
316
+
317
+ return sessions[session_id]
318
+
319
+
320
+ @router.get("/results", response_model=SessionListResponse)
321
+ async def list_results(limit: int = Query(10, ge=1, le=100)):
322
+ """
323
+ List all cached diagnosis sessions.
324
+
325
+ Args:
326
+ limit: Maximum number of sessions to return
327
+
328
+ Returns:
329
+ SessionListResponse with session metadata
330
+ """
331
+ session_list = []
332
+ for sid, response in list(sessions.items())[:limit]:
333
+ session_list.append({
334
+ "session_id": sid,
335
+ "filename": response.filename,
336
+ "duration": response.duration,
337
+ "error_count": response.error_count,
338
+ "created_at": response.created_at.isoformat(),
339
+ "processing_time_ms": response.processing_time_ms
340
+ })
341
+
342
+ return SessionListResponse(
343
+ sessions=session_list,
344
+ total=len(sessions)
345
+ )
346
+
347
+
348
+ @router.delete("/results/{session_id}")
349
+ async def delete_results(session_id: str):
350
+ """
351
+ Delete cached diagnosis results for a session.
352
+
353
+ Args:
354
+ session_id: Session identifier
355
+
356
+ Returns:
357
+ Success message
358
+ """
359
+ if session_id not in sessions:
360
+ raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
361
+
362
+ del sessions[session_id]
363
+ logger.info(f"🗑️ Deleted session: {session_id}")
364
+
365
+ return {"status": "success", "message": f"Session {session_id} deleted"}
366
+
367
+
368
+ @router.get("/health", response_model=HealthResponse)
369
+ async def health_check():
370
+ """
371
+ Health check endpoint.
372
+
373
+ Returns:
374
+ HealthResponse with service status
375
+ """
376
+ import time
377
+ start_time = getattr(health_check, '_start_time', time.time())
378
+ if not hasattr(health_check, '_start_time'):
379
+ health_check._start_time = start_time
380
+
381
+ uptime = time.time() - start_time
382
+
383
+ return HealthResponse(
384
+ status="healthy" if inference_pipeline is not None else "degraded",
385
+ version="2.0.0",
386
+ model_loaded=inference_pipeline is not None,
387
+ uptime_seconds=uptime
388
+ )
389
+
api/schemas.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic schemas for Speech Pathology Diagnosis API.
3
+
4
+ This module defines request and response models for REST API and WebSocket endpoints.
5
+ """
6
+
7
+ from typing import List, Optional, Dict, Any
8
+ from pydantic import BaseModel, Field
9
+ from datetime import datetime
10
+
11
+
12
+ class FluencyInfo(BaseModel):
13
+ """Fluency classification information."""
14
+ label: str = Field(..., description="Fluency label: 'normal' or 'stutter'")
15
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score (0-1)")
16
+
17
+
18
+ class ArticulationInfo(BaseModel):
19
+ """Articulation classification information."""
20
+ label: str = Field(..., description="Articulation label: 'normal', 'substitution', 'omission', 'distortion'")
21
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score (0-1)")
22
+ class_id: int = Field(..., ge=0, le=3, description="Class ID: 0=normal, 1=substitution, 2=omission, 3=distortion")
23
+
24
+
25
+ class ErrorDetailSchema(BaseModel):
26
+ """Error detail schema for API responses."""
27
+ phoneme: str = Field(..., description="Expected phoneme symbol")
28
+ error_type: str = Field(..., description="Error type: normal, substitution, omission, distortion")
29
+ wrong_sound: Optional[str] = Field(None, description="For substitutions, the incorrect phoneme produced")
30
+ severity: float = Field(..., ge=0.0, le=1.0, description="Severity score (0-1)")
31
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Model confidence (0-1)")
32
+ therapy: str = Field(..., description="Therapy recommendation")
33
+ frame_indices: List[int] = Field(default_factory=list, description="Frame indices where error occurs")
34
+
35
+
36
+ class FrameDiagnosis(BaseModel):
37
+ """Diagnosis for a single frame."""
38
+ frame_id: int = Field(..., description="Frame index")
39
+ timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
40
+ phoneme: str = Field(..., description="Expected phoneme for this frame")
41
+ fluency: FluencyInfo = Field(..., description="Fluency classification")
42
+ articulation: ArticulationInfo = Field(..., description="Articulation classification")
43
+ error: Optional[ErrorDetailSchema] = Field(None, description="Error details if error detected")
44
+ severity_level: str = Field(..., description="Severity level: none, low, medium, high")
45
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
46
+
47
+
48
+ class ErrorReport(BaseModel):
49
+ """Detailed error report for a frame."""
50
+ frame_id: int = Field(..., description="Frame index")
51
+ timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
52
+ phoneme: str = Field(..., description="Expected phoneme")
53
+ error: ErrorDetailSchema = Field(..., description="Error details")
54
+ severity_level: str = Field(..., description="Severity level: none, low, medium, high")
55
+
56
+
57
+ class SummaryMetrics(BaseModel):
58
+ """Summary metrics for the analysis."""
59
+ fluency_score: float = Field(..., ge=0.0, le=1.0, description="Average fluency score (0=stutter, 1=normal)")
60
+ fluency_percentage: float = Field(..., ge=0.0, le=100.0, description="Fluency percentage")
61
+ articulation_score: float = Field(..., ge=0.0, le=1.0, description="Average articulation correctness")
62
+ error_count: int = Field(..., ge=0, description="Total number of errors detected")
63
+ error_rate: float = Field(..., ge=0.0, le=1.0, description="Error rate (errors/total_frames)")
64
+
65
+
66
+ class BatchDiagnosisResponse(BaseModel):
67
+ """Response for batch file diagnosis."""
68
+ session_id: str = Field(..., description="Session identifier")
69
+ filename: str = Field(..., description="Processed filename")
70
+ duration: float = Field(..., ge=0.0, description="Audio duration in seconds")
71
+ total_frames: int = Field(..., ge=0, description="Total number of frames analyzed")
72
+ error_count: int = Field(..., ge=0, description="Number of errors detected")
73
+ errors: List[ErrorReport] = Field(default_factory=list, description="List of error reports")
74
+ frame_diagnoses: List[FrameDiagnosis] = Field(default_factory=list, description="All frame diagnoses")
75
+ summary: SummaryMetrics = Field(..., description="Summary metrics")
76
+ therapy_plan: List[str] = Field(default_factory=list, description="Therapy recommendations")
77
+ processing_time_ms: float = Field(..., ge=0.0, description="Processing time in milliseconds")
78
+ created_at: datetime = Field(default_factory=datetime.now, description="Analysis timestamp")
79
+
80
+
81
+ class StreamingDiagnosisRequest(BaseModel):
82
+ """Request for streaming diagnosis."""
83
+ audio_chunk: bytes = Field(..., description="Audio chunk data (320 samples for 20ms @ 16kHz)")
84
+ sample_rate: int = Field(16000, description="Sample rate in Hz")
85
+ session_id: str = Field(..., description="Session identifier")
86
+ frame_index: Optional[int] = Field(None, description="Frame index for tracking")
87
+
88
+
89
+ class StreamingDiagnosisResponse(BaseModel):
90
+ """Response for streaming diagnosis (single frame)."""
91
+ session_id: str = Field(..., description="Session identifier")
92
+ frame_id: int = Field(..., description="Frame index")
93
+ timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
94
+ phoneme: str = Field(..., description="Expected phoneme")
95
+ fluency: FluencyInfo = Field(..., description="Fluency classification")
96
+ articulation: ArticulationInfo = Field(..., description="Articulation classification")
97
+ error: Optional[ErrorDetailSchema] = Field(None, description="Error details if error detected")
98
+ severity_level: str = Field(..., description="Severity level")
99
+ confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
100
+ latency_ms: float = Field(..., ge=0.0, description="Processing latency in milliseconds")
101
+
102
+
103
+ class SessionListResponse(BaseModel):
104
+ """Response for listing sessions."""
105
+ sessions: List[Dict[str, Any]] = Field(..., description="List of session metadata")
106
+ total: int = Field(..., ge=0, description="Total number of sessions")
107
+
108
+
109
+ class HealthResponse(BaseModel):
110
+ """Health check response."""
111
+ status: str = Field(..., description="Service status")
112
+ version: str = Field(..., description="API version")
113
+ model_loaded: bool = Field(..., description="Whether model is loaded")
114
+ uptime_seconds: float = Field(..., ge=0.0, description="Service uptime in seconds")
115
+
api/streaming.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebSocket streaming for real-time speech pathology diagnosis.
3
+
4
+ This module provides WebSocket endpoint for streaming audio analysis
5
+ with <50ms latency per frame requirement.
6
+ """
7
+
8
+ import logging
9
+ import time
10
+ import uuid
11
+ import numpy as np
12
+ from typing import Optional, Dict
13
+ from collections import deque
14
+ from datetime import datetime
15
+
16
+ from fastapi import WebSocket, WebSocketDisconnect, HTTPException
17
+
18
+ from api.schemas import StreamingDiagnosisResponse, FluencyInfo, ArticulationInfo, ErrorDetailSchema
19
+ from models.phoneme_mapper import PhonemeMapper
20
+ from models.error_taxonomy import ErrorMapper, ErrorType
21
+ from inference.inference_pipeline import InferencePipeline
22
+ from config import AudioConfig, default_audio_config
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class StreamingBuffer:
28
+ """
29
+ Buffer for managing sliding window in streaming audio.
30
+
31
+ Maintains a buffer of audio samples and provides frames
32
+ for processing with overlap management.
33
+ """
34
+
35
+ def __init__(self, window_size_samples: int, hop_size_samples: int):
36
+ """
37
+ Initialize streaming buffer.
38
+
39
+ Args:
40
+ window_size_samples: Size of analysis window in samples
41
+ hop_size_samples: Hop size between windows in samples
42
+ """
43
+ self.window_size_samples = window_size_samples
44
+ self.hop_size_samples = hop_size_samples
45
+ self.buffer = deque(maxlen=window_size_samples + hop_size_samples)
46
+ self.frame_index = 0
47
+
48
+ logger.debug(f"StreamingBuffer initialized: window={window_size_samples}, hop={hop_size_samples}")
49
+
50
+ def add_chunk(self, audio_chunk: np.ndarray) -> bool:
51
+ """
52
+ Add audio chunk to buffer.
53
+
54
+ Args:
55
+ audio_chunk: Audio samples to add
56
+
57
+ Returns:
58
+ True if buffer has enough data for a frame, False otherwise
59
+ """
60
+ self.buffer.extend(audio_chunk)
61
+ return len(self.buffer) >= self.window_size_samples
62
+
63
+ def get_frame(self) -> Optional[np.ndarray]:
64
+ """
65
+ Get current frame from buffer.
66
+
67
+ Returns:
68
+ Audio frame array if ready, None otherwise
69
+ """
70
+ if len(self.buffer) < self.window_size_samples:
71
+ return None
72
+
73
+ # Extract window (last window_size_samples)
74
+ frame = np.array(list(self.buffer)[-self.window_size_samples:])
75
+ return frame
76
+
77
+ def slide(self):
78
+ """Advance buffer by hop size."""
79
+ # Remove oldest hop_size_samples
80
+ for _ in range(min(self.hop_size_samples, len(self.buffer))):
81
+ if self.buffer:
82
+ self.buffer.popleft()
83
+ self.frame_index += 1
84
+
85
+
86
+ # Global instances (will be injected)
87
+ inference_pipeline: Optional[InferencePipeline] = None
88
+ phoneme_mapper: Optional[PhonemeMapper] = None
89
+ error_mapper: Optional[ErrorMapper] = None
90
+
91
+ # Active streaming sessions
92
+ streaming_sessions: Dict[str, Dict] = {}
93
+
94
+
95
+ def initialize_streaming(
96
+ pipeline: InferencePipeline,
97
+ mapper: Optional[PhonemeMapper] = None,
98
+ error_mapper_instance: Optional[ErrorMapper] = None
99
+ ):
100
+ """
101
+ Initialize streaming with dependencies.
102
+
103
+ Args:
104
+ pipeline: InferencePipeline instance
105
+ mapper: Optional PhonemeMapper instance
106
+ error_mapper_instance: Optional ErrorMapper instance
107
+ """
108
+ global inference_pipeline, phoneme_mapper, error_mapper
109
+
110
+ inference_pipeline = pipeline
111
+
112
+ if mapper is None:
113
+ try:
114
+ phoneme_mapper = PhonemeMapper(
115
+ frame_duration_ms=default_audio_config.chunk_duration_ms,
116
+ sample_rate=default_audio_config.sample_rate
117
+ )
118
+ logger.info("✅ PhonemeMapper initialized for streaming")
119
+ except Exception as e:
120
+ logger.warning(f"⚠️ PhonemeMapper not available: {e}")
121
+ phoneme_mapper = None
122
+
123
+ if error_mapper_instance is None:
124
+ try:
125
+ error_mapper = ErrorMapper()
126
+ logger.info("✅ ErrorMapper initialized for streaming")
127
+ except Exception as e:
128
+ logger.error(f"❌ ErrorMapper failed to initialize: {e}")
129
+ error_mapper = None
130
+
131
+
132
+ async def handle_streaming_websocket(websocket: WebSocket, session_id: Optional[str] = None):
133
+ """
134
+ Handle WebSocket connection for streaming diagnosis.
135
+
136
+ Args:
137
+ websocket: WebSocket connection
138
+ session_id: Optional session ID (auto-generated if not provided)
139
+ """
140
+ if inference_pipeline is None:
141
+ await websocket.close(code=1003, reason="Inference pipeline not loaded")
142
+ return
143
+
144
+ # Generate session ID
145
+ if not session_id:
146
+ session_id = str(uuid.uuid4())
147
+
148
+ # Accept connection
149
+ await websocket.accept()
150
+ logger.info(f"🔌 WebSocket connected: session_id={session_id}")
151
+
152
+ # Initialize buffer
153
+ window_size_samples = int(
154
+ inference_pipeline.inference_config.window_size_ms *
155
+ inference_pipeline.audio_config.sample_rate / 1000
156
+ )
157
+ hop_size_samples = int(
158
+ inference_pipeline.inference_config.hop_size_ms *
159
+ inference_pipeline.audio_config.sample_rate / 1000
160
+ )
161
+
162
+ buffer = StreamingBuffer(window_size_samples, hop_size_samples)
163
+
164
+ # Session metadata
165
+ streaming_sessions[session_id] = {
166
+ "session_id": session_id,
167
+ "connected_at": datetime.now(),
168
+ "frame_count": 0,
169
+ "total_latency_ms": 0.0
170
+ }
171
+
172
+ frame_index = 0
173
+ start_time = time.time()
174
+
175
+ try:
176
+ while True:
177
+ # Receive audio chunk
178
+ try:
179
+ data = await websocket.receive_bytes()
180
+
181
+ # Convert bytes to numpy array
182
+ # Assuming 16-bit PCM, mono, 16kHz
183
+ audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
184
+
185
+ # Add to buffer
186
+ buffer.add_chunk(audio_chunk)
187
+
188
+ # Process if buffer is ready
189
+ if buffer.get_frame() is not None:
190
+ frame_start_time = time.time()
191
+
192
+ # Get frame
193
+ frame = buffer.get_frame()
194
+
195
+ # Run inference
196
+ try:
197
+ result = inference_pipeline.predict_phone_level(
198
+ frame,
199
+ return_timestamps=False
200
+ )
201
+
202
+ if result.frame_predictions:
203
+ frame_pred = result.frame_predictions[0] # Single frame result
204
+
205
+ # Map to error detail
206
+ class_id = frame_pred.articulation_class
207
+ if frame_pred.fluency_label == 'stutter':
208
+ class_id += 4
209
+
210
+ error_detail = None
211
+ phoneme = '' # Streaming doesn't have text input
212
+
213
+ if error_mapper:
214
+ try:
215
+ error_detail_obj = error_mapper.map_classifier_output(
216
+ class_id=class_id,
217
+ confidence=frame_pred.confidence,
218
+ phoneme=phoneme,
219
+ fluency_label=frame_pred.fluency_label
220
+ )
221
+
222
+ if error_detail_obj.error_type != ErrorType.NORMAL:
223
+ error_detail = ErrorDetailSchema(
224
+ phoneme=error_detail_obj.phoneme,
225
+ error_type=error_detail_obj.error_type.value,
226
+ wrong_sound=error_detail_obj.wrong_sound,
227
+ severity=error_detail_obj.severity,
228
+ confidence=error_detail_obj.confidence,
229
+ therapy=error_detail_obj.therapy,
230
+ frame_indices=[frame_index]
231
+ )
232
+ except Exception as e:
233
+ logger.warning(f"Error mapping failed: {e}")
234
+
235
+ # Calculate latency
236
+ latency_ms = (time.time() - frame_start_time) * 1000
237
+
238
+ # Get severity level
239
+ severity_level = "none"
240
+ if error_detail and error_mapper:
241
+ severity_level = error_mapper.get_severity_level(error_detail.severity).value
242
+
243
+ # Create response
244
+ response = StreamingDiagnosisResponse(
245
+ session_id=session_id,
246
+ frame_id=frame_index,
247
+ timestamp=frame_index * (inference_pipeline.inference_config.hop_size_ms / 1000.0),
248
+ phoneme=phoneme,
249
+ fluency=FluencyInfo(
250
+ label=frame_pred.fluency_label,
251
+ confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob)
252
+ ),
253
+ articulation=ArticulationInfo(
254
+ label=frame_pred.articulation_label,
255
+ confidence=frame_pred.confidence,
256
+ class_id=frame_pred.articulation_class
257
+ ),
258
+ error=error_detail,
259
+ severity_level=severity_level,
260
+ confidence=frame_pred.confidence,
261
+ latency_ms=latency_ms
262
+ )
263
+
264
+ # Send response
265
+ await websocket.send_json(response.model_dump())
266
+
267
+ # Update session stats
268
+ streaming_sessions[session_id]["frame_count"] += 1
269
+ streaming_sessions[session_id]["total_latency_ms"] += latency_ms
270
+
271
+ # Check latency requirement
272
+ if latency_ms > 50.0:
273
+ logger.warning(f"⚠️ Latency exceeded 50ms: {latency_ms:.1f}ms")
274
+
275
+ # Slide buffer
276
+ buffer.slide()
277
+ frame_index += 1
278
+
279
+ except Exception as e:
280
+ logger.error(f"❌ Inference failed: {e}", exc_info=True)
281
+ await websocket.send_json({
282
+ "error": f"Inference failed: {str(e)}",
283
+ "frame_id": frame_index
284
+ })
285
+
286
+ except Exception as e:
287
+ logger.error(f"❌ Error processing chunk: {e}", exc_info=True)
288
+ await websocket.send_json({
289
+ "error": f"Processing failed: {str(e)}",
290
+ "frame_id": frame_index
291
+ })
292
+
293
+ except WebSocketDisconnect:
294
+ logger.info(f"🔌 WebSocket disconnected: session_id={session_id}")
295
+ except Exception as e:
296
+ logger.error(f"❌ WebSocket error: {e}", exc_info=True)
297
+ finally:
298
+ # Cleanup session
299
+ if session_id in streaming_sessions:
300
+ session_data = streaming_sessions[session_id]
301
+ avg_latency = session_data["total_latency_ms"] / session_data["frame_count"] if session_data["frame_count"] > 0 else 0.0
302
+ logger.info(f"📊 Session {session_id} stats: {session_data['frame_count']} frames, "
303
+ f"avg_latency={avg_latency:.1f}ms")
304
+ del streaming_sessions[session_id]
305
+
app.py CHANGED
@@ -173,33 +173,65 @@ async def diagnose_speech(
173
 
174
  # Run inference
175
  logger.info("🔄 Running inference pipeline...")
176
- result = inference_pipeline.predict_batch(
 
177
  temp_file,
178
- return_timestamps=True,
179
- apply_smoothing=True
180
  )
181
 
182
  processing_time_ms = (time.time() - start_time) * 1000
183
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  # Format response
185
  response = {
186
  "status": "success",
187
  "fluency_metrics": {
188
- "mean_fluency": result.fluency_metrics.get("mean", 0.0),
189
- "fluency_percentage": result.fluency_metrics.get("mean", 0.0) * 100,
190
- "fluent_frames_ratio": result.fluency_metrics.get("fluent_frames_ratio", 0.0),
191
- "std": result.fluency_metrics.get("std", 0.0),
192
- "min": result.fluency_metrics.get("min", 0.0),
193
- "max": result.fluency_metrics.get("max", 0.0),
194
- "median": result.fluency_metrics.get("median", 0.0)
195
  },
196
  "articulation_results": {
197
- "total_frames": len(result.articulation_scores),
198
- "frame_duration_ms": result.frame_duration_ms,
199
- "scores": result.articulation_scores # All frames
 
 
 
 
 
 
 
 
 
 
 
 
 
200
  },
201
- "confidence": result.confidence,
202
- "confidence_percentage": result.confidence * 100,
203
  "processing_time_ms": processing_time_ms
204
  }
205
 
 
173
 
174
  # Run inference
175
  logger.info("🔄 Running inference pipeline...")
176
+ # Use new phone-level prediction
177
+ result = inference_pipeline.predict_phone_level(
178
  temp_file,
179
+ return_timestamps=True
 
180
  )
181
 
182
  processing_time_ms = (time.time() - start_time) * 1000
183
 
184
+ # Extract metrics from new PhoneLevelResult format
185
+ aggregate = result.aggregate
186
+ mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
187
+ fluency_percentage = (1.0 - mean_fluency_stutter) * 100 # Convert stutter prob to fluency percentage
188
+
189
+ # Count fluent frames
190
+ fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
191
+ fluent_frames_ratio = fluent_frames / result.num_frames if result.num_frames > 0 else 0.0
192
+
193
+ # Extract articulation class distribution
194
+ articulation_class_counts = {}
195
+ for fp in result.frame_predictions:
196
+ label = fp.articulation_label
197
+ articulation_class_counts[label] = articulation_class_counts.get(label, 0) + 1
198
+
199
+ # Get dominant articulation class
200
+ dominant_articulation = aggregate.get("articulation_label", "normal")
201
+
202
+ # Calculate average confidence
203
+ avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
204
+
205
  # Format response
206
  response = {
207
  "status": "success",
208
  "fluency_metrics": {
209
+ "mean_fluency": fluency_percentage / 100.0,
210
+ "fluency_percentage": fluency_percentage,
211
+ "fluent_frames_ratio": fluent_frames_ratio,
212
+ "fluent_frames_percentage": fluent_frames_ratio * 100,
213
+ "stutter_probability": mean_fluency_stutter
 
 
214
  },
215
  "articulation_results": {
216
+ "total_frames": result.num_frames,
217
+ "frame_duration_ms": int(inference_pipeline.inference_config.hop_size_ms),
218
+ "dominant_class": aggregate.get("articulation_class", 0),
219
+ "dominant_label": dominant_articulation,
220
+ "class_distribution": articulation_class_counts,
221
+ "frame_predictions": [
222
+ {
223
+ "time": fp.time,
224
+ "fluency_prob": fp.fluency_prob,
225
+ "fluency_label": fp.fluency_label,
226
+ "articulation_class": fp.articulation_class,
227
+ "articulation_label": fp.articulation_label,
228
+ "confidence": fp.confidence
229
+ }
230
+ for fp in result.frame_predictions
231
+ ]
232
  },
233
+ "confidence": avg_confidence,
234
+ "confidence_percentage": avg_confidence * 100,
235
  "processing_time_ms": processing_time_ms
236
  }
237
 
config.py CHANGED
@@ -92,12 +92,24 @@ class InferenceConfig:
92
  Reduces jitter in frame-level predictions.
93
  batch_size: Number of chunks to process in parallel during inference.
94
  Higher values = faster but more memory usage.
 
 
 
 
 
 
 
95
  """
96
  fluency_threshold: float = 0.5
97
  articulation_threshold: float = 0.6
98
  min_chunk_duration_ms: int = 10
99
  smoothing_window: int = 5
100
  batch_size: int = 32
 
 
 
 
 
101
 
102
 
103
  @dataclass
 
92
  Reduces jitter in frame-level predictions.
93
  batch_size: Number of chunks to process in parallel during inference.
94
  Higher values = faster but more memory usage.
95
+ window_size_ms: Size of sliding window in milliseconds (default: 1000ms = 1 second).
96
+ Minimum for Wav2Vec2 stability.
97
+ hop_size_ms: Hop size between windows in milliseconds (default: 10ms).
98
+ Controls temporal resolution (100 frames/second).
99
+ frame_rate: Frames per second (calculated from hop_size_ms).
100
+ minimum_audio_length: Minimum audio length in seconds (must be >= window_size_ms).
101
+ phone_level_strategy: Strategy for phone-level analysis ("sliding_window").
102
  """
103
  fluency_threshold: float = 0.5
104
  articulation_threshold: float = 0.6
105
  min_chunk_duration_ms: int = 10
106
  smoothing_window: int = 5
107
  batch_size: int = 32
108
+ window_size_ms: int = 1000 # 1 second minimum for Wav2Vec2
109
+ hop_size_ms: int = 10 # 10ms for phone-level resolution
110
+ frame_rate: float = 100.0 # 100 frames per second (1/hop_size_ms)
111
+ minimum_audio_length: float = 1.0 # Must be >= window_size_ms
112
+ phone_level_strategy: str = "sliding_window"
113
 
114
 
115
  @dataclass
data/therapy_recommendations.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "substitutions": {
3
+ "/s/→/θ/": "Lisp - Use tongue tip placement behind upper teeth. Practice /s/ in isolation, then in words. Use mirror feedback to ensure tongue is not protruding.",
4
+ "/s/→/ʃ/": "Sibilant confusion - Practice /s/ vs /sh/ distinction. Focus on tongue position: /s/ has tongue tip up, /sh/ has tongue body back.",
5
+ "/s/→/z/": "Voicing error - Practice voiceless /s/ vs voiced /z/. Place hand on throat to feel vibration difference.",
6
+ "/r/→/w/": "Rhotacism - Practice tongue position: curl tongue back, avoid lip rounding. Start with /r/ in isolation, then CV syllables (ra, re, ri, ro, ru).",
7
+ "/r/→/l/": "Rhotacism - Focus on tongue tip position vs. tongue body placement. /r/ uses tongue body, /l/ uses tongue tip.",
8
+ "/r/→/ɹ/": "Rhotacism variant - Practice standard /r/ production with proper tongue curl.",
9
+ "/l/→/w/": "Liquid substitution - Practice lateral tongue placement with tongue tip up to alveolar ridge.",
10
+ "/l/→/j/": "Liquid substitution - Focus on tongue tip contact for /l/ vs. tongue body for /j/.",
11
+ "/k/→/t/": "Velar to alveolar substitution - Practice back tongue placement for /k/. Use mirror to see tongue position.",
12
+ "/k/→/p/": "Velar to bilabial substitution - Practice velar placement: tongue back, soft palate contact.",
13
+ "/g/→/d/": "Velar to alveolar substitution - Practice voiced velar /g/ with tongue back position.",
14
+ "/g/→/b/": "Velar to bilabial substitution - Practice velar placement for /g/.",
15
+ "/θ/→/f/": "Th-fronting - Practice tongue tip placement between teeth for /θ/. Use mirror to ensure correct position.",
16
+ "/θ/→/s/": "Th-fronting - Practice interdental placement for /θ/ vs. alveolar /s/.",
17
+ "/ð/→/v/": "Voiced th-fronting - Practice interdental placement for /ð/ vs. labiodental /v/.",
18
+ "/ð/→/z/": "Voiced th-fronting - Practice interdental /ð/ vs. alveolar /z/.",
19
+ "/ʃ/→/s/": "Sh-sound confusion - Practice /sh/ with tongue body back vs. /s/ with tongue tip up.",
20
+ "/ʃ/→/tʃ/": "Fricative to affricate - Practice sustained /sh/ vs. stop-release /ch/.",
21
+ "/tʃ/→/ʃ/": "Affricate to fricative - Practice stop component of /ch/ before fricative release.",
22
+ "/tʃ/→/ts/": "Affricate substitution - Practice /ch/ with proper tongue placement and air release.",
23
+ "generic": "Substitution error for {phoneme}. Practice correct articulator placement with mirror feedback. Start in isolation, then syllables, then words."
24
+ },
25
+ "omissions": {
26
+ "/r/": "Practice /r/ in isolation, then in CV syllables (ra, re, ri, ro, ru). Focus on tongue curl and lip position. Use visual cues and mirror feedback.",
27
+ "/l/": "Lateral tongue placement - practice with tongue tip up to alveolar ridge. Start with /l/ in isolation, then blend into words.",
28
+ "/s/": "Practice /s/ with tongue tip placement, use mirror to check position. Start in isolation, then fricative-only words (sss, see, say).",
29
+ "/k/": "Practice velar /k/ with tongue back position. Use mirror to see tongue placement. Start with /k/ in isolation.",
30
+ "/g/": "Practice voiced velar /g/ with tongue back. Start in isolation, then CV syllables.",
31
+ "/t/": "Practice alveolar /t/ with tongue tip contact. Use mirror feedback for placement.",
32
+ "/d/": "Practice voiced alveolar /d/ with tongue tip contact.",
33
+ "/p/": "Practice bilabial /p/ with lip closure. Use mirror to ensure proper closure.",
34
+ "/b/": "Practice voiced bilabial /b/ with lip closure.",
35
+ "/f/": "Practice labiodental /f/ with lower lip to upper teeth contact.",
36
+ "/v/": "Practice voiced labiodental /v/ with lip-teeth contact.",
37
+ "/θ/": "Practice interdental /θ/ with tongue tip between teeth.",
38
+ "/ð/": "Practice voiced interdental /ð/ with tongue tip between teeth.",
39
+ "/ʃ/": "Practice /sh/ with tongue body back and lip rounding.",
40
+ "/tʃ/": "Practice /ch/ with stop then fricative release.",
41
+ "/dʒ/": "Practice /j/ (as in judge) with stop then fricative release.",
42
+ "generic": "Omission error for {phoneme}. Say the sound separately first, then blend into syllables, then words. Use visual cues and mirror feedback."
43
+ },
44
+ "distortions": {
45
+ "/s/": "Sibilant clarity - use mirror feedback, ensure tongue tip is up and air stream is central. Practice sustained /s/ sound.",
46
+ "/z/": "Voiced sibilant clarity - practice with voicing, ensure proper tongue placement.",
47
+ "/ʃ/": "Fricative voicing exercise - practice /sh/ vs /s/ distinction. Focus on tongue body position.",
48
+ "/tʃ/": "Affricate clarity - practice stop component then fricative release. Ensure proper timing.",
49
+ "/r/": "Rhotacism - practice tongue position and lip rounding control. Use mirror to see tongue curl.",
50
+ "/l/": "Lateral clarity - ensure tongue tip is up and air flows over sides of tongue.",
51
+ "/θ/": "Interdental clarity - practice with tongue tip between teeth, ensure air stream is correct.",
52
+ "/ð/": "Voiced interdental clarity - practice with voicing and proper tongue placement.",
53
+ "/k/": "Velar stop clarity - practice with proper tongue back placement and release.",
54
+ "/g/": "Voiced velar stop clarity - practice with voicing and tongue placement.",
55
+ "/t/": "Alveolar stop clarity - practice with tongue tip contact and clean release.",
56
+ "/d/": "Voiced alveolar stop clarity - practice with voicing and proper contact.",
57
+ "generic": "Distortion error for {phoneme}. Use mirror feedback and watch articulator position carefully. Practice in isolation first, then in words."
58
+ }
59
+ }
60
+
diagnosis/ai_engine/detect_stuttering.py CHANGED
@@ -30,6 +30,9 @@ from scipy.spatial import ConvexHull
30
  from scipy.stats import pearsonr
31
  from difflib import SequenceMatcher
32
 
 
 
 
33
  logger = logging.getLogger(__name__)
34
 
35
  # === CONFIGURATION ===
 
30
  from scipy.stats import pearsonr
31
  from difflib import SequenceMatcher
32
 
33
+ import warnings
34
+ warnings.filterwarnings("ignore", message="CUDA requested but not available")
35
+
36
  logger = logging.getLogger(__name__)
37
 
38
  # === CONFIGURATION ===
inference/inference_pipeline.py CHANGED
@@ -2,7 +2,7 @@
2
  Inference Pipeline for Speech Pathology Diagnosis
3
 
4
  This module provides the inference pipeline for real-time and batch processing
5
- of audio for fluency and articulation analysis.
6
  """
7
 
8
  import logging
@@ -13,7 +13,8 @@ import soundfile as sf
13
  from typing import Dict, List, Optional, Tuple, Union
14
  from pathlib import Path
15
  import time
16
- from dataclasses import dataclass
 
17
 
18
  from models.speech_pathology_model import SpeechPathologyClassifier, load_speech_pathology_model
19
  from config import AudioConfig, ModelConfig, InferenceConfig
@@ -22,60 +23,50 @@ logger = logging.getLogger(__name__)
22
 
23
 
24
  @dataclass
25
- class PredictionResult:
26
  """
27
- Container for prediction results.
28
 
29
  Attributes:
30
- fluency_score: Probability of fluent speech (0-1)
 
 
31
  articulation_class: Class index (0-3)
32
- articulation_class_name: Class name string
33
- articulation_probs: Probabilities for all 4 classes
34
  confidence: Overall confidence score
35
- timestamp_ms: Timestamp in milliseconds (for streaming)
36
- frame_index: Frame index (for streaming)
37
  """
38
- fluency_score: float
 
 
39
  articulation_class: int
40
- articulation_class_name: str
41
- articulation_probs: List[float]
42
  confidence: float
43
- timestamp_ms: Optional[float] = None
44
- frame_index: Optional[int] = None
45
 
46
 
47
  @dataclass
48
- class BatchPredictionResult:
49
  """
50
- Container for batch prediction results.
51
 
52
  Attributes:
53
- fluency_metrics: Dictionary with fluency statistics
54
- articulation_scores: List of articulation predictions per frame
55
- confidence: Overall confidence
56
- timestamps: List of timestamps in milliseconds
57
- frame_duration_ms: Duration of each frame in milliseconds
58
  """
59
- fluency_metrics: Dict[str, float]
60
- articulation_scores: List[Dict[str, Union[int, str, List[float]]]]
61
- confidence: float
62
- timestamps: List[float]
63
- frame_duration_ms: float
64
 
65
 
66
  class InferencePipeline:
67
  """
68
- Inference pipeline for speech pathology diagnosis.
69
-
70
- Handles both batch processing (full audio files) and streaming
71
- (real-time chunk-by-chunk) inference with phone-level granularity.
72
 
73
- Features:
74
- - Batch inference for complete audio files
75
- - Streaming inference for real-time processing
76
- - Phone-level analysis (20ms frames)
77
- - Automatic audio preprocessing and normalization
78
- - Temporal smoothing of predictions
79
  """
80
 
81
  def __init__(
@@ -93,9 +84,6 @@ class InferencePipeline:
93
  audio_config: Audio processing configuration
94
  model_config: Model configuration
95
  inference_config: Inference configuration
96
-
97
- Raises:
98
- RuntimeError: If model cannot be loaded or initialized
99
  """
100
  # Load configurations
101
  from config import default_audio_config, default_model_config, default_inference_config
@@ -104,18 +92,17 @@ class InferencePipeline:
104
  self.model_config = model_config or default_model_config
105
  self.inference_config = inference_config or default_inference_config
106
 
107
- logger.info("Initializing InferencePipeline...")
108
- logger.info(f"Audio config: sample_rate={self.audio_config.sample_rate}, "
109
- f"chunk_duration_ms={self.audio_config.chunk_duration_ms}")
110
- logger.info(f"Inference config: fluency_threshold={self.inference_config.fluency_threshold}, "
111
- f"articulation_threshold={self.inference_config.articulation_threshold}")
112
 
113
  # Initialize or use provided model
114
  if model is None:
115
  logger.info("Loading SpeechPathologyClassifier...")
116
  self.model = load_speech_pathology_model(
117
  model_name=self.model_config.model_name,
118
- classifier_hidden_dims=self.model_config.classifier_hidden_dims,
119
  dropout=self.model_config.dropout,
120
  device=self.model_config.device,
121
  use_fp16=self.model_config.use_fp16
@@ -127,16 +114,16 @@ class InferencePipeline:
127
  # Get processor for audio preprocessing
128
  self.processor = self.model.processor
129
 
130
- # Calculate frame size in samples
131
- self.frame_size_samples = int(
132
- self.audio_config.chunk_duration_ms * self.audio_config.sample_rate / 1000
133
  )
134
  self.hop_size_samples = int(
135
- self.audio_config.hop_length_ms * self.audio_config.sample_rate / 1000
136
  )
137
 
138
- logger.info(f"Frame size: {self.frame_size_samples} samples "
139
- f"({self.audio_config.chunk_duration_ms}ms)")
140
  logger.info("✅ InferencePipeline initialized successfully")
141
 
142
  def preprocess_audio(
@@ -153,335 +140,276 @@ class InferencePipeline:
153
 
154
  Returns:
155
  Preprocessed audio array normalized to [-1, 1] range
156
-
157
- Raises:
158
- ValueError: If audio cannot be loaded or processed
159
  """
160
  target_sr = target_sr or self.audio_config.sample_rate
161
 
162
  # Load audio if path provided
163
  if isinstance(audio, (str, Path)):
164
- try:
165
- audio_array, sr = librosa.load(str(audio), sr=target_sr, mono=True)
166
- logger.debug(f"Loaded audio from {audio}: shape={audio_array.shape}, sr={sr}")
167
- except Exception as e:
168
- logger.error(f"Failed to load audio from {audio}: {e}")
169
- raise ValueError(f"Cannot load audio file: {e}") from e
170
  else:
171
- audio_array = audio
172
- # Resample if needed
173
  if len(audio_array.shape) > 1:
174
- audio_array = librosa.to_mono(audio_array)
175
 
176
- # Normalize to [-1, 1] range
177
- if audio_array.max() > 1.0 or audio_array.min() < -1.0:
178
- max_val = np.abs(audio_array).max()
179
- if max_val > 0:
180
- audio_array = audio_array / max_val
181
- logger.debug("Normalized audio to [-1, 1] range")
182
 
183
  return audio_array
184
 
185
- def predict_batch(
186
  self,
187
- audio_path: Union[str, Path, np.ndarray],
188
- return_timestamps: bool = True,
189
- apply_smoothing: bool = True
190
- ) -> BatchPredictionResult:
191
  """
192
- Predict fluency and articulation for a complete audio file.
193
-
194
- Processes audio in overlapping frames for phone-level analysis.
195
 
196
  Args:
197
- audio_path: Path to audio file or audio array
198
- return_timestamps: Whether to include timestamps for each frame
199
- apply_smoothing: Whether to apply temporal smoothing
200
 
201
  Returns:
202
- BatchPredictionResult with aggregated metrics and per-frame predictions
203
-
204
- Raises:
205
- ValueError: If audio cannot be processed
206
- RuntimeError: If inference fails
207
  """
208
- logger.info(f"Starting batch prediction for: {audio_path}")
209
- start_time = time.time()
210
 
211
- try:
212
- # Preprocess audio
213
- audio_array = self.preprocess_audio(audio_path)
214
- duration_seconds = len(audio_array) / self.audio_config.sample_rate
215
- logger.info(f"Audio duration: {duration_seconds:.2f}s, "
216
- f"frames: {len(audio_array) // self.hop_size_samples}")
 
 
217
 
218
- # Process in overlapping frames
219
- predictions = []
220
- timestamps = []
221
 
222
- frame_idx = 0
223
- for start_sample in range(0, len(audio_array) - self.frame_size_samples + 1,
224
- self.hop_size_samples):
225
- # Extract frame
226
- frame = audio_array[start_sample:start_sample + self.frame_size_samples]
 
 
 
 
 
 
 
 
 
 
 
227
 
228
- # Skip if frame is too short
229
- if len(frame) < self.inference_config.min_chunk_duration_ms * \
230
- self.audio_config.sample_rate / 1000:
231
- continue
232
 
233
- # Predict on frame
234
- frame_timestamp_ms = (start_sample / self.audio_config.sample_rate) * 1000
 
 
235
 
236
- try:
237
- pred_result = self._predict_single_frame(
238
- frame,
239
- frame_idx=frame_idx,
240
- timestamp_ms=frame_timestamp_ms if return_timestamps else None
241
- )
242
- predictions.append(pred_result)
243
- if return_timestamps:
244
- timestamps.append(frame_timestamp_ms)
245
- frame_idx += 1
246
- except Exception as e:
247
- logger.warning(f"Failed to predict frame {frame_idx}: {e}")
248
- continue
249
-
250
- if not predictions:
251
- raise ValueError("No valid frames extracted from audio")
252
-
253
- logger.info(f"Processed {len(predictions)} frames")
254
-
255
- # Apply temporal smoothing if requested
256
- if apply_smoothing and len(predictions) > 1:
257
- predictions = self._apply_temporal_smoothing(predictions)
258
-
259
- # Aggregate results
260
- result = self._aggregate_predictions(predictions, timestamps)
261
-
262
- elapsed_time = time.time() - start_time
263
- logger.info(f"Batch prediction completed in {elapsed_time:.2f}s "
264
- f"({duration_seconds/elapsed_time:.1f}x real-time)")
265
-
266
- return result
267
 
268
- except Exception as e:
269
- logger.error(f"Batch prediction failed: {e}", exc_info=True)
270
- raise RuntimeError(f"Batch prediction failed: {e}") from e
 
 
 
 
 
 
 
 
271
 
272
- def predict_streaming(
273
  self,
274
- audio_chunk: np.ndarray,
275
- frame_index: Optional[int] = None,
276
- timestamp_ms: Optional[float] = None
277
- ) -> PredictionResult:
278
  """
279
- Predict fluency and articulation for a single audio chunk (streaming).
280
-
281
- Designed for real-time processing with <200ms latency requirement.
282
 
283
  Args:
284
- audio_chunk: Audio chunk array (should match chunk_duration_ms)
285
- frame_index: Optional frame index for tracking
286
- timestamp_ms: Optional timestamp in milliseconds
287
 
288
  Returns:
289
- PredictionResult for the chunk
290
-
291
- Raises:
292
- ValueError: If chunk is invalid
293
- RuntimeError: If inference fails
294
  """
 
 
295
  try:
296
- # Preprocess chunk
297
- chunk = self.preprocess_audio(audio_chunk)
 
298
 
299
- # Validate chunk size
300
- expected_samples = self.frame_size_samples
301
- if len(chunk) < expected_samples * 0.5: # Allow some tolerance
302
- logger.warning(f"Chunk size {len(chunk)} is smaller than expected {expected_samples}")
303
 
304
- # Pad or truncate to expected size
305
- if len(chunk) < expected_samples:
306
- chunk = np.pad(chunk, (0, expected_samples - len(chunk)), mode='constant')
307
- elif len(chunk) > expected_samples:
308
- chunk = chunk[:expected_samples]
 
 
 
 
309
 
310
- # Predict
311
- return self._predict_single_frame(
312
- chunk,
313
- frame_index=frame_index,
314
- timestamp_ms=timestamp_ms
315
- )
316
 
317
- except Exception as e:
318
- logger.error(f"Streaming prediction failed: {e}", exc_info=True)
319
- raise RuntimeError(f"Streaming prediction failed: {e}") from e
320
-
321
- def _predict_single_frame(
322
- self,
323
- frame: np.ndarray,
324
- frame_index: Optional[int] = None,
325
- timestamp_ms: Optional[float] = None
326
- ) -> PredictionResult:
327
- """
328
- Predict on a single audio frame.
329
-
330
- Args:
331
- frame: Audio frame array
332
- frame_index: Optional frame index
333
- timestamp_ms: Optional timestamp
334
-
335
- Returns:
336
- PredictionResult
337
- """
338
- # Convert to tensor
339
- audio_tensor = torch.from_numpy(frame).float()
340
-
341
- # Process through model
342
- with torch.no_grad():
343
- # Use processor to prepare input
344
- inputs = self.processor(
345
- audio_tensor,
346
- sampling_rate=self.audio_config.sample_rate,
347
- return_tensors="pt",
348
- padding=True
349
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
- # Move to device
352
- input_values = inputs.input_values.to(self.model.device)
 
353
 
354
- # Get predictions
355
- outputs = self.model.predict(
356
- input_values.squeeze(0), # Remove batch dimension if needed
357
- sample_rate=self.audio_config.sample_rate,
358
- return_dict=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  )
360
-
361
- return PredictionResult(
362
- fluency_score=outputs["fluency_score"],
363
- articulation_class=outputs["articulation_class"],
364
- articulation_class_name=outputs["articulation_class_name"],
365
- articulation_probs=outputs["articulation_probs"],
366
- confidence=outputs["confidence"],
367
- timestamp_ms=timestamp_ms,
368
- frame_index=frame_index
369
- )
370
 
371
- def _apply_temporal_smoothing(
372
  self,
373
- predictions: List[PredictionResult],
374
- window_size: Optional[int] = None
375
- ) -> List[PredictionResult]:
 
376
  """
377
- Apply temporal smoothing to predictions using moving average.
378
 
379
  Args:
380
- predictions: List of prediction results
381
- window_size: Smoothing window size (defaults to inference_config.smoothing_window)
 
382
 
383
  Returns:
384
- List of smoothed predictions
385
  """
386
- window_size = window_size or self.inference_config.smoothing_window
387
-
388
- if len(predictions) <= window_size:
389
- return predictions
390
-
391
- smoothed = []
392
- for i in range(len(predictions)):
393
- # Get window indices
394
- start_idx = max(0, i - window_size // 2)
395
- end_idx = min(len(predictions), i + window_size // 2 + 1)
396
- window_preds = predictions[start_idx:end_idx]
397
-
398
- # Average fluency scores
399
- avg_fluency = np.mean([p.fluency_score for p in window_preds])
400
-
401
- # Average articulation probabilities
402
- avg_articulation_probs = np.mean(
403
- [p.articulation_probs for p in window_preds],
404
- axis=0
405
- )
406
-
407
- # Get most likely class from averaged probabilities
408
- articulation_class = int(np.argmax(avg_articulation_probs))
409
- articulation_class_name = self.model.get_articulation_class_name(articulation_class)
410
-
411
- # Calculate confidence
412
- confidence = (avg_fluency + avg_articulation_probs[articulation_class]) / 2.0
413
-
414
- smoothed.append(PredictionResult(
415
- fluency_score=float(avg_fluency),
416
- articulation_class=articulation_class,
417
- articulation_class_name=articulation_class_name,
418
- articulation_probs=avg_articulation_probs.tolist(),
419
- confidence=float(confidence),
420
- timestamp_ms=predictions[i].timestamp_ms,
421
- frame_index=predictions[i].frame_index
422
- ))
423
-
424
- return smoothed
425
 
426
- def _aggregate_predictions(
427
  self,
428
- predictions: List[PredictionResult],
429
- timestamps: Optional[List[float]] = None
430
- ) -> BatchPredictionResult:
 
431
  """
432
- Aggregate frame-level predictions into batch results.
433
 
434
  Args:
435
- predictions: List of frame predictions
436
- timestamps: Optional list of timestamps
 
437
 
438
  Returns:
439
- BatchPredictionResult with aggregated metrics
440
  """
441
- if not predictions:
442
- raise ValueError("Cannot aggregate empty predictions")
443
-
444
- # Calculate fluency metrics
445
- fluency_scores = [p.fluency_score for p in predictions]
446
- fluency_metrics = {
447
- "mean": float(np.mean(fluency_scores)),
448
- "std": float(np.std(fluency_scores)),
449
- "min": float(np.min(fluency_scores)),
450
- "max": float(np.max(fluency_scores)),
451
- "median": float(np.median(fluency_scores)),
452
- "fluent_frames_ratio": float(np.mean([s >= self.inference_config.fluency_threshold
453
- for s in fluency_scores]))
454
- }
455
-
456
- # Articulation scores per frame
457
- articulation_scores = [
458
- {
459
- "class": p.articulation_class,
460
- "class_name": p.articulation_class_name,
461
- "probs": p.articulation_probs,
462
- "confidence": p.confidence
463
- }
464
- for p in predictions
465
- ]
466
 
467
- # Overall confidence
468
- overall_confidence = float(np.mean([p.confidence for p in predictions]))
469
 
470
- # Timestamps
471
- if timestamps is None:
472
- timestamps = [p.timestamp_ms for p in predictions if p.timestamp_ms is not None]
 
 
 
 
 
 
 
 
 
 
473
 
474
- return BatchPredictionResult(
475
- fluency_metrics=fluency_metrics,
476
- articulation_scores=articulation_scores,
477
- confidence=overall_confidence,
478
- timestamps=timestamps or [],
479
- frame_duration_ms=self.audio_config.chunk_duration_ms
480
- )
481
 
482
 
483
  def create_inference_pipeline(
484
- model_path: Optional[str] = None,
485
  audio_config: Optional[AudioConfig] = None,
486
  model_config: Optional[ModelConfig] = None,
487
  inference_config: Optional[InferenceConfig] = None
@@ -490,31 +418,17 @@ def create_inference_pipeline(
490
  Factory function to create an InferencePipeline instance.
491
 
492
  Args:
493
- model_path: Optional path to saved model checkpoint
494
- audio_config: Audio configuration
495
- model_config: Model configuration
496
- inference_config: Inference configuration
497
 
498
  Returns:
499
- InferencePipeline instance
500
  """
501
- model = None
502
- if model_path:
503
- logger.info(f"Loading model from: {model_path}")
504
- model_config = model_config or ModelConfig()
505
- model = load_speech_pathology_model(
506
- model_name=model_config.model_name,
507
- classifier_hidden_dims=model_config.classifier_hidden_dims,
508
- dropout=model_config.dropout,
509
- device=model_config.device,
510
- use_fp16=model_config.use_fp16,
511
- model_path=model_path
512
- )
513
-
514
  return InferencePipeline(
515
  model=model,
516
  audio_config=audio_config,
517
  model_config=model_config,
518
  inference_config=inference_config
519
  )
520
-
 
2
  Inference Pipeline for Speech Pathology Diagnosis
3
 
4
  This module provides the inference pipeline for real-time and batch processing
5
+ of audio for fluency and articulation analysis using sliding window approach.
6
  """
7
 
8
  import logging
 
13
  from typing import Dict, List, Optional, Tuple, Union
14
  from pathlib import Path
15
  import time
16
+ from dataclasses import dataclass, field
17
+ from collections import deque
18
 
19
  from models.speech_pathology_model import SpeechPathologyClassifier, load_speech_pathology_model
20
  from config import AudioConfig, ModelConfig, InferenceConfig
 
23
 
24
 
25
  @dataclass
26
+ class FramePrediction:
27
  """
28
+ Container for a single frame prediction.
29
 
30
  Attributes:
31
+ time: Timestamp in seconds
32
+ fluency_prob: Probability of stutter (0-1)
33
+ fluency_label: 'normal' or 'stutter'
34
  articulation_class: Class index (0-3)
35
+ articulation_label: Class name
 
36
  confidence: Overall confidence score
 
 
37
  """
38
+ time: float
39
+ fluency_prob: float
40
+ fluency_label: str
41
  articulation_class: int
42
+ articulation_label: str
 
43
  confidence: float
 
 
44
 
45
 
46
  @dataclass
47
+ class PhoneLevelResult:
48
  """
49
+ Container for phone-level prediction results.
50
 
51
  Attributes:
52
+ frame_predictions: List of frame-level predictions
53
+ aggregate: Aggregated statistics
 
 
 
54
  """
55
+ frame_predictions: List[FramePrediction]
56
+ aggregate: Dict[str, Union[float, int, str]]
57
+ duration: float
58
+ num_frames: int
 
59
 
60
 
61
  class InferencePipeline:
62
  """
63
+ Inference pipeline for speech pathology diagnosis using sliding window approach.
 
 
 
64
 
65
+ Architecture:
66
+ - 1-second sliding windows (minimum for Wav2Vec2)
67
+ - 10ms hop size for phone-level resolution
68
+ - Wav2Vec2 feature extraction per window
69
+ - Multi-task classifier for fluency + articulation
 
70
  """
71
 
72
  def __init__(
 
84
  audio_config: Audio processing configuration
85
  model_config: Model configuration
86
  inference_config: Inference configuration
 
 
 
87
  """
88
  # Load configurations
89
  from config import default_audio_config, default_model_config, default_inference_config
 
92
  self.model_config = model_config or default_model_config
93
  self.inference_config = inference_config or default_inference_config
94
 
95
+ logger.info("Initializing InferencePipeline (sliding window)...")
96
+ logger.info(f"Window size: {self.inference_config.window_size_ms}ms")
97
+ logger.info(f"Hop size: {self.inference_config.hop_size_ms}ms")
98
+ logger.info(f"Frame rate: {self.inference_config.frame_rate} fps")
 
99
 
100
  # Initialize or use provided model
101
  if model is None:
102
  logger.info("Loading SpeechPathologyClassifier...")
103
  self.model = load_speech_pathology_model(
104
  model_name=self.model_config.model_name,
105
+ classifier_hidden_dims=[512, 256], # 1024 → 512 → 256
106
  dropout=self.model_config.dropout,
107
  device=self.model_config.device,
108
  use_fp16=self.model_config.use_fp16
 
114
  # Get processor for audio preprocessing
115
  self.processor = self.model.processor
116
 
117
+ # Calculate window and hop sizes in samples
118
+ self.window_size_samples = int(
119
+ self.inference_config.window_size_ms * self.audio_config.sample_rate / 1000
120
  )
121
  self.hop_size_samples = int(
122
+ self.inference_config.hop_size_ms * self.audio_config.sample_rate / 1000
123
  )
124
 
125
+ logger.info(f"Window size: {self.window_size_samples} samples")
126
+ logger.info(f"Hop size: {self.hop_size_samples} samples")
127
  logger.info("✅ InferencePipeline initialized successfully")
128
 
129
  def preprocess_audio(
 
140
 
141
  Returns:
142
  Preprocessed audio array normalized to [-1, 1] range
 
 
 
143
  """
144
  target_sr = target_sr or self.audio_config.sample_rate
145
 
146
  # Load audio if path provided
147
  if isinstance(audio, (str, Path)):
148
+ audio_path = Path(audio)
149
+ if not audio_path.exists():
150
+ raise ValueError(f"Audio file not found: {audio_path}")
151
+
152
+ audio_array, sr = librosa.load(str(audio_path), sr=target_sr, mono=True)
 
153
  else:
154
+ audio_array = np.array(audio, dtype=np.float32)
 
155
  if len(audio_array.shape) > 1:
156
+ audio_array = np.mean(audio_array, axis=0) # Convert to mono
157
 
158
+ # Normalize to [-1, 1]
159
+ max_val = np.abs(audio_array).max()
160
+ if max_val > 0:
161
+ audio_array = audio_array / max_val
 
 
162
 
163
  return audio_array
164
 
165
+ def get_phone_level_features(
166
  self,
167
+ audio: np.ndarray
168
+ ) -> Tuple[torch.Tensor, np.ndarray]:
 
 
169
  """
170
+ Extract phone-level features using sliding window approach.
 
 
171
 
172
  Args:
173
+ audio: Preprocessed audio array (16kHz, mono, normalized)
 
 
174
 
175
  Returns:
176
+ Tuple of (frame_features, frame_times):
177
+ - frame_features: Tensor of shape (num_frames, 1024)
178
+ - frame_times: Array of timestamps in seconds
 
 
179
  """
180
+ num_samples = len(audio)
181
+ num_windows = max(1, (num_samples - self.window_size_samples) // self.hop_size_samples + 1)
182
 
183
+ frame_features_list = []
184
+ frame_times = []
185
+
186
+ logger.info(f"Extracting features from {num_windows} windows...")
187
+
188
+ for i in range(num_windows):
189
+ start_sample = i * self.hop_size_samples
190
+ end_sample = min(start_sample + self.window_size_samples, num_samples)
191
 
192
+ # Extract window
193
+ window = audio[start_sample:end_sample]
 
194
 
195
+ # Pad if necessary (at the end of audio)
196
+ if len(window) < self.window_size_samples:
197
+ padding = self.window_size_samples - len(window)
198
+ window = np.pad(window, (0, padding), mode='constant')
199
+
200
+ # Convert to tensor
201
+ audio_tensor = torch.from_numpy(window).float()
202
+
203
+ # Process through feature extractor
204
+ with torch.no_grad():
205
+ inputs = self.processor(
206
+ audio_tensor,
207
+ sampling_rate=self.audio_config.sample_rate,
208
+ return_tensors="pt",
209
+ padding=True
210
+ )
211
 
212
+ input_values = inputs.input_values.to(self.model.device)
 
 
 
213
 
214
+ # Extract Wav2Vec2 features
215
+ wav2vec2_outputs = self.model.wav2vec2_model(
216
+ input_values=input_values
217
+ )
218
 
219
+ # Get features: (batch_size, seq_len, 1024)
220
+ features = wav2vec2_outputs.last_hidden_state
221
+
222
+ # Pool to single vector: mean over sequence length
223
+ pooled_features = torch.mean(features, dim=1) # (batch_size, 1024)
224
+ frame_features_list.append(pooled_features.cpu())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
+ # Calculate timestamp (center of window)
227
+ frame_time = (start_sample + self.window_size_samples / 2) / self.audio_config.sample_rate
228
+ frame_times.append(frame_time)
229
+
230
+ # Stack all features
231
+ frame_features = torch.cat(frame_features_list, dim=0) # (num_frames, 1024)
232
+ frame_times = np.array(frame_times)
233
+
234
+ logger.info(f"Extracted {len(frame_features)} frame features")
235
+
236
+ return frame_features, frame_times
237
 
238
+ def predict_phone_level(
239
  self,
240
+ audio: Union[np.ndarray, str, Path],
241
+ return_timestamps: bool = True
242
+ ) -> PhoneLevelResult:
 
243
  """
244
+ Predict fluency and articulation at phone-level resolution.
 
 
245
 
246
  Args:
247
+ audio: Audio array, file path, or Path object
248
+ return_timestamps: Whether to include timestamps in results
 
249
 
250
  Returns:
251
+ PhoneLevelResult with frame-level predictions and aggregates
 
 
 
 
252
  """
253
+ start_time = time.time()
254
+
255
  try:
256
+ # Preprocess audio
257
+ audio_array = self.preprocess_audio(audio)
258
+ duration = len(audio_array) / self.audio_config.sample_rate
259
 
260
+ logger.info(f"Processing audio: {duration:.2f}s")
 
 
 
261
 
262
+ # Check minimum length
263
+ if duration < self.inference_config.minimum_audio_length:
264
+ logger.warning(f"Audio shorter than minimum ({duration:.2f}s < {self.inference_config.minimum_audio_length}s), "
265
+ f"padding to minimum")
266
+ min_samples = int(self.inference_config.minimum_audio_length * self.audio_config.sample_rate)
267
+ if len(audio_array) < min_samples:
268
+ padding = min_samples - len(audio_array)
269
+ audio_array = np.pad(audio_array, (0, padding), mode='constant')
270
+ duration = len(audio_array) / self.audio_config.sample_rate
271
 
272
+ # Extract phone-level features
273
+ frame_features, frame_times = self.get_phone_level_features(audio_array)
 
 
 
 
274
 
275
+ # Move features to device
276
+ frame_features = frame_features.to(self.model.device)
277
+
278
+ # Predict using classifier
279
+ self.model.eval()
280
+ with torch.no_grad():
281
+ # Pass through shared layers and heads
282
+ shared_features = self.model.classifier_head.shared_layers(frame_features)
283
+
284
+ # Get predictions from all heads
285
+ fluency_logits = self.model.classifier_head.fluency_head(shared_features)
286
+ articulation_logits = self.model.classifier_head.articulation_head(shared_features)
287
+ full_logits = self.model.classifier_head.full_head(shared_features)
288
+
289
+ # Apply softmax
290
+ fluency_probs = torch.softmax(fluency_logits, dim=-1) # (num_frames, 2)
291
+ articulation_probs = torch.softmax(articulation_logits, dim=-1) # (num_frames, 4)
292
+ full_probs = torch.softmax(full_logits, dim=-1) # (num_frames, 8)
293
+
294
+ # Convert to numpy
295
+ fluency_probs = fluency_probs.cpu().numpy()
296
+ articulation_probs = articulation_probs.cpu().numpy()
297
+ full_probs = full_probs.cpu().numpy()
298
+
299
+ # Create frame predictions
300
+ frame_predictions = []
301
+ for i in range(len(frame_features)):
302
+ # Fluency: class 0 = normal, class 1 = stutter
303
+ fluency_prob_stutter = fluency_probs[i, 1]
304
+ fluency_label = 'stutter' if fluency_prob_stutter > self.inference_config.fluency_threshold else 'normal'
305
+
306
+ # Articulation: get class with highest probability
307
+ articulation_class = int(np.argmax(articulation_probs[i]))
308
+ articulation_label = self.model.get_articulation_class_name(articulation_class)
309
+
310
+ # Confidence: average of max probabilities
311
+ confidence = (np.max(fluency_probs[i]) + np.max(articulation_probs[i])) / 2.0
312
+
313
+ frame_pred = FramePrediction(
314
+ time=frame_times[i] if return_timestamps else 0.0,
315
+ fluency_prob=float(fluency_prob_stutter),
316
+ fluency_label=fluency_label,
317
+ articulation_class=articulation_class,
318
+ articulation_label=articulation_label,
319
+ confidence=float(confidence)
320
+ )
321
+ frame_predictions.append(frame_pred)
322
 
323
+ # Aggregate statistics
324
+ fluency_scores = [fp.fluency_prob for fp in frame_predictions]
325
+ articulation_classes = [fp.articulation_class for fp in frame_predictions]
326
 
327
+ aggregate = {
328
+ 'fluency_score': float(np.mean(fluency_scores)),
329
+ 'articulation_class': int(np.bincount(articulation_classes).argmax()),
330
+ 'articulation_label': self.model.get_articulation_class_name(
331
+ int(np.bincount(articulation_classes).argmax())
332
+ ),
333
+ 'num_frames': len(frame_predictions),
334
+ 'duration': duration
335
+ }
336
+
337
+ elapsed_time = time.time() - start_time
338
+ logger.info(f"Phone-level prediction completed in {elapsed_time:.2f}s "
339
+ f"({duration/elapsed_time:.1f}x real-time)")
340
+
341
+ return PhoneLevelResult(
342
+ frame_predictions=frame_predictions,
343
+ aggregate=aggregate,
344
+ duration=duration,
345
+ num_frames=len(frame_predictions)
346
  )
347
+
348
+ except Exception as e:
349
+ logger.error(f"Phone-level prediction failed: {e}", exc_info=True)
350
+ raise RuntimeError(f"Phone-level prediction failed: {e}") from e
 
 
 
 
 
 
351
 
352
+ def predict_batch(
353
  self,
354
+ audio_path: Union[str, Path],
355
+ return_timestamps: bool = True,
356
+ apply_smoothing: bool = True
357
+ ) -> PhoneLevelResult:
358
  """
359
+ Predict on audio file (batch processing).
360
 
361
  Args:
362
+ audio_path: Path to audio file
363
+ return_timestamps: Whether to include timestamps
364
+ apply_smoothing: Whether to apply temporal smoothing (not implemented yet)
365
 
366
  Returns:
367
+ PhoneLevelResult
368
  """
369
+ return self.predict_phone_level(audio_path, return_timestamps=return_timestamps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
371
+ def predict_streaming_chunk(
372
  self,
373
+ chunk: np.ndarray,
374
+ buffer: Optional[deque] = None,
375
+ timestamp: Optional[float] = None
376
+ ) -> Optional[FramePrediction]:
377
  """
378
+ Predict on streaming audio chunk.
379
 
380
  Args:
381
+ chunk: Audio chunk array
382
+ buffer: Optional buffer for maintaining sliding window
383
+ timestamp: Optional timestamp for the chunk
384
 
385
  Returns:
386
+ FramePrediction if enough data accumulated, None otherwise
387
  """
388
+ if buffer is None:
389
+ buffer = deque(maxlen=self.window_size_samples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ # Add chunk to buffer
392
+ buffer.extend(chunk)
393
 
394
+ # Check if we have enough data for a window
395
+ if len(buffer) >= self.window_size_samples:
396
+ # Extract window
397
+ window = np.array(list(buffer)[-self.window_size_samples:])
398
+
399
+ # Process window
400
+ try:
401
+ result = self.predict_phone_level(window, return_timestamps=False)
402
+ if result.frame_predictions:
403
+ return result.frame_predictions[-1] # Return latest frame
404
+ except Exception as e:
405
+ logger.warning(f"Streaming prediction failed: {e}")
406
+ return None
407
 
408
+ return None
 
 
 
 
 
 
409
 
410
 
411
  def create_inference_pipeline(
412
+ model: Optional[SpeechPathologyClassifier] = None,
413
  audio_config: Optional[AudioConfig] = None,
414
  model_config: Optional[ModelConfig] = None,
415
  inference_config: Optional[InferenceConfig] = None
 
418
  Factory function to create an InferencePipeline instance.
419
 
420
  Args:
421
+ model: Optional pre-initialized model
422
+ audio_config: Optional audio configuration
423
+ model_config: Optional model configuration
424
+ inference_config: Optional inference configuration
425
 
426
  Returns:
427
+ Initialized InferencePipeline instance
428
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  return InferencePipeline(
430
  model=model,
431
  audio_config=audio_config,
432
  model_config=model_config,
433
  inference_config=inference_config
434
  )
 
models/error_taxonomy.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Error Taxonomy for Speech Pathology Analysis
3
+
4
+ This module defines error types, severity levels, and therapy recommendations
5
+ for phoneme-level error detection.
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from enum import Enum
11
+ from typing import Optional, Dict, List
12
+ from pathlib import Path
13
+ from dataclasses import dataclass, field
14
+ from pydantic import BaseModel, Field
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class ErrorType(str, Enum):
20
+ """Types of articulation errors."""
21
+ NORMAL = "normal"
22
+ SUBSTITUTION = "substitution"
23
+ OMISSION = "omission"
24
+ DISTORTION = "distortion"
25
+
26
+
27
+ class SeverityLevel(str, Enum):
28
+ """Severity levels for errors."""
29
+ NONE = "none"
30
+ LOW = "low"
31
+ MEDIUM = "medium"
32
+ HIGH = "high"
33
+
34
+
35
+ @dataclass
36
+ class ErrorDetail:
37
+ """
38
+ Detailed error information for a phoneme.
39
+
40
+ Attributes:
41
+ phoneme: Expected phoneme symbol (e.g., '/s/')
42
+ error_type: Type of error (NORMAL, SUBSTITUTION, OMISSION, DISTORTION)
43
+ wrong_sound: For substitutions, the incorrect phoneme produced (e.g., '/θ/')
44
+ severity: Severity score (0.0-1.0)
45
+ confidence: Model confidence in the error detection (0.0-1.0)
46
+ therapy: Therapy recommendation text
47
+ frame_indices: List of frame indices where this error occurs
48
+ """
49
+ phoneme: str
50
+ error_type: ErrorType
51
+ wrong_sound: Optional[str] = None
52
+ severity: float = 0.0
53
+ confidence: float = 0.0
54
+ therapy: str = ""
55
+ frame_indices: List[int] = field(default_factory=list)
56
+
57
+
58
+ class ErrorDetailPydantic(BaseModel):
59
+ """Pydantic model for API serialization."""
60
+ phoneme: str
61
+ error_type: str
62
+ wrong_sound: Optional[str] = None
63
+ severity: float = Field(ge=0.0, le=1.0)
64
+ confidence: float = Field(ge=0.0, le=1.0)
65
+ therapy: str
66
+ frame_indices: List[int] = Field(default_factory=list)
67
+
68
+
69
+ class ErrorMapper:
70
+ """
71
+ Maps classifier outputs to error types and generates therapy recommendations.
72
+
73
+ Classifier output mapping (8 classes):
74
+ - Class 0: Normal articulation, normal fluency
75
+ - Class 1: Substitution, normal fluency
76
+ - Class 2: Omission, normal fluency
77
+ - Class 3: Distortion, normal fluency
78
+ - Class 4: Normal articulation, stutter
79
+ - Class 5: Substitution, stutter
80
+ - Class 6: Omission, stutter
81
+ - Class 7: Distortion, stutter
82
+ """
83
+
84
+ def __init__(self, therapy_db_path: Optional[str] = None):
85
+ """
86
+ Initialize the ErrorMapper.
87
+
88
+ Args:
89
+ therapy_db_path: Path to therapy recommendations JSON file.
90
+ If None, uses default location: data/therapy_recommendations.json
91
+ """
92
+ self.therapy_db: Dict = {}
93
+
94
+ # Default path
95
+ if therapy_db_path is None:
96
+ therapy_db_path = Path(__file__).parent.parent / "data" / "therapy_recommendations.json"
97
+ else:
98
+ therapy_db_path = Path(therapy_db_path)
99
+
100
+ # Load therapy database
101
+ try:
102
+ if therapy_db_path.exists():
103
+ with open(therapy_db_path, 'r', encoding='utf-8') as f:
104
+ self.therapy_db = json.load(f)
105
+ logger.info(f"✅ Loaded therapy database from {therapy_db_path}")
106
+ else:
107
+ logger.warning(f"Therapy database not found at {therapy_db_path}, using defaults")
108
+ self.therapy_db = self._get_default_therapy_db()
109
+ except Exception as e:
110
+ logger.error(f"Failed to load therapy database: {e}, using defaults")
111
+ self.therapy_db = self._get_default_therapy_db()
112
+
113
+ # Common substitution mappings (phoneme → likely wrong sound)
114
+ self.substitution_map: Dict[str, List[str]] = {
115
+ '/s/': ['/θ/', '/ʃ/', '/z/'], # lisp, sh-sound, voicing
116
+ '/r/': ['/w/', '/l/', '/ɹ/'], # rhotacism variants
117
+ '/l/': ['/w/', '/j/'], # liquid substitutions
118
+ '/k/': ['/t/', '/p/'], # velar → alveolar/bilabial
119
+ '/g/': ['/d/', '/b/'], # velar → alveolar/bilabial
120
+ '/θ/': ['/f/', '/s/'], # th → f or s
121
+ '/ð/': ['/v/', '/z/'], # voiced th → v or z
122
+ '/ʃ/': ['/s/', '/tʃ/'], # sh → s or ch
123
+ '/tʃ/': ['/ʃ/', '/ts/'], # ch → sh or ts
124
+ }
125
+
126
+ def map_classifier_output(
127
+ self,
128
+ class_id: int,
129
+ confidence: float,
130
+ phoneme: str,
131
+ fluency_label: str = "normal"
132
+ ) -> ErrorDetail:
133
+ """
134
+ Map classifier output to ErrorDetail.
135
+
136
+ Args:
137
+ class_id: Classifier output class (0-7)
138
+ confidence: Model confidence (0.0-1.0)
139
+ phoneme: Expected phoneme symbol
140
+ fluency_label: Fluency label ("normal" or "stutter")
141
+
142
+ Returns:
143
+ ErrorDetail object with error information
144
+ """
145
+ # Determine error type from class_id
146
+ if class_id == 0 or class_id == 4:
147
+ error_type = ErrorType.NORMAL
148
+ elif class_id == 1 or class_id == 5:
149
+ error_type = ErrorType.SUBSTITUTION
150
+ elif class_id == 2 or class_id == 6:
151
+ error_type = ErrorType.OMISSION
152
+ elif class_id == 3 or class_id == 7:
153
+ error_type = ErrorType.DISTORTION
154
+ else:
155
+ logger.warning(f"Unknown class_id: {class_id}, defaulting to NORMAL")
156
+ error_type = ErrorType.NORMAL
157
+
158
+ # Calculate severity from confidence
159
+ # Higher confidence in error = higher severity
160
+ if error_type == ErrorType.NORMAL:
161
+ severity = 0.0
162
+ else:
163
+ severity = confidence # Use confidence as severity proxy
164
+
165
+ # Get wrong sound for substitutions
166
+ wrong_sound = None
167
+ if error_type == ErrorType.SUBSTITUTION:
168
+ wrong_sound = self._map_substitution(phoneme, confidence)
169
+
170
+ # Get therapy recommendation
171
+ therapy = self.get_therapy(error_type, phoneme, wrong_sound)
172
+
173
+ return ErrorDetail(
174
+ phoneme=phoneme,
175
+ error_type=error_type,
176
+ wrong_sound=wrong_sound,
177
+ severity=severity,
178
+ confidence=confidence,
179
+ therapy=therapy
180
+ )
181
+
182
+ def _map_substitution(self, phoneme: str, confidence: float) -> Optional[str]:
183
+ """
184
+ Map substitution error to likely wrong sound.
185
+
186
+ Args:
187
+ phoneme: Expected phoneme
188
+ confidence: Model confidence
189
+
190
+ Returns:
191
+ Most likely wrong phoneme, or None if unknown
192
+ """
193
+ if phoneme in self.substitution_map:
194
+ # Return first (most common) substitution
195
+ return self.substitution_map[phoneme][0]
196
+ return None
197
+
198
+ def get_therapy(
199
+ self,
200
+ error_type: ErrorType,
201
+ phoneme: str,
202
+ wrong_sound: Optional[str] = None
203
+ ) -> str:
204
+ """
205
+ Get therapy recommendation for an error.
206
+
207
+ Args:
208
+ error_type: Type of error
209
+ phoneme: Expected phoneme
210
+ wrong_sound: For substitutions, the wrong sound produced
211
+
212
+ Returns:
213
+ Therapy recommendation text
214
+ """
215
+ if error_type == ErrorType.NORMAL:
216
+ return "No therapy needed - production is correct."
217
+
218
+ # Build lookup key
219
+ if error_type == ErrorType.SUBSTITUTION and wrong_sound:
220
+ key = f"{phoneme}→{wrong_sound}"
221
+ if "substitutions" in self.therapy_db and key in self.therapy_db["substitutions"]:
222
+ return self.therapy_db["substitutions"][key]
223
+
224
+ # Fallback to generic recommendations
225
+ if error_type == ErrorType.SUBSTITUTION:
226
+ if "substitutions" in self.therapy_db and "generic" in self.therapy_db["substitutions"]:
227
+ return self.therapy_db["substitutions"]["generic"].replace("{phoneme}", phoneme)
228
+ return f"Substitution error for {phoneme}. Practice correct articulator placement."
229
+
230
+ elif error_type == ErrorType.OMISSION:
231
+ if "omissions" in self.therapy_db and phoneme in self.therapy_db["omissions"]:
232
+ return self.therapy_db["omissions"][phoneme]
233
+ if "omissions" in self.therapy_db and "generic" in self.therapy_db["omissions"]:
234
+ return self.therapy_db["omissions"]["generic"].replace("{phoneme}", phoneme)
235
+ return f"Omission error for {phoneme}. Practice saying the sound separately first."
236
+
237
+ elif error_type == ErrorType.DISTORTION:
238
+ if "distortions" in self.therapy_db and phoneme in self.therapy_db["distortions"]:
239
+ return self.therapy_db["distortions"][phoneme]
240
+ if "distortions" in self.therapy_db and "generic" in self.therapy_db["distortions"]:
241
+ return self.therapy_db["distortions"]["generic"].replace("{phoneme}", phoneme)
242
+ return f"Distortion error for {phoneme}. Use mirror feedback and watch articulator position."
243
+
244
+ return "Consult with speech-language pathologist for personalized therapy plan."
245
+
246
+ def get_severity_level(self, severity: float) -> SeverityLevel:
247
+ """
248
+ Convert severity score to severity level.
249
+
250
+ Args:
251
+ severity: Severity score (0.0-1.0)
252
+
253
+ Returns:
254
+ SeverityLevel enum
255
+ """
256
+ if severity == 0.0:
257
+ return SeverityLevel.NONE
258
+ elif severity < 0.3:
259
+ return SeverityLevel.LOW
260
+ elif severity < 0.7:
261
+ return SeverityLevel.MEDIUM
262
+ else:
263
+ return SeverityLevel.HIGH
264
+
265
+ def _get_default_therapy_db(self) -> Dict:
266
+ """Get default therapy database if file not found."""
267
+ return {
268
+ "substitutions": {
269
+ "/s/→/θ/": "Lisp - Use tongue tip placement behind upper teeth. Practice /s/ in isolation.",
270
+ "/r/→/w/": "Rhotacism - Practice tongue position: curl tongue back, avoid lip rounding.",
271
+ "/r/→/l/": "Rhotacism - Focus on tongue tip position vs. tongue body placement.",
272
+ "generic": "Substitution error for {phoneme}. Practice correct articulator placement with mirror feedback."
273
+ },
274
+ "omissions": {
275
+ "/r/": "Practice /r/ in isolation, then in CV syllables (ra, re, ri, ro, ru).",
276
+ "/l/": "Lateral tongue placement - practice with tongue tip up to alveolar ridge.",
277
+ "/s/": "Practice /s/ with tongue tip placement, use mirror to check position.",
278
+ "generic": "Omission error for {phoneme}. Say the sound separately first, then blend into words."
279
+ },
280
+ "distortions": {
281
+ "/s/": "Sibilant clarity - use mirror feedback, ensure tongue tip is up and air stream is central.",
282
+ "/ʃ/": "Fricative voicing exercise - practice /sh/ vs /s/ distinction.",
283
+ "/r/": "Rhotacism - practice tongue position and lip rounding control.",
284
+ "generic": "Distortion error for {phoneme}. Use mirror feedback and watch articulator position carefully."
285
+ }
286
+ }
287
+
288
+
289
+ # Unit test function
290
+ def test_error_mapper():
291
+ """Test the ErrorMapper."""
292
+ print("Testing ErrorMapper...")
293
+
294
+ mapper = ErrorMapper()
295
+
296
+ # Test 1: Normal (class 0)
297
+ error = mapper.map_classifier_output(0, 0.95, "/k/")
298
+ assert error.error_type == ErrorType.NORMAL
299
+ assert error.severity == 0.0
300
+ print(f"✅ Normal error: {error.error_type}, therapy: {error.therapy[:50]}...")
301
+
302
+ # Test 2: Substitution (class 1)
303
+ error = mapper.map_classifier_output(1, 0.78, "/s/")
304
+ assert error.error_type == ErrorType.SUBSTITUTION
305
+ assert error.wrong_sound is not None
306
+ print(f"✅ Substitution error: {error.error_type}, wrong_sound: {error.wrong_sound}")
307
+ print(f" Therapy: {error.therapy[:80]}...")
308
+
309
+ # Test 3: Omission (class 2)
310
+ error = mapper.map_classifier_output(2, 0.85, "/r/")
311
+ assert error.error_type == ErrorType.OMISSION
312
+ print(f"✅ Omission error: {error.error_type}")
313
+ print(f" Therapy: {error.therapy[:80]}...")
314
+
315
+ # Test 4: Distortion (class 3)
316
+ error = mapper.map_classifier_output(3, 0.65, "/s/")
317
+ assert error.error_type == ErrorType.DISTORTION
318
+ print(f"✅ Distortion error: {error.error_type}")
319
+ print(f" Therapy: {error.therapy[:80]}...")
320
+
321
+ # Test 5: Severity levels
322
+ assert mapper.get_severity_level(0.0) == SeverityLevel.NONE
323
+ assert mapper.get_severity_level(0.2) == SeverityLevel.LOW
324
+ assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM
325
+ assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH
326
+ print("✅ Severity level mapping correct")
327
+
328
+ print("\n✅ All tests passed!")
329
+
330
+
331
+ if __name__ == "__main__":
332
+ test_error_mapper()
333
+
models/phoneme_mapper.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Phoneme Mapper for Speech Pathology Analysis
3
+
4
+ This module provides grapheme-to-phoneme (G2P) conversion and alignment
5
+ of phonemes to audio frames for phone-level error detection.
6
+ """
7
+
8
+ import logging
9
+ from typing import List, Tuple, Optional, Dict
10
+ from dataclasses import dataclass
11
+ import numpy as np
12
+
13
+ try:
14
+ import g2p_en
15
+ G2P_AVAILABLE = True
16
+ except ImportError:
17
+ G2P_AVAILABLE = False
18
+ logging.warning("g2p_en not available. Install with: pip install g2p-en")
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ @dataclass
24
+ class PhonemeSegment:
25
+ """
26
+ Represents a phoneme segment with timing information.
27
+
28
+ Attributes:
29
+ phoneme: Phoneme symbol (e.g., '/r/', '/k/')
30
+ start_time: Start time in seconds
31
+ end_time: End time in seconds
32
+ duration: Duration in seconds
33
+ frame_start: Starting frame index
34
+ frame_end: Ending frame index (exclusive)
35
+ """
36
+ phoneme: str
37
+ start_time: float
38
+ end_time: float
39
+ duration: float
40
+ frame_start: int
41
+ frame_end: int
42
+
43
+
44
+ class PhonemeMapper:
45
+ """
46
+ Maps text to phonemes and aligns them to audio frames.
47
+
48
+ Uses g2p_en library for English grapheme-to-phoneme conversion.
49
+ Aligns phonemes to 20ms frames for phone-level analysis.
50
+
51
+ Example:
52
+ >>> mapper = PhonemeMapper()
53
+ >>> phonemes = mapper.text_to_phonemes("robot")
54
+ >>> # Returns: [('/r/', 0.0), ('/o/', 0.1), ('/b/', 0.2), ('/o/', 0.3), ('/t/', 0.4)]
55
+ >>> frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25, frame_duration_ms=20)
56
+ >>> # Returns: ['/r/', '/r/', '/r/', '/o/', '/o/', '/b/', '/b/', ...]
57
+ """
58
+
59
+ def __init__(self, frame_duration_ms: int = 20, sample_rate: int = 16000):
60
+ """
61
+ Initialize the PhonemeMapper.
62
+
63
+ Args:
64
+ frame_duration_ms: Duration of each frame in milliseconds (default: 20ms)
65
+ sample_rate: Audio sample rate in Hz (default: 16000)
66
+
67
+ Raises:
68
+ ImportError: If g2p_en is not available
69
+ """
70
+ if not G2P_AVAILABLE:
71
+ raise ImportError(
72
+ "g2p_en library is required. Install with: pip install g2p-en"
73
+ )
74
+
75
+ try:
76
+ self.g2p = g2p_en.G2p()
77
+ logger.info("✅ G2P model loaded successfully")
78
+ except Exception as e:
79
+ logger.error(f"❌ Failed to load G2P model: {e}")
80
+ raise
81
+
82
+ self.frame_duration_ms = frame_duration_ms
83
+ self.frame_duration_s = frame_duration_ms / 1000.0
84
+ self.sample_rate = sample_rate
85
+
86
+ # Average phoneme duration (typical English: 50-100ms)
87
+ # We'll use 80ms as default, but adjust based on text length
88
+ self.avg_phoneme_duration_ms = 80
89
+ self.avg_phoneme_duration_s = self.avg_phoneme_duration_ms / 1000.0
90
+
91
+ logger.info(f"PhonemeMapper initialized: frame_duration={frame_duration_ms}ms, "
92
+ f"avg_phoneme_duration={self.avg_phoneme_duration_ms}ms")
93
+
94
+ def text_to_phonemes(
95
+ self,
96
+ text: str,
97
+ duration: Optional[float] = None
98
+ ) -> List[Tuple[str, float]]:
99
+ """
100
+ Convert text to phonemes with timing information.
101
+
102
+ Args:
103
+ text: Input text string (e.g., "robot", "cat")
104
+ duration: Optional audio duration in seconds. If provided, phonemes
105
+ are distributed evenly across this duration. If None, uses
106
+ estimated duration based on phoneme count.
107
+
108
+ Returns:
109
+ List of tuples: [(phoneme, start_time), ...]
110
+ - phoneme: Phoneme symbol with slashes (e.g., '/r/', '/k/')
111
+ - start_time: Start time in seconds
112
+
113
+ Example:
114
+ >>> mapper = PhonemeMapper()
115
+ >>> phonemes = mapper.text_to_phonemes("cat")
116
+ >>> # Returns: [('/k/', 0.0), ('/æ/', 0.08), ('/t/', 0.16)]
117
+ """
118
+ if not text or not text.strip():
119
+ logger.warning("Empty text provided, returning empty phoneme list")
120
+ return []
121
+
122
+ try:
123
+ # Convert to phonemes using g2p_en
124
+ phoneme_list = self.g2p(text.lower().strip())
125
+
126
+ # Filter out punctuation and empty strings
127
+ phoneme_list = [p for p in phoneme_list if p and p.strip() and not p.isspace()]
128
+
129
+ if not phoneme_list:
130
+ logger.warning(f"No phonemes extracted from text: '{text}'")
131
+ return []
132
+
133
+ # Add slashes if not present
134
+ formatted_phonemes = []
135
+ for p in phoneme_list:
136
+ if not p.startswith('/'):
137
+ p = '/' + p
138
+ if not p.endswith('/'):
139
+ p = p + '/'
140
+ formatted_phonemes.append(p)
141
+
142
+ logger.debug(f"Extracted {len(formatted_phonemes)} phonemes from '{text}': {formatted_phonemes}")
143
+
144
+ # Calculate timing
145
+ if duration is None:
146
+ # Estimate duration: avg_phoneme_duration * num_phonemes
147
+ total_duration = len(formatted_phonemes) * self.avg_phoneme_duration_s
148
+ else:
149
+ total_duration = duration
150
+
151
+ # Distribute phonemes evenly across duration
152
+ if len(formatted_phonemes) == 1:
153
+ phoneme_duration = total_duration
154
+ else:
155
+ phoneme_duration = total_duration / len(formatted_phonemes)
156
+
157
+ # Create phoneme-time pairs
158
+ phoneme_times = []
159
+ for i, phoneme in enumerate(formatted_phonemes):
160
+ start_time = i * phoneme_duration
161
+ phoneme_times.append((phoneme, start_time))
162
+
163
+ logger.info(f"Converted '{text}' to {len(phoneme_times)} phonemes over {total_duration:.2f}s")
164
+
165
+ return phoneme_times
166
+
167
+ except Exception as e:
168
+ logger.error(f"Error converting text to phonemes: {e}", exc_info=True)
169
+ raise RuntimeError(f"Failed to convert text to phonemes: {e}") from e
170
+
171
+ def align_phonemes_to_frames(
172
+ self,
173
+ phoneme_times: List[Tuple[str, float]],
174
+ num_frames: int,
175
+ frame_duration_ms: Optional[int] = None
176
+ ) -> List[str]:
177
+ """
178
+ Align phonemes to audio frames.
179
+
180
+ Each frame gets assigned the phoneme that overlaps with its time window.
181
+ If multiple phonemes overlap, uses the one with the most overlap.
182
+
183
+ Args:
184
+ phoneme_times: List of (phoneme, start_time) tuples from text_to_phonemes()
185
+ num_frames: Total number of frames in the audio
186
+ frame_duration_ms: Optional frame duration override
187
+
188
+ Returns:
189
+ List of phonemes, one per frame: ['/r/', '/r/', '/o/', '/b/', ...]
190
+
191
+ Example:
192
+ >>> mapper = PhonemeMapper()
193
+ >>> phonemes = [('/k/', 0.0), ('/æ/', 0.08), ('/t/', 0.16)]
194
+ >>> frames = mapper.align_phonemes_to_frames(phonemes, num_frames=15, frame_duration_ms=20)
195
+ >>> # Returns: ['/k/', '/k/', '/k/', '/k/', '/æ/', '/æ/', '/æ/', '/æ/', '/t/', ...]
196
+ """
197
+ if not phoneme_times:
198
+ logger.warning("No phonemes provided, returning empty frame list")
199
+ return [''] * num_frames
200
+
201
+ frame_duration_s = (frame_duration_ms / 1000.0) if frame_duration_ms else self.frame_duration_s
202
+
203
+ # Calculate phoneme end times (assume equal duration for simplicity)
204
+ phoneme_segments = []
205
+ for i, (phoneme, start_time) in enumerate(phoneme_times):
206
+ if i < len(phoneme_times) - 1:
207
+ end_time = phoneme_times[i + 1][1]
208
+ else:
209
+ # Last phoneme: estimate duration
210
+ if len(phoneme_times) > 1:
211
+ avg_duration = phoneme_times[1][1] - phoneme_times[0][1]
212
+ else:
213
+ avg_duration = self.avg_phoneme_duration_s
214
+ end_time = start_time + avg_duration
215
+
216
+ phoneme_segments.append(PhonemeSegment(
217
+ phoneme=phoneme,
218
+ start_time=start_time,
219
+ end_time=end_time,
220
+ duration=end_time - start_time,
221
+ frame_start=-1, # Will be calculated
222
+ frame_end=-1
223
+ ))
224
+
225
+ # Map each frame to a phoneme
226
+ frame_phonemes = []
227
+ for frame_idx in range(num_frames):
228
+ frame_start_time = frame_idx * frame_duration_s
229
+ frame_end_time = (frame_idx + 1) * frame_duration_s
230
+ frame_center_time = frame_start_time + (frame_duration_s / 2.0)
231
+
232
+ # Find phoneme with most overlap
233
+ best_phoneme = ''
234
+ max_overlap = 0.0
235
+
236
+ for seg in phoneme_segments:
237
+ # Calculate overlap
238
+ overlap_start = max(frame_start_time, seg.start_time)
239
+ overlap_end = min(frame_end_time, seg.end_time)
240
+ overlap = max(0.0, overlap_end - overlap_start)
241
+
242
+ if overlap > max_overlap:
243
+ max_overlap = overlap
244
+ best_phoneme = seg.phoneme
245
+
246
+ # If no overlap, use closest phoneme
247
+ if not best_phoneme:
248
+ closest_seg = min(
249
+ phoneme_segments,
250
+ key=lambda s: abs(frame_center_time - (s.start_time + s.duration / 2))
251
+ )
252
+ best_phoneme = closest_seg.phoneme
253
+
254
+ frame_phonemes.append(best_phoneme)
255
+
256
+ logger.debug(f"Aligned {len(phoneme_times)} phonemes to {num_frames} frames")
257
+
258
+ return frame_phonemes
259
+
260
+ def get_phoneme_boundaries(
261
+ self,
262
+ phoneme_times: List[Tuple[str, float]],
263
+ duration: float
264
+ ) -> List[PhonemeSegment]:
265
+ """
266
+ Get detailed phoneme boundary information.
267
+
268
+ Args:
269
+ phoneme_times: List of (phoneme, start_time) tuples
270
+ duration: Total audio duration in seconds
271
+
272
+ Returns:
273
+ List of PhonemeSegment objects with timing and frame information
274
+ """
275
+ segments = []
276
+
277
+ for i, (phoneme, start_time) in enumerate(phoneme_times):
278
+ if i < len(phoneme_times) - 1:
279
+ end_time = phoneme_times[i + 1][1]
280
+ else:
281
+ end_time = duration
282
+
283
+ frame_start = int(start_time / self.frame_duration_s)
284
+ frame_end = int(end_time / self.frame_duration_s)
285
+
286
+ segments.append(PhonemeSegment(
287
+ phoneme=phoneme,
288
+ start_time=start_time,
289
+ end_time=end_time,
290
+ duration=end_time - start_time,
291
+ frame_start=frame_start,
292
+ frame_end=frame_end
293
+ ))
294
+
295
+ return segments
296
+
297
+ def map_text_to_frames(
298
+ self,
299
+ text: str,
300
+ num_frames: int,
301
+ audio_duration: Optional[float] = None
302
+ ) -> List[str]:
303
+ """
304
+ Complete pipeline: text → phonemes → frame alignment.
305
+
306
+ Args:
307
+ text: Input text string
308
+ num_frames: Number of audio frames
309
+ audio_duration: Optional audio duration in seconds
310
+
311
+ Returns:
312
+ List of phonemes, one per frame
313
+ """
314
+ # Convert text to phonemes
315
+ phoneme_times = self.text_to_phonemes(text, duration=audio_duration)
316
+
317
+ if not phoneme_times:
318
+ return [''] * num_frames
319
+
320
+ # Align to frames
321
+ frame_phonemes = self.align_phonemes_to_frames(phoneme_times, num_frames)
322
+
323
+ return frame_phonemes
324
+
325
+
326
+ # Unit test function
327
+ def test_phoneme_mapper():
328
+ """Test the PhonemeMapper with example text."""
329
+ print("Testing PhonemeMapper...")
330
+
331
+ try:
332
+ mapper = PhonemeMapper(frame_duration_ms=20)
333
+
334
+ # Test 1: Simple word
335
+ print("\n1. Testing 'robot':")
336
+ phonemes = mapper.text_to_phonemes("robot")
337
+ print(f" Phonemes: {phonemes}")
338
+ assert len(phonemes) > 0, "Should extract phonemes"
339
+
340
+ # Test 2: Frame alignment
341
+ print("\n2. Testing frame alignment:")
342
+ frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25)
343
+ print(f" Frame phonemes (first 10): {frame_phonemes[:10]}")
344
+ assert len(frame_phonemes) == 25, "Should have 25 frames"
345
+
346
+ # Test 3: Complete pipeline
347
+ print("\n3. Testing complete pipeline with 'cat':")
348
+ cat_frames = mapper.map_text_to_frames("cat", num_frames=15)
349
+ print(f" Frame phonemes: {cat_frames}")
350
+ assert len(cat_frames) == 15, "Should have 15 frames"
351
+
352
+ print("\n✅ All tests passed!")
353
+
354
+ except ImportError as e:
355
+ print(f"❌ G2P library not available: {e}")
356
+ print(" Install with: pip install g2p-en")
357
+ except Exception as e:
358
+ print(f"❌ Test failed: {e}")
359
+ raise
360
+
361
+
362
+ if __name__ == "__main__":
363
+ test_phoneme_mapper()
364
+
models/speech_pathology_model.py CHANGED
@@ -11,7 +11,7 @@ import logging
11
  import torch
12
  import torch.nn as nn
13
  from torch.nn import functional as F
14
- from transformers import Wav2Vec2Model, Wav2Vec2Processor, Wav2Vec2Config
15
  from typing import Dict, Optional, Tuple, List
16
  import os
17
 
@@ -51,7 +51,7 @@ class MultiTaskClassifierHead(nn.Module):
51
 
52
  self.num_articulation_classes = num_articulation_classes
53
 
54
- # Build shared feature layers
55
  layers = []
56
  prev_dim = input_dim
57
 
@@ -67,15 +67,15 @@ class MultiTaskClassifierHead(nn.Module):
67
  self.shared_layers = nn.Sequential(*layers)
68
  shared_output_dim = prev_dim
69
 
70
- # Fluency head (binary classification: fluent vs disfluent)
71
  self.fluency_head = nn.Sequential(
72
  nn.Linear(shared_output_dim, 64),
73
  nn.ReLU(),
74
  nn.Dropout(dropout),
75
- nn.Linear(64, 1), # Binary output (sigmoid)
76
  )
77
 
78
- # Articulation head (multi-class classification)
79
  self.articulation_head = nn.Sequential(
80
  nn.Linear(shared_output_dim, 64),
81
  nn.ReLU(),
@@ -83,6 +83,14 @@ class MultiTaskClassifierHead(nn.Module):
83
  nn.Linear(64, num_articulation_classes), # 4 classes
84
  )
85
 
 
 
 
 
 
 
 
 
86
  logger.info(
87
  f"Initialized MultiTaskClassifierHead: "
88
  f"input_dim={input_dim}, hidden_dims={hidden_dims}, "
@@ -124,18 +132,23 @@ class MultiTaskClassifierHead(nn.Module):
124
  shared_features = self.shared_layers(pooled_features)
125
 
126
  # Task-specific heads
127
- fluency_logits = self.fluency_head(shared_features)
128
- articulation_logits = self.articulation_head(shared_features)
 
129
 
130
  # Apply activations
131
- fluency_probs = torch.sigmoid(fluency_logits)
132
- articulation_probs = F.softmax(articulation_logits, dim=-1)
 
133
 
134
  return {
135
  "fluency_logits": fluency_logits,
136
  "articulation_logits": articulation_logits,
 
137
  "fluency_probs": fluency_probs,
138
  "articulation_probs": articulation_probs,
 
 
139
  }
140
 
141
 
@@ -210,13 +223,15 @@ class SpeechPathologyClassifier(nn.Module):
210
  # Load Wav2Vec2 model and processor
211
  hf_token = os.getenv("HF_TOKEN")
212
 
213
- logger.info("Loading Wav2Vec2 model and processor...")
214
  self.wav2vec2_model = Wav2Vec2Model.from_pretrained(
215
  model_name,
216
  token=hf_token if hf_token else None
217
  )
218
 
219
- self.processor = Wav2Vec2Processor.from_pretrained(
 
 
220
  model_name,
221
  token=hf_token if hf_token else None
222
  )
@@ -281,16 +296,49 @@ class SpeechPathologyClassifier(nn.Module):
281
  - articulation_probs: Articulation class probabilities
282
  - wav2vec2_features: Raw Wav2Vec2 features (for debugging)
283
  """
 
 
 
 
 
 
 
 
284
  # Extract features using Wav2Vec2
285
- with torch.no_grad() if not self.training else torch.enable_grad():
286
- wav2vec2_outputs = self.wav2vec2_model(
287
- input_values=input_values,
288
- attention_mask=attention_mask
289
- )
 
 
 
 
 
 
 
 
 
 
290
 
291
  # Get last hidden state (features)
292
  features = wav2vec2_outputs.last_hidden_state # (batch_size, seq_len, feature_dim)
293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  # Pass through classifier head
295
  outputs = self.classifier_head(features, attention_mask)
296
 
 
11
  import torch
12
  import torch.nn as nn
13
  from torch.nn import functional as F
14
+ from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2Config
15
  from typing import Dict, Optional, Tuple, List
16
  import os
17
 
 
51
 
52
  self.num_articulation_classes = num_articulation_classes
53
 
54
+ # Build shared feature layers: 1024 → 512 → 256
55
  layers = []
56
  prev_dim = input_dim
57
 
 
67
  self.shared_layers = nn.Sequential(*layers)
68
  shared_output_dim = prev_dim
69
 
70
+ # Fluency head: 256 64 2 (stutter/normal)
71
  self.fluency_head = nn.Sequential(
72
  nn.Linear(shared_output_dim, 64),
73
  nn.ReLU(),
74
  nn.Dropout(dropout),
75
+ nn.Linear(64, 2), # 2 classes: stutter/normal
76
  )
77
 
78
+ # Articulation head: 256 → 64 → 4 (normal/sub/omit/dist)
79
  self.articulation_head = nn.Sequential(
80
  nn.Linear(shared_output_dim, 64),
81
  nn.ReLU(),
 
83
  nn.Linear(64, num_articulation_classes), # 4 classes
84
  )
85
 
86
+ # Full combined head: 256 → 128 → 8 (all classes combined)
87
+ self.full_head = nn.Sequential(
88
+ nn.Linear(shared_output_dim, 128),
89
+ nn.ReLU(),
90
+ nn.Dropout(dropout),
91
+ nn.Linear(128, 8), # 8 classes (combined fluency + articulation)
92
+ )
93
+
94
  logger.info(
95
  f"Initialized MultiTaskClassifierHead: "
96
  f"input_dim={input_dim}, hidden_dims={hidden_dims}, "
 
132
  shared_features = self.shared_layers(pooled_features)
133
 
134
  # Task-specific heads
135
+ fluency_logits = self.fluency_head(shared_features) # (batch, 2)
136
+ articulation_logits = self.articulation_head(shared_features) # (batch, 4)
137
+ full_logits = self.full_head(shared_features) # (batch, 8)
138
 
139
  # Apply activations
140
+ fluency_probs = F.softmax(fluency_logits, dim=-1) # (batch, 2)
141
+ articulation_probs = F.softmax(articulation_logits, dim=-1) # (batch, 4)
142
+ full_probs = F.softmax(full_logits, dim=-1) # (batch, 8)
143
 
144
  return {
145
  "fluency_logits": fluency_logits,
146
  "articulation_logits": articulation_logits,
147
+ "full_logits": full_logits,
148
  "fluency_probs": fluency_probs,
149
  "articulation_probs": articulation_probs,
150
+ "full_probs": full_probs,
151
+ "shared_features": shared_features,
152
  }
153
 
154
 
 
223
  # Load Wav2Vec2 model and processor
224
  hf_token = os.getenv("HF_TOKEN")
225
 
226
+ logger.info("Loading Wav2Vec2 model and feature extractor...")
227
  self.wav2vec2_model = Wav2Vec2Model.from_pretrained(
228
  model_name,
229
  token=hf_token if hf_token else None
230
  )
231
 
232
+ # Use FeatureExtractor instead of Processor for feature extraction tasks
233
+ # Processor includes tokenizer which requires vocab file (not available for pre-trained models)
234
+ self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
235
  model_name,
236
  token=hf_token if hf_token else None
237
  )
 
296
  - articulation_probs: Articulation class probabilities
297
  - wav2vec2_features: Raw Wav2Vec2 features (for debugging)
298
  """
299
+ # #region agent log
300
+ try:
301
+ with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
302
+ import json, time
303
+ f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Before Wav2Vec2 forward","data":{"input_values_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n')
304
+ except: pass
305
+ # #endregion
306
+
307
  # Extract features using Wav2Vec2
308
+ try:
309
+ with torch.no_grad() if not self.training else torch.enable_grad():
310
+ wav2vec2_outputs = self.wav2vec2_model(
311
+ input_values=input_values,
312
+ attention_mask=attention_mask
313
+ )
314
+ except Exception as e:
315
+ # #region agent log
316
+ try:
317
+ with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
318
+ import json, time
319
+ f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Wav2Vec2 forward exception","data":{"error":str(e),"error_type":type(e).__name__,"input_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n')
320
+ except: pass
321
+ # #endregion
322
+ raise
323
 
324
  # Get last hidden state (features)
325
  features = wav2vec2_outputs.last_hidden_state # (batch_size, seq_len, feature_dim)
326
 
327
+ # #region agent log
328
+ try:
329
+ with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
330
+ import json, time
331
+ f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:297","message":"After Wav2Vec2 forward","data":{"features_shape":list(features.shape),"seq_len":features.shape[1] if len(features.shape) > 1 else 0},"timestamp":int(time.time()*1000)}) + '\n')
332
+ except: pass
333
+ # #endregion
334
+
335
+ # Safety check: ensure sequence length is valid (at least 1)
336
+ if features.shape[1] < 1:
337
+ raise ValueError(
338
+ f"Wav2Vec2 output sequence length is too short: {features.shape[1]}. "
339
+ f"Input was {input_values.shape}. Try using longer audio segments (>= 500ms)."
340
+ )
341
+
342
  # Pass through classifier head
343
  outputs = self.classifier_head(features, attention_mask)
344
 
tests/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """
2
+ Test module for speech pathology diagnosis system.
3
+ """
4
+
tests/integration_tests.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration tests for speech pathology diagnosis API.
3
+
4
+ Tests API endpoints, error mapping, and therapy recommendations.
5
+ """
6
+
7
+ import logging
8
+ import numpy as np
9
+ import tempfile
10
+ import soundfile as sf
11
+ from pathlib import Path
12
+ import json
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def test_phoneme_mapping():
18
+ """Test phoneme mapping functionality."""
19
+ logger.info("Testing phoneme mapping...")
20
+
21
+ try:
22
+ from models.phoneme_mapper import PhonemeMapper
23
+
24
+ mapper = PhonemeMapper(frame_duration_ms=20)
25
+
26
+ # Test 1: Simple word
27
+ phonemes = mapper.text_to_phonemes("robot")
28
+ assert len(phonemes) > 0, "Should extract phonemes"
29
+ logger.info(f"✅ 'robot' → {len(phonemes)} phonemes: {[p[0] for p in phonemes]}")
30
+
31
+ # Test 2: Frame alignment
32
+ frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25)
33
+ assert len(frame_phonemes) == 25, "Should have 25 frames"
34
+ logger.info(f"✅ Aligned to {len(frame_phonemes)} frames")
35
+
36
+ # Test 3: Complete pipeline
37
+ cat_frames = mapper.map_text_to_frames("cat", num_frames=15)
38
+ assert len(cat_frames) == 15, "Should have 15 frames"
39
+ logger.info(f"✅ 'cat' → {len(cat_frames)} frame phonemes")
40
+
41
+ return True
42
+
43
+ except ImportError as e:
44
+ logger.warning(f"⚠️ G2P library not available: {e}")
45
+ return False
46
+ except Exception as e:
47
+ logger.error(f"❌ Phoneme mapping test failed: {e}")
48
+ return False
49
+
50
+
51
+ def test_error_taxonomy():
52
+ """Test error taxonomy and therapy mapping."""
53
+ logger.info("Testing error taxonomy...")
54
+
55
+ try:
56
+ from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel
57
+
58
+ mapper = ErrorMapper()
59
+
60
+ # Test 1: Normal (class 0)
61
+ error = mapper.map_classifier_output(0, 0.95, "/k/")
62
+ assert error.error_type == ErrorType.NORMAL
63
+ assert error.severity == 0.0
64
+ logger.info(f"✅ Normal error mapping: {error.error_type}")
65
+
66
+ # Test 2: Substitution (class 1)
67
+ error = mapper.map_classifier_output(1, 0.78, "/s/")
68
+ assert error.error_type == ErrorType.SUBSTITUTION
69
+ assert error.wrong_sound is not None
70
+ logger.info(f"✅ Substitution error: {error.error_type}, wrong_sound={error.wrong_sound}")
71
+ logger.info(f" Therapy: {error.therapy[:60]}...")
72
+
73
+ # Test 3: Omission (class 2)
74
+ error = mapper.map_classifier_output(2, 0.85, "/r/")
75
+ assert error.error_type == ErrorType.OMISSION
76
+ logger.info(f"✅ Omission error: {error.error_type}")
77
+ logger.info(f" Therapy: {error.therapy[:60]}...")
78
+
79
+ # Test 4: Distortion (class 3)
80
+ error = mapper.map_classifier_output(3, 0.65, "/s/")
81
+ assert error.error_type == ErrorType.DISTORTION
82
+ logger.info(f"✅ Distortion error: {error.error_type}")
83
+ logger.info(f" Therapy: {error.therapy[:60]}...")
84
+
85
+ # Test 5: Severity levels
86
+ assert mapper.get_severity_level(0.0) == SeverityLevel.NONE
87
+ assert mapper.get_severity_level(0.2) == SeverityLevel.LOW
88
+ assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM
89
+ assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH
90
+ logger.info("✅ Severity level mapping correct")
91
+
92
+ return True
93
+
94
+ except Exception as e:
95
+ logger.error(f"❌ Error taxonomy test failed: {e}")
96
+ return False
97
+
98
+
99
+ def test_batch_diagnosis_endpoint(pipeline, phoneme_mapper, error_mapper):
100
+ """Test batch diagnosis endpoint functionality."""
101
+ logger.info("Testing batch diagnosis endpoint...")
102
+
103
+ try:
104
+ # Generate test audio
105
+ duration = 2.0
106
+ sample_rate = 16000
107
+ num_samples = int(duration * sample_rate)
108
+ audio = 0.5 * np.sin(2 * np.pi * 440 * np.linspace(0, duration, num_samples))
109
+ audio = audio.astype(np.float32)
110
+
111
+ # Save to temp file
112
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
113
+ temp_path = f.name
114
+ sf.write(temp_path, audio, sample_rate)
115
+
116
+ try:
117
+ # Run inference
118
+ result = pipeline.predict_phone_level(temp_path, return_timestamps=True)
119
+
120
+ # Map phonemes
121
+ text = "test audio"
122
+ frame_phonemes = phoneme_mapper.map_text_to_frames(
123
+ text,
124
+ num_frames=result.num_frames,
125
+ audio_duration=result.duration
126
+ )
127
+
128
+ # Process errors
129
+ errors = []
130
+ for i, frame_pred in enumerate(result.frame_predictions):
131
+ class_id = frame_pred.articulation_class
132
+ if frame_pred.fluency_label == 'stutter':
133
+ class_id += 4
134
+
135
+ error_detail = error_mapper.map_classifier_output(
136
+ class_id=class_id,
137
+ confidence=frame_pred.confidence,
138
+ phoneme=frame_phonemes[i] if i < len(frame_phonemes) else '',
139
+ fluency_label=frame_pred.fluency_label
140
+ )
141
+
142
+ if error_detail.error_type != ErrorType.NORMAL:
143
+ errors.append(error_detail)
144
+
145
+ logger.info(f"✅ Batch diagnosis: {result.num_frames} frames, {len(errors)} errors detected")
146
+
147
+ return True
148
+
149
+ finally:
150
+ import os
151
+ if os.path.exists(temp_path):
152
+ os.remove(temp_path)
153
+
154
+ except Exception as e:
155
+ logger.error(f"❌ Batch diagnosis test failed: {e}")
156
+ return False
157
+
158
+
159
+ def test_therapy_recommendations():
160
+ """Test therapy recommendation coverage."""
161
+ logger.info("Testing therapy recommendations...")
162
+
163
+ try:
164
+ from models.error_taxonomy import ErrorMapper, ErrorType
165
+
166
+ mapper = ErrorMapper()
167
+
168
+ # Test common phonemes
169
+ test_cases = [
170
+ ("/s/", ErrorType.SUBSTITUTION, "/θ/"),
171
+ ("/r/", ErrorType.OMISSION, None),
172
+ ("/s/", ErrorType.DISTORTION, None),
173
+ ]
174
+
175
+ for phoneme, error_type, wrong_sound in test_cases:
176
+ therapy = mapper.get_therapy(error_type, phoneme, wrong_sound)
177
+ assert therapy and len(therapy) > 0, f"Therapy should not be empty for {phoneme}"
178
+ logger.info(f"✅ {phoneme} {error_type.value}: {therapy[:50]}...")
179
+
180
+ return True
181
+
182
+ except Exception as e:
183
+ logger.error(f"❌ Therapy recommendations test failed: {e}")
184
+ return False
185
+
186
+
187
+ def run_all_integration_tests():
188
+ """Run all integration tests."""
189
+ logger.info("=" * 60)
190
+ logger.info("Running Integration Tests")
191
+ logger.info("=" * 60)
192
+
193
+ results = {}
194
+
195
+ # Test 1: Phoneme mapping
196
+ logger.info("\n1. Phoneme Mapping Test")
197
+ results["phoneme_mapping"] = test_phoneme_mapping()
198
+
199
+ # Test 2: Error taxonomy
200
+ logger.info("\n2. Error Taxonomy Test")
201
+ results["error_taxonomy"] = test_error_taxonomy()
202
+
203
+ # Test 3: Therapy recommendations
204
+ logger.info("\n3. Therapy Recommendations Test")
205
+ results["therapy_recommendations"] = test_therapy_recommendations()
206
+
207
+ # Test 4: Batch diagnosis (if pipeline available)
208
+ try:
209
+ from inference.inference_pipeline import create_inference_pipeline
210
+ from models.phoneme_mapper import PhonemeMapper
211
+ from models.error_taxonomy import ErrorMapper
212
+
213
+ logger.info("\n4. Batch Diagnosis Test")
214
+ pipeline = create_inference_pipeline()
215
+ phoneme_mapper = PhonemeMapper()
216
+ error_mapper = ErrorMapper()
217
+
218
+ results["batch_diagnosis"] = test_batch_diagnosis_endpoint(
219
+ pipeline, phoneme_mapper, error_mapper
220
+ )
221
+ except Exception as e:
222
+ logger.warning(f"⚠️ Batch diagnosis test skipped: {e}")
223
+ results["batch_diagnosis"] = False
224
+
225
+ # Summary
226
+ logger.info("\n" + "=" * 60)
227
+ logger.info("Integration Test Summary")
228
+ logger.info("=" * 60)
229
+
230
+ for test_name, passed in results.items():
231
+ status = "✅ PASSED" if passed else "❌ FAILED"
232
+ logger.info(f"{status}: {test_name}")
233
+
234
+ all_passed = all(results.values())
235
+ return all_passed, results
236
+
237
+
238
+ if __name__ == "__main__":
239
+ logging.basicConfig(level=logging.INFO)
240
+
241
+ all_passed, results = run_all_integration_tests()
242
+
243
+ if all_passed:
244
+ logger.info("\n✅ All integration tests passed!")
245
+ exit(0)
246
+ else:
247
+ logger.error("\n❌ Some integration tests failed!")
248
+ exit(1)
249
+
tests/performance_tests.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Performance tests for speech pathology diagnosis system.
3
+
4
+ Tests latency requirements:
5
+ - File batch: <200ms per file
6
+ - Per-frame: <50ms
7
+ - WebSocket roundtrip: <100ms
8
+ """
9
+
10
+ import time
11
+ import numpy as np
12
+ import logging
13
+ from pathlib import Path
14
+ import asyncio
15
+ from typing import Dict, List
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def generate_test_audio(duration_seconds: float = 1.0, sample_rate: int = 16000) -> np.ndarray:
21
+ """
22
+ Generate synthetic test audio.
23
+
24
+ Args:
25
+ duration_seconds: Duration in seconds
26
+ sample_rate: Sample rate in Hz
27
+
28
+ Returns:
29
+ Audio array
30
+ """
31
+ num_samples = int(duration_seconds * sample_rate)
32
+ # Generate simple sine wave
33
+ t = np.linspace(0, duration_seconds, num_samples)
34
+ audio = 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone
35
+ return audio.astype(np.float32)
36
+
37
+
38
+ def test_batch_latency(pipeline, num_files: int = 10) -> Dict[str, float]:
39
+ """
40
+ Test batch file processing latency.
41
+
42
+ Args:
43
+ pipeline: InferencePipeline instance
44
+ num_files: Number of test files to process
45
+
46
+ Returns:
47
+ Dictionary with latency statistics
48
+ """
49
+ logger.info(f"Testing batch latency with {num_files} files...")
50
+
51
+ latencies = []
52
+
53
+ for i in range(num_files):
54
+ # Generate test audio
55
+ audio = generate_test_audio(duration_seconds=1.0)
56
+
57
+ # Save to temp file
58
+ import tempfile
59
+ import soundfile as sf
60
+
61
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
62
+ temp_path = f.name
63
+ sf.write(temp_path, audio, 16000)
64
+
65
+ try:
66
+ start_time = time.time()
67
+ result = pipeline.predict_phone_level(temp_path, return_timestamps=True)
68
+ latency_ms = (time.time() - start_time) * 1000
69
+ latencies.append(latency_ms)
70
+
71
+ logger.info(f" File {i+1}: {latency_ms:.1f}ms ({result.num_frames} frames)")
72
+ except Exception as e:
73
+ logger.error(f" File {i+1} failed: {e}")
74
+ finally:
75
+ import os
76
+ if os.path.exists(temp_path):
77
+ os.remove(temp_path)
78
+
79
+ if not latencies:
80
+ return {"error": "No successful runs"}
81
+
82
+ avg_latency = sum(latencies) / len(latencies)
83
+ max_latency = max(latencies)
84
+ min_latency = min(latencies)
85
+
86
+ result = {
87
+ "avg_latency_ms": avg_latency,
88
+ "max_latency_ms": max_latency,
89
+ "min_latency_ms": min_latency,
90
+ "num_files": len(latencies),
91
+ "target_ms": 200.0,
92
+ "passed": avg_latency < 200.0
93
+ }
94
+
95
+ logger.info(f"✅ Batch latency test: avg={avg_latency:.1f}ms, max={max_latency:.1f}ms, "
96
+ f"target=200ms, passed={result['passed']}")
97
+
98
+ return result
99
+
100
+
101
+ def test_frame_latency(pipeline, num_frames: int = 100) -> Dict[str, float]:
102
+ """
103
+ Test per-frame processing latency.
104
+
105
+ Args:
106
+ pipeline: InferencePipeline instance
107
+ num_frames: Number of frames to test
108
+
109
+ Returns:
110
+ Dictionary with latency statistics
111
+ """
112
+ logger.info(f"Testing frame latency with {num_frames} frames...")
113
+
114
+ # Generate 1 second of audio (enough for one window)
115
+ audio = generate_test_audio(duration_seconds=1.0)
116
+
117
+ latencies = []
118
+
119
+ for i in range(num_frames):
120
+ start_time = time.time()
121
+ try:
122
+ result = pipeline.predict_phone_level(audio, return_timestamps=False)
123
+ latency_ms = (time.time() - start_time) * 1000
124
+ latencies.append(latency_ms)
125
+ except Exception as e:
126
+ logger.error(f" Frame {i+1} failed: {e}")
127
+
128
+ if not latencies:
129
+ return {"error": "No successful runs"}
130
+
131
+ avg_latency = sum(latencies) / len(latencies)
132
+ max_latency = max(latencies)
133
+ min_latency = min(latencies)
134
+ p95_latency = sorted(latencies)[int(len(latencies) * 0.95)]
135
+
136
+ result = {
137
+ "avg_latency_ms": avg_latency,
138
+ "max_latency_ms": max_latency,
139
+ "min_latency_ms": min_latency,
140
+ "p95_latency_ms": p95_latency,
141
+ "num_frames": len(latencies),
142
+ "target_ms": 50.0,
143
+ "passed": avg_latency < 50.0
144
+ }
145
+
146
+ logger.info(f"✅ Frame latency test: avg={avg_latency:.1f}ms, p95={p95_latency:.1f}ms, "
147
+ f"target=50ms, passed={result['passed']}")
148
+
149
+ return result
150
+
151
+
152
+ async def test_websocket_latency(websocket_url: str, num_chunks: int = 50) -> Dict[str, float]:
153
+ """
154
+ Test WebSocket streaming latency.
155
+
156
+ Args:
157
+ websocket_url: WebSocket URL
158
+ num_chunks: Number of chunks to send
159
+
160
+ Returns:
161
+ Dictionary with latency statistics
162
+ """
163
+ try:
164
+ import websockets
165
+
166
+ logger.info(f"Testing WebSocket latency with {num_chunks} chunks...")
167
+
168
+ latencies = []
169
+
170
+ async with websockets.connect(websocket_url) as websocket:
171
+ # Generate test audio chunk (20ms @ 16kHz = 320 samples)
172
+ chunk_samples = 320
173
+ audio_chunk = generate_test_audio(duration_seconds=0.02)
174
+ chunk_bytes = (audio_chunk * 32768).astype(np.int16).tobytes()
175
+
176
+ for i in range(num_chunks):
177
+ start_time = time.time()
178
+
179
+ # Send chunk
180
+ await websocket.send(chunk_bytes)
181
+
182
+ # Receive response
183
+ response = await websocket.recv()
184
+
185
+ latency_ms = (time.time() - start_time) * 1000
186
+ latencies.append(latency_ms)
187
+
188
+ if i % 10 == 0:
189
+ logger.info(f" Chunk {i+1}: {latency_ms:.1f}ms")
190
+
191
+ if not latencies:
192
+ return {"error": "No successful runs"}
193
+
194
+ avg_latency = sum(latencies) / len(latencies)
195
+ max_latency = max(latencies)
196
+ p95_latency = sorted(latencies)[int(len(latencies) * 0.95)]
197
+
198
+ result = {
199
+ "avg_latency_ms": avg_latency,
200
+ "max_latency_ms": max_latency,
201
+ "p95_latency_ms": p95_latency,
202
+ "num_chunks": len(latencies),
203
+ "target_ms": 100.0,
204
+ "passed": avg_latency < 100.0
205
+ }
206
+
207
+ logger.info(f"✅ WebSocket latency test: avg={avg_latency:.1f}ms, p95={p95_latency:.1f}ms, "
208
+ f"target=100ms, passed={result['passed']}")
209
+
210
+ return result
211
+
212
+ except ImportError:
213
+ logger.warning("websockets library not available, skipping WebSocket test")
214
+ return {"error": "websockets library not available"}
215
+ except Exception as e:
216
+ logger.error(f"WebSocket test failed: {e}")
217
+ return {"error": str(e)}
218
+
219
+
220
+ def test_concurrent_connections(pipeline, num_connections: int = 10) -> Dict[str, Any]:
221
+ """
222
+ Test concurrent processing (simulated).
223
+
224
+ Args:
225
+ pipeline: InferencePipeline instance
226
+ num_connections: Number of concurrent requests
227
+
228
+ Returns:
229
+ Dictionary with results
230
+ """
231
+ logger.info(f"Testing {num_connections} concurrent connections...")
232
+
233
+ import concurrent.futures
234
+
235
+ def process_audio(i: int):
236
+ try:
237
+ audio = generate_test_audio(duration_seconds=0.5)
238
+ start_time = time.time()
239
+ result = pipeline.predict_phone_level(audio, return_timestamps=False)
240
+ latency_ms = (time.time() - start_time) * 1000
241
+ return {"success": True, "latency_ms": latency_ms, "frames": result.num_frames}
242
+ except Exception as e:
243
+ return {"success": False, "error": str(e)}
244
+
245
+ start_time = time.time()
246
+
247
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_connections) as executor:
248
+ futures = [executor.submit(process_audio, i) for i in range(num_connections)]
249
+ results = [f.result() for f in concurrent.futures.as_completed(futures)]
250
+
251
+ total_time = time.time() - start_time
252
+
253
+ successful = sum(1 for r in results if r.get("success", False))
254
+ avg_latency = sum(r["latency_ms"] for r in results if r.get("success", False)) / successful if successful > 0 else 0.0
255
+
256
+ result = {
257
+ "total_connections": num_connections,
258
+ "successful": successful,
259
+ "failed": num_connections - successful,
260
+ "total_time_seconds": total_time,
261
+ "avg_latency_ms": avg_latency,
262
+ "throughput_per_second": successful / total_time if total_time > 0 else 0.0
263
+ }
264
+
265
+ logger.info(f"✅ Concurrent test: {successful}/{num_connections} successful, "
266
+ f"avg_latency={avg_latency:.1f}ms, throughput={result['throughput_per_second']:.1f}/s")
267
+
268
+ return result
269
+
270
+
271
+ def run_all_performance_tests(pipeline, websocket_url: Optional[str] = None) -> Dict[str, Any]:
272
+ """
273
+ Run all performance tests.
274
+
275
+ Args:
276
+ pipeline: InferencePipeline instance
277
+ websocket_url: Optional WebSocket URL for streaming tests
278
+
279
+ Returns:
280
+ Dictionary with all test results
281
+ """
282
+ logger.info("=" * 60)
283
+ logger.info("Running Performance Tests")
284
+ logger.info("=" * 60)
285
+
286
+ results = {}
287
+
288
+ # Test 1: Batch latency
289
+ logger.info("\n1. Batch File Latency Test")
290
+ results["batch_latency"] = test_batch_latency(pipeline)
291
+
292
+ # Test 2: Frame latency
293
+ logger.info("\n2. Per-Frame Latency Test")
294
+ results["frame_latency"] = test_frame_latency(pipeline)
295
+
296
+ # Test 3: Concurrent connections
297
+ logger.info("\n3. Concurrent Connections Test")
298
+ results["concurrent"] = test_concurrent_connections(pipeline, num_connections=10)
299
+
300
+ # Test 4: WebSocket latency (if URL provided)
301
+ if websocket_url:
302
+ logger.info("\n4. WebSocket Latency Test")
303
+ results["websocket_latency"] = asyncio.run(test_websocket_latency(websocket_url))
304
+
305
+ # Summary
306
+ logger.info("\n" + "=" * 60)
307
+ logger.info("Performance Test Summary")
308
+ logger.info("=" * 60)
309
+
310
+ if "batch_latency" in results and results["batch_latency"].get("passed"):
311
+ logger.info("✅ Batch latency: PASSED")
312
+ else:
313
+ logger.warning("❌ Batch latency: FAILED")
314
+
315
+ if "frame_latency" in results and results["frame_latency"].get("passed"):
316
+ logger.info("✅ Frame latency: PASSED")
317
+ else:
318
+ logger.warning("❌ Frame latency: FAILED")
319
+
320
+ if "websocket_latency" in results and results["websocket_latency"].get("passed"):
321
+ logger.info("✅ WebSocket latency: PASSED")
322
+ elif "websocket_latency" in results:
323
+ logger.warning("❌ WebSocket latency: FAILED")
324
+
325
+ return results
326
+
327
+
328
+ if __name__ == "__main__":
329
+ # Example usage
330
+ logging.basicConfig(level=logging.INFO)
331
+
332
+ try:
333
+ from inference.inference_pipeline import create_inference_pipeline
334
+
335
+ pipeline = create_inference_pipeline()
336
+ results = run_all_performance_tests(pipeline)
337
+
338
+ print("\nTest Results:")
339
+ import json
340
+ print(json.dumps(results, indent=2))
341
+
342
+ except Exception as e:
343
+ logger.error(f"Test failed: {e}", exc_info=True)
344
+
ui/gradio_interface.py CHANGED
@@ -163,21 +163,21 @@ def analyze_speech(
163
  try:
164
  with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
165
  import json
166
- f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:139","message":"After predict_batch call","data":{"success":True},"timestamp":int(time.time()*1000)}) + '\n')
167
  except: pass
168
  # #endregion
169
 
170
  # Calculate processing time
171
  processing_time_ms = (time.time() - start_time) * 1000
172
 
173
- # Extract metrics
174
- fluency_metrics = result.fluency_metrics
175
- mean_fluency = fluency_metrics.get("mean", 0.0)
176
- fluent_frames_ratio = fluency_metrics.get("fluent_frames_ratio", 0.0)
177
 
178
- # Convert fluency score to percentage (0-100)
179
- fluency_percentage = mean_fluency * 100
180
- fluent_frames_percentage = fluent_frames_ratio * 100
181
 
182
  # Format fluency score with color coding
183
  if fluency_percentage >= 80:
@@ -203,10 +203,22 @@ def analyze_speech(
203
  """
204
 
205
  # Format articulation issues
206
- articulation_text = format_articulation_issues(result.articulation_scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
 
208
- # Format confidence
209
- confidence_percentage = result.confidence * 100
210
  confidence_html = f"""
211
  <div style='text-align: center; padding: 10px;'>
212
  <h3 style='color: #2196F3; font-size: 32px; margin: 5px 0;'>
@@ -223,7 +235,7 @@ def analyze_speech(
223
  ⏱️ Processing Time: <strong>{processing_time_ms:.0f}ms</strong>
224
  </p>
225
  <p style='color: #999; font-size: 12px;'>
226
- Analyzed {len(result.articulation_scores)} frames
227
  </p>
228
  </div>
229
  """
@@ -232,24 +244,32 @@ def analyze_speech(
232
  json_output = {
233
  "status": "success",
234
  "fluency_metrics": {
235
- "mean_fluency": mean_fluency,
236
  "fluency_percentage": fluency_percentage,
237
- "fluent_frames_ratio": fluent_frames_ratio,
238
  "fluent_frames_percentage": fluent_frames_percentage,
239
- "std": fluency_metrics.get("std", 0.0),
240
- "min": fluency_metrics.get("min", 0.0),
241
- "max": fluency_metrics.get("max", 0.0),
242
- "median": fluency_metrics.get("median", 0.0)
243
  },
244
  "articulation_results": {
245
- "total_frames": len(result.articulation_scores),
246
- "frame_duration_ms": result.frame_duration_ms,
247
- "scores": result.articulation_scores[:10] # First 10 frames for preview
 
248
  },
249
- "confidence": result.confidence,
250
  "confidence_percentage": confidence_percentage,
251
  "processing_time_ms": processing_time_ms,
252
- "timestamps": result.timestamps[:10] if result.timestamps else []
 
 
 
 
 
 
 
 
 
 
253
  }
254
 
255
  logger.info(f"✅ Analysis complete: fluency={fluency_percentage:.1f}%, "
 
163
  try:
164
  with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
165
  import json
166
+ f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:139","message":"After predict_batch call","data":{"success":True,"num_frames":result.num_frames},"timestamp":int(time.time()*1000)}) + '\n')
167
  except: pass
168
  # #endregion
169
 
170
  # Calculate processing time
171
  processing_time_ms = (time.time() - start_time) * 1000
172
 
173
+ # Extract metrics from new PhoneLevelResult format
174
+ aggregate = result.aggregate
175
+ mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
176
+ fluency_percentage = (1.0 - mean_fluency_stutter) * 100 # Convert stutter prob to fluency percentage
177
 
178
+ # Count fluent frames
179
+ fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
180
+ fluent_frames_percentage = (fluent_frames / result.num_frames * 100) if result.num_frames > 0 else 0.0
181
 
182
  # Format fluency score with color coding
183
  if fluency_percentage >= 80:
 
203
  """
204
 
205
  # Format articulation issues
206
+ articulation_class = aggregate.get("articulation_class", 0)
207
+ articulation_label = aggregate.get("articulation_label", "normal")
208
+ articulation_text = f"**Dominant Class:** {articulation_label.capitalize()}\n\n"
209
+ articulation_text += f"**Frame Breakdown:**\n"
210
+ class_counts = {}
211
+ for fp in result.frame_predictions:
212
+ label = fp.articulation_label
213
+ class_counts[label] = class_counts.get(label, 0) + 1
214
+ for label, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True):
215
+ percentage = (count / result.num_frames * 100) if result.num_frames > 0 else 0.0
216
+ articulation_text += f"- {label.capitalize()}: {count} frames ({percentage:.1f}%)\n"
217
+
218
+ # Calculate average confidence
219
+ avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
220
+ confidence_percentage = avg_confidence * 100
221
 
 
 
222
  confidence_html = f"""
223
  <div style='text-align: center; padding: 10px;'>
224
  <h3 style='color: #2196F3; font-size: 32px; margin: 5px 0;'>
 
235
  ⏱️ Processing Time: <strong>{processing_time_ms:.0f}ms</strong>
236
  </p>
237
  <p style='color: #999; font-size: 12px;'>
238
+ Analyzed {result.num_frames} frames ({result.duration:.2f}s audio)
239
  </p>
240
  </div>
241
  """
 
244
  json_output = {
245
  "status": "success",
246
  "fluency_metrics": {
247
+ "mean_fluency": fluency_percentage / 100.0,
248
  "fluency_percentage": fluency_percentage,
249
+ "fluent_frames_ratio": fluent_frames / result.num_frames if result.num_frames > 0 else 0.0,
250
  "fluent_frames_percentage": fluent_frames_percentage,
251
+ "stutter_probability": mean_fluency_stutter
 
 
 
252
  },
253
  "articulation_results": {
254
+ "total_frames": result.num_frames,
255
+ "dominant_class": articulation_class,
256
+ "dominant_label": articulation_label,
257
+ "class_distribution": class_counts
258
  },
259
+ "confidence": avg_confidence,
260
  "confidence_percentage": confidence_percentage,
261
  "processing_time_ms": processing_time_ms,
262
+ "frame_predictions": [
263
+ {
264
+ "time": fp.time,
265
+ "fluency_prob": fp.fluency_prob,
266
+ "fluency_label": fp.fluency_label,
267
+ "articulation_class": fp.articulation_class,
268
+ "articulation_label": fp.articulation_label,
269
+ "confidence": fp.confidence
270
+ }
271
+ for fp in result.frame_predictions[:20] # First 20 frames for preview
272
+ ]
273
  }
274
 
275
  logger.info(f"✅ Analysis complete: fluency={fluency_percentage:.1f}%, "