New: implemented many, many changes. 10% Phone-level detection: WORKING
Browse files- Docs/QUICK_START.md +2 -0
- api/__init__.py +4 -0
- api/routes.py +389 -0
- api/schemas.py +115 -0
- api/streaming.py +305 -0
- app.py +47 -15
- config.py +12 -0
- data/therapy_recommendations.json +60 -0
- diagnosis/ai_engine/detect_stuttering.py +3 -0
- inference/inference_pipeline.py +250 -336
- models/error_taxonomy.py +333 -0
- models/phoneme_mapper.py +364 -0
- models/speech_pathology_model.py +64 -16
- tests/__init__.py +4 -0
- tests/integration_tests.py +249 -0
- tests/performance_tests.py +344 -0
- ui/gradio_interface.py +43 -23
Docs/QUICK_START.md
CHANGED
|
@@ -327,6 +327,7 @@ with Pool(4) as pool:
|
|
| 327 |
## API Reference
|
| 328 |
|
| 329 |
### Main Method
|
|
|
|
| 330 |
```python
|
| 331 |
analyze_audio(
|
| 332 |
audio_path: str, # Path to .wav file
|
|
@@ -336,6 +337,7 @@ analyze_audio(
|
|
| 336 |
```
|
| 337 |
|
| 338 |
### Utility Methods
|
|
|
|
| 339 |
```python
|
| 340 |
# Phonetic similarity (0-1)
|
| 341 |
_calculate_phonetic_similarity(char1: str, char2: str) -> float
|
|
|
|
| 327 |
## API Reference
|
| 328 |
|
| 329 |
### Main Method
|
| 330 |
+
|
| 331 |
```python
|
| 332 |
analyze_audio(
|
| 333 |
audio_path: str, # Path to .wav file
|
|
|
|
| 337 |
```
|
| 338 |
|
| 339 |
### Utility Methods
|
| 340 |
+
|
| 341 |
```python
|
| 342 |
# Phonetic similarity (0-1)
|
| 343 |
_calculate_phonetic_similarity(char1: str, char2: str) -> float
|
api/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API module for speech pathology diagnosis endpoints.
|
| 3 |
+
"""
|
| 4 |
+
|
api/routes.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
REST API routes for Speech Pathology Diagnosis.
|
| 3 |
+
|
| 4 |
+
This module provides FastAPI endpoints for batch file analysis,
|
| 5 |
+
session management, and health checks.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
import tempfile
|
| 12 |
+
import uuid
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Optional, List, Dict, Any
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
|
| 17 |
+
from fastapi import APIRouter, UploadFile, File, HTTPException, Query
|
| 18 |
+
from fastapi.responses import JSONResponse
|
| 19 |
+
|
| 20 |
+
from api.schemas import (
|
| 21 |
+
BatchDiagnosisResponse,
|
| 22 |
+
FrameDiagnosis,
|
| 23 |
+
ErrorReport,
|
| 24 |
+
SummaryMetrics,
|
| 25 |
+
SessionListResponse,
|
| 26 |
+
HealthResponse,
|
| 27 |
+
ErrorDetailSchema,
|
| 28 |
+
FluencyInfo,
|
| 29 |
+
ArticulationInfo
|
| 30 |
+
)
|
| 31 |
+
from models.phoneme_mapper import PhonemeMapper
|
| 32 |
+
from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel
|
| 33 |
+
from inference.inference_pipeline import InferencePipeline
|
| 34 |
+
from config import AudioConfig, default_audio_config
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
# Create router
|
| 39 |
+
router = APIRouter(prefix="/diagnose", tags=["diagnosis"])
|
| 40 |
+
|
| 41 |
+
# In-memory session storage (in production, use Redis or database)
|
| 42 |
+
sessions: Dict[str, BatchDiagnosisResponse] = {}
|
| 43 |
+
|
| 44 |
+
# Global instances (will be injected)
|
| 45 |
+
inference_pipeline: Optional[InferencePipeline] = None
|
| 46 |
+
phoneme_mapper: Optional[PhonemeMapper] = None
|
| 47 |
+
error_mapper: Optional[ErrorMapper] = None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def initialize_routes(
|
| 51 |
+
pipeline: InferencePipeline,
|
| 52 |
+
mapper: Optional[PhonemeMapper] = None,
|
| 53 |
+
error_mapper_instance: Optional[ErrorMapper] = None
|
| 54 |
+
):
|
| 55 |
+
"""
|
| 56 |
+
Initialize routes with dependencies.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
pipeline: InferencePipeline instance
|
| 60 |
+
mapper: Optional PhonemeMapper instance
|
| 61 |
+
error_mapper_instance: Optional ErrorMapper instance
|
| 62 |
+
"""
|
| 63 |
+
global inference_pipeline, phoneme_mapper, error_mapper
|
| 64 |
+
|
| 65 |
+
inference_pipeline = pipeline
|
| 66 |
+
|
| 67 |
+
if mapper is None:
|
| 68 |
+
try:
|
| 69 |
+
phoneme_mapper = PhonemeMapper(
|
| 70 |
+
frame_duration_ms=default_audio_config.chunk_duration_ms,
|
| 71 |
+
sample_rate=default_audio_config.sample_rate
|
| 72 |
+
)
|
| 73 |
+
logger.info("✅ PhonemeMapper initialized")
|
| 74 |
+
except Exception as e:
|
| 75 |
+
logger.warning(f"⚠️ PhonemeMapper not available: {e}")
|
| 76 |
+
phoneme_mapper = None
|
| 77 |
+
|
| 78 |
+
if error_mapper_instance is None:
|
| 79 |
+
try:
|
| 80 |
+
error_mapper = ErrorMapper()
|
| 81 |
+
logger.info("✅ ErrorMapper initialized")
|
| 82 |
+
except Exception as e:
|
| 83 |
+
logger.error(f"❌ ErrorMapper failed to initialize: {e}")
|
| 84 |
+
error_mapper = None
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@router.post("/file", response_model=BatchDiagnosisResponse)
|
| 88 |
+
async def diagnose_file(
|
| 89 |
+
audio: UploadFile = File(...),
|
| 90 |
+
text: Optional[str] = Query(None, description="Expected text/transcript for phoneme mapping"),
|
| 91 |
+
session_id: Optional[str] = Query(None, description="Optional session ID")
|
| 92 |
+
):
|
| 93 |
+
"""
|
| 94 |
+
Analyze audio file for speech pathology errors.
|
| 95 |
+
|
| 96 |
+
Performs complete phoneme-level analysis:
|
| 97 |
+
- Extracts Wav2Vec2 features
|
| 98 |
+
- Classifies fluency and articulation per frame
|
| 99 |
+
- Maps phonemes to frames
|
| 100 |
+
- Detects errors and generates therapy recommendations
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
audio: Audio file (WAV, MP3, etc.)
|
| 104 |
+
text: Optional expected text for phoneme mapping
|
| 105 |
+
session_id: Optional session ID (auto-generated if not provided)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
BatchDiagnosisResponse with detailed error analysis
|
| 109 |
+
"""
|
| 110 |
+
if inference_pipeline is None:
|
| 111 |
+
raise HTTPException(status_code=503, detail="Inference pipeline not loaded")
|
| 112 |
+
|
| 113 |
+
start_time = time.time()
|
| 114 |
+
|
| 115 |
+
# Generate session ID
|
| 116 |
+
if not session_id:
|
| 117 |
+
session_id = str(uuid.uuid4())
|
| 118 |
+
|
| 119 |
+
# Save uploaded file
|
| 120 |
+
temp_file = None
|
| 121 |
+
try:
|
| 122 |
+
# Create temp file
|
| 123 |
+
temp_dir = tempfile.gettempdir()
|
| 124 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 125 |
+
temp_file = os.path.join(temp_dir, f"diagnosis_{session_id}_{audio.filename}")
|
| 126 |
+
|
| 127 |
+
# Save file
|
| 128 |
+
content = await audio.read()
|
| 129 |
+
with open(temp_file, "wb") as f:
|
| 130 |
+
f.write(content)
|
| 131 |
+
|
| 132 |
+
file_size_mb = len(content) / 1024 / 1024
|
| 133 |
+
logger.info(f"📂 Saved file: {temp_file} ({file_size_mb:.2f} MB)")
|
| 134 |
+
|
| 135 |
+
# Run inference
|
| 136 |
+
logger.info("🔄 Running phone-level inference...")
|
| 137 |
+
result = inference_pipeline.predict_phone_level(
|
| 138 |
+
temp_file,
|
| 139 |
+
return_timestamps=True
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Map phonemes to frames if text provided
|
| 143 |
+
frame_phonemes = []
|
| 144 |
+
if text and phoneme_mapper:
|
| 145 |
+
try:
|
| 146 |
+
frame_phonemes = phoneme_mapper.map_text_to_frames(
|
| 147 |
+
text,
|
| 148 |
+
num_frames=result.num_frames,
|
| 149 |
+
audio_duration=result.duration
|
| 150 |
+
)
|
| 151 |
+
logger.info(f"✅ Mapped {len(frame_phonemes)} phonemes to frames")
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.warning(f"⚠️ Phoneme mapping failed: {e}, using empty phonemes")
|
| 154 |
+
frame_phonemes = [''] * result.num_frames
|
| 155 |
+
else:
|
| 156 |
+
frame_phonemes = [''] * result.num_frames
|
| 157 |
+
if not text:
|
| 158 |
+
logger.warning("⚠️ No text provided, phoneme mapping skipped")
|
| 159 |
+
|
| 160 |
+
# Process frame predictions with error mapping
|
| 161 |
+
frame_diagnoses = []
|
| 162 |
+
error_reports = []
|
| 163 |
+
error_count = 0
|
| 164 |
+
|
| 165 |
+
for i, frame_pred in enumerate(result.frame_predictions):
|
| 166 |
+
# Get phoneme for this frame
|
| 167 |
+
phoneme = frame_phonemes[i] if i < len(frame_phonemes) else ''
|
| 168 |
+
|
| 169 |
+
# Map classifier output to error detail
|
| 170 |
+
# Combine fluency and articulation into 8-class system
|
| 171 |
+
# Class = articulation_class * 2 + (1 if stutter else 0)
|
| 172 |
+
class_id = frame_pred.articulation_class
|
| 173 |
+
if frame_pred.fluency_label == 'stutter':
|
| 174 |
+
class_id += 4 # Add 4 for stutter classes (4-7)
|
| 175 |
+
|
| 176 |
+
# Get error detail
|
| 177 |
+
error_detail = None
|
| 178 |
+
if error_mapper:
|
| 179 |
+
try:
|
| 180 |
+
error_detail_obj = error_mapper.map_classifier_output(
|
| 181 |
+
class_id=class_id,
|
| 182 |
+
confidence=frame_pred.confidence,
|
| 183 |
+
phoneme=phoneme if phoneme else 'unknown',
|
| 184 |
+
fluency_label=frame_pred.fluency_label
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Add frame index
|
| 188 |
+
error_detail_obj.frame_indices = [i]
|
| 189 |
+
|
| 190 |
+
# Convert to schema
|
| 191 |
+
if error_detail_obj.error_type != ErrorType.NORMAL:
|
| 192 |
+
error_detail = ErrorDetailSchema(
|
| 193 |
+
phoneme=error_detail_obj.phoneme,
|
| 194 |
+
error_type=error_detail_obj.error_type.value,
|
| 195 |
+
wrong_sound=error_detail_obj.wrong_sound,
|
| 196 |
+
severity=error_detail_obj.severity,
|
| 197 |
+
confidence=error_detail_obj.confidence,
|
| 198 |
+
therapy=error_detail_obj.therapy,
|
| 199 |
+
frame_indices=[i]
|
| 200 |
+
)
|
| 201 |
+
error_count += 1
|
| 202 |
+
|
| 203 |
+
# Create error report
|
| 204 |
+
severity_level = error_mapper.get_severity_level(error_detail_obj.severity)
|
| 205 |
+
error_reports.append(ErrorReport(
|
| 206 |
+
frame_id=i,
|
| 207 |
+
timestamp=frame_pred.time,
|
| 208 |
+
phoneme=error_detail_obj.phoneme,
|
| 209 |
+
error=error_detail,
|
| 210 |
+
severity_level=severity_level.value
|
| 211 |
+
))
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.warning(f"Error mapping failed for frame {i}: {e}")
|
| 214 |
+
|
| 215 |
+
# Create frame diagnosis
|
| 216 |
+
severity_level_str = "none"
|
| 217 |
+
if error_detail:
|
| 218 |
+
severity_level_str = error_mapper.get_severity_level(error_detail.severity).value if error_mapper else "none"
|
| 219 |
+
|
| 220 |
+
frame_diagnoses.append(FrameDiagnosis(
|
| 221 |
+
frame_id=i,
|
| 222 |
+
timestamp=frame_pred.time,
|
| 223 |
+
phoneme=phoneme if phoneme else 'unknown',
|
| 224 |
+
fluency=FluencyInfo(
|
| 225 |
+
label=frame_pred.fluency_label,
|
| 226 |
+
confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob)
|
| 227 |
+
),
|
| 228 |
+
articulation=ArticulationInfo(
|
| 229 |
+
label=frame_pred.articulation_label,
|
| 230 |
+
confidence=frame_pred.confidence,
|
| 231 |
+
class_id=frame_pred.articulation_class
|
| 232 |
+
),
|
| 233 |
+
error=error_detail,
|
| 234 |
+
severity_level=severity_level_str,
|
| 235 |
+
confidence=frame_pred.confidence
|
| 236 |
+
))
|
| 237 |
+
|
| 238 |
+
# Calculate summary metrics
|
| 239 |
+
fluency_scores = [1.0 - fp.fluency_prob for fp in result.frame_predictions] # Convert stutter prob to fluency
|
| 240 |
+
avg_fluency = sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0.0
|
| 241 |
+
|
| 242 |
+
# Articulation score: percentage of normal frames
|
| 243 |
+
normal_frames = sum(1 for fp in result.frame_predictions if fp.articulation_class == 0)
|
| 244 |
+
articulation_score = normal_frames / result.num_frames if result.num_frames > 0 else 0.0
|
| 245 |
+
|
| 246 |
+
summary = SummaryMetrics(
|
| 247 |
+
fluency_score=avg_fluency,
|
| 248 |
+
fluency_percentage=avg_fluency * 100.0,
|
| 249 |
+
articulation_score=articulation_score,
|
| 250 |
+
error_count=error_count,
|
| 251 |
+
error_rate=error_count / result.num_frames if result.num_frames > 0 else 0.0
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Generate therapy plan (unique therapy recommendations)
|
| 255 |
+
therapy_plan = []
|
| 256 |
+
if error_mapper:
|
| 257 |
+
seen_therapies = set()
|
| 258 |
+
for error_report in error_reports:
|
| 259 |
+
if error_report.error.therapy and error_report.error.therapy not in seen_therapies:
|
| 260 |
+
therapy_plan.append(error_report.error.therapy)
|
| 261 |
+
seen_therapies.add(error_report.error.therapy)
|
| 262 |
+
|
| 263 |
+
processing_time_ms = (time.time() - start_time) * 1000
|
| 264 |
+
|
| 265 |
+
# Create response
|
| 266 |
+
response = BatchDiagnosisResponse(
|
| 267 |
+
session_id=session_id,
|
| 268 |
+
filename=audio.filename or "unknown",
|
| 269 |
+
duration=result.duration,
|
| 270 |
+
total_frames=result.num_frames,
|
| 271 |
+
error_count=error_count,
|
| 272 |
+
errors=error_reports,
|
| 273 |
+
frame_diagnoses=frame_diagnoses,
|
| 274 |
+
summary=summary,
|
| 275 |
+
therapy_plan=therapy_plan,
|
| 276 |
+
processing_time_ms=processing_time_ms,
|
| 277 |
+
created_at=datetime.now()
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Store in sessions
|
| 281 |
+
sessions[session_id] = response
|
| 282 |
+
|
| 283 |
+
logger.info(f"✅ Diagnosis complete: {error_count} errors, {processing_time_ms:.0f}ms")
|
| 284 |
+
|
| 285 |
+
return response
|
| 286 |
+
|
| 287 |
+
except HTTPException:
|
| 288 |
+
raise
|
| 289 |
+
except Exception as e:
|
| 290 |
+
logger.error(f"❌ Diagnosis failed: {e}", exc_info=True)
|
| 291 |
+
raise HTTPException(status_code=500, detail=f"Diagnosis failed: {str(e)}")
|
| 292 |
+
|
| 293 |
+
finally:
|
| 294 |
+
# Cleanup temp file
|
| 295 |
+
if temp_file and os.path.exists(temp_file):
|
| 296 |
+
try:
|
| 297 |
+
os.remove(temp_file)
|
| 298 |
+
logger.debug(f"🧹 Cleaned up: {temp_file}")
|
| 299 |
+
except Exception as e:
|
| 300 |
+
logger.warning(f"Could not clean up {temp_file}: {e}")
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@router.get("/results/{session_id}", response_model=BatchDiagnosisResponse)
|
| 304 |
+
async def get_results(session_id: str):
|
| 305 |
+
"""
|
| 306 |
+
Get cached diagnosis results for a session.
|
| 307 |
+
|
| 308 |
+
Args:
|
| 309 |
+
session_id: Session identifier
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
BatchDiagnosisResponse
|
| 313 |
+
"""
|
| 314 |
+
if session_id not in sessions:
|
| 315 |
+
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
| 316 |
+
|
| 317 |
+
return sessions[session_id]
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@router.get("/results", response_model=SessionListResponse)
|
| 321 |
+
async def list_results(limit: int = Query(10, ge=1, le=100)):
|
| 322 |
+
"""
|
| 323 |
+
List all cached diagnosis sessions.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
limit: Maximum number of sessions to return
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
SessionListResponse with session metadata
|
| 330 |
+
"""
|
| 331 |
+
session_list = []
|
| 332 |
+
for sid, response in list(sessions.items())[:limit]:
|
| 333 |
+
session_list.append({
|
| 334 |
+
"session_id": sid,
|
| 335 |
+
"filename": response.filename,
|
| 336 |
+
"duration": response.duration,
|
| 337 |
+
"error_count": response.error_count,
|
| 338 |
+
"created_at": response.created_at.isoformat(),
|
| 339 |
+
"processing_time_ms": response.processing_time_ms
|
| 340 |
+
})
|
| 341 |
+
|
| 342 |
+
return SessionListResponse(
|
| 343 |
+
sessions=session_list,
|
| 344 |
+
total=len(sessions)
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
@router.delete("/results/{session_id}")
|
| 349 |
+
async def delete_results(session_id: str):
|
| 350 |
+
"""
|
| 351 |
+
Delete cached diagnosis results for a session.
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
session_id: Session identifier
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
Success message
|
| 358 |
+
"""
|
| 359 |
+
if session_id not in sessions:
|
| 360 |
+
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
| 361 |
+
|
| 362 |
+
del sessions[session_id]
|
| 363 |
+
logger.info(f"🗑️ Deleted session: {session_id}")
|
| 364 |
+
|
| 365 |
+
return {"status": "success", "message": f"Session {session_id} deleted"}
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
@router.get("/health", response_model=HealthResponse)
|
| 369 |
+
async def health_check():
|
| 370 |
+
"""
|
| 371 |
+
Health check endpoint.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
HealthResponse with service status
|
| 375 |
+
"""
|
| 376 |
+
import time
|
| 377 |
+
start_time = getattr(health_check, '_start_time', time.time())
|
| 378 |
+
if not hasattr(health_check, '_start_time'):
|
| 379 |
+
health_check._start_time = start_time
|
| 380 |
+
|
| 381 |
+
uptime = time.time() - start_time
|
| 382 |
+
|
| 383 |
+
return HealthResponse(
|
| 384 |
+
status="healthy" if inference_pipeline is not None else "degraded",
|
| 385 |
+
version="2.0.0",
|
| 386 |
+
model_loaded=inference_pipeline is not None,
|
| 387 |
+
uptime_seconds=uptime
|
| 388 |
+
)
|
| 389 |
+
|
api/schemas.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic schemas for Speech Pathology Diagnosis API.
|
| 3 |
+
|
| 4 |
+
This module defines request and response models for REST API and WebSocket endpoints.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import List, Optional, Dict, Any
|
| 8 |
+
from pydantic import BaseModel, Field
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class FluencyInfo(BaseModel):
|
| 13 |
+
"""Fluency classification information."""
|
| 14 |
+
label: str = Field(..., description="Fluency label: 'normal' or 'stutter'")
|
| 15 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score (0-1)")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ArticulationInfo(BaseModel):
|
| 19 |
+
"""Articulation classification information."""
|
| 20 |
+
label: str = Field(..., description="Articulation label: 'normal', 'substitution', 'omission', 'distortion'")
|
| 21 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score (0-1)")
|
| 22 |
+
class_id: int = Field(..., ge=0, le=3, description="Class ID: 0=normal, 1=substitution, 2=omission, 3=distortion")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ErrorDetailSchema(BaseModel):
|
| 26 |
+
"""Error detail schema for API responses."""
|
| 27 |
+
phoneme: str = Field(..., description="Expected phoneme symbol")
|
| 28 |
+
error_type: str = Field(..., description="Error type: normal, substitution, omission, distortion")
|
| 29 |
+
wrong_sound: Optional[str] = Field(None, description="For substitutions, the incorrect phoneme produced")
|
| 30 |
+
severity: float = Field(..., ge=0.0, le=1.0, description="Severity score (0-1)")
|
| 31 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Model confidence (0-1)")
|
| 32 |
+
therapy: str = Field(..., description="Therapy recommendation")
|
| 33 |
+
frame_indices: List[int] = Field(default_factory=list, description="Frame indices where error occurs")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class FrameDiagnosis(BaseModel):
|
| 37 |
+
"""Diagnosis for a single frame."""
|
| 38 |
+
frame_id: int = Field(..., description="Frame index")
|
| 39 |
+
timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
|
| 40 |
+
phoneme: str = Field(..., description="Expected phoneme for this frame")
|
| 41 |
+
fluency: FluencyInfo = Field(..., description="Fluency classification")
|
| 42 |
+
articulation: ArticulationInfo = Field(..., description="Articulation classification")
|
| 43 |
+
error: Optional[ErrorDetailSchema] = Field(None, description="Error details if error detected")
|
| 44 |
+
severity_level: str = Field(..., description="Severity level: none, low, medium, high")
|
| 45 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ErrorReport(BaseModel):
|
| 49 |
+
"""Detailed error report for a frame."""
|
| 50 |
+
frame_id: int = Field(..., description="Frame index")
|
| 51 |
+
timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
|
| 52 |
+
phoneme: str = Field(..., description="Expected phoneme")
|
| 53 |
+
error: ErrorDetailSchema = Field(..., description="Error details")
|
| 54 |
+
severity_level: str = Field(..., description="Severity level: none, low, medium, high")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SummaryMetrics(BaseModel):
|
| 58 |
+
"""Summary metrics for the analysis."""
|
| 59 |
+
fluency_score: float = Field(..., ge=0.0, le=1.0, description="Average fluency score (0=stutter, 1=normal)")
|
| 60 |
+
fluency_percentage: float = Field(..., ge=0.0, le=100.0, description="Fluency percentage")
|
| 61 |
+
articulation_score: float = Field(..., ge=0.0, le=1.0, description="Average articulation correctness")
|
| 62 |
+
error_count: int = Field(..., ge=0, description="Total number of errors detected")
|
| 63 |
+
error_rate: float = Field(..., ge=0.0, le=1.0, description="Error rate (errors/total_frames)")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class BatchDiagnosisResponse(BaseModel):
|
| 67 |
+
"""Response for batch file diagnosis."""
|
| 68 |
+
session_id: str = Field(..., description="Session identifier")
|
| 69 |
+
filename: str = Field(..., description="Processed filename")
|
| 70 |
+
duration: float = Field(..., ge=0.0, description="Audio duration in seconds")
|
| 71 |
+
total_frames: int = Field(..., ge=0, description="Total number of frames analyzed")
|
| 72 |
+
error_count: int = Field(..., ge=0, description="Number of errors detected")
|
| 73 |
+
errors: List[ErrorReport] = Field(default_factory=list, description="List of error reports")
|
| 74 |
+
frame_diagnoses: List[FrameDiagnosis] = Field(default_factory=list, description="All frame diagnoses")
|
| 75 |
+
summary: SummaryMetrics = Field(..., description="Summary metrics")
|
| 76 |
+
therapy_plan: List[str] = Field(default_factory=list, description="Therapy recommendations")
|
| 77 |
+
processing_time_ms: float = Field(..., ge=0.0, description="Processing time in milliseconds")
|
| 78 |
+
created_at: datetime = Field(default_factory=datetime.now, description="Analysis timestamp")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class StreamingDiagnosisRequest(BaseModel):
|
| 82 |
+
"""Request for streaming diagnosis."""
|
| 83 |
+
audio_chunk: bytes = Field(..., description="Audio chunk data (320 samples for 20ms @ 16kHz)")
|
| 84 |
+
sample_rate: int = Field(16000, description="Sample rate in Hz")
|
| 85 |
+
session_id: str = Field(..., description="Session identifier")
|
| 86 |
+
frame_index: Optional[int] = Field(None, description="Frame index for tracking")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class StreamingDiagnosisResponse(BaseModel):
|
| 90 |
+
"""Response for streaming diagnosis (single frame)."""
|
| 91 |
+
session_id: str = Field(..., description="Session identifier")
|
| 92 |
+
frame_id: int = Field(..., description="Frame index")
|
| 93 |
+
timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
|
| 94 |
+
phoneme: str = Field(..., description="Expected phoneme")
|
| 95 |
+
fluency: FluencyInfo = Field(..., description="Fluency classification")
|
| 96 |
+
articulation: ArticulationInfo = Field(..., description="Articulation classification")
|
| 97 |
+
error: Optional[ErrorDetailSchema] = Field(None, description="Error details if error detected")
|
| 98 |
+
severity_level: str = Field(..., description="Severity level")
|
| 99 |
+
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
|
| 100 |
+
latency_ms: float = Field(..., ge=0.0, description="Processing latency in milliseconds")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SessionListResponse(BaseModel):
|
| 104 |
+
"""Response for listing sessions."""
|
| 105 |
+
sessions: List[Dict[str, Any]] = Field(..., description="List of session metadata")
|
| 106 |
+
total: int = Field(..., ge=0, description="Total number of sessions")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class HealthResponse(BaseModel):
|
| 110 |
+
"""Health check response."""
|
| 111 |
+
status: str = Field(..., description="Service status")
|
| 112 |
+
version: str = Field(..., description="API version")
|
| 113 |
+
model_loaded: bool = Field(..., description="Whether model is loaded")
|
| 114 |
+
uptime_seconds: float = Field(..., ge=0.0, description="Service uptime in seconds")
|
| 115 |
+
|
api/streaming.py
ADDED
|
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WebSocket streaming for real-time speech pathology diagnosis.
|
| 3 |
+
|
| 4 |
+
This module provides WebSocket endpoint for streaming audio analysis
|
| 5 |
+
with <50ms latency per frame requirement.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import time
|
| 10 |
+
import uuid
|
| 11 |
+
import numpy as np
|
| 12 |
+
from typing import Optional, Dict
|
| 13 |
+
from collections import deque
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
|
| 16 |
+
from fastapi import WebSocket, WebSocketDisconnect, HTTPException
|
| 17 |
+
|
| 18 |
+
from api.schemas import StreamingDiagnosisResponse, FluencyInfo, ArticulationInfo, ErrorDetailSchema
|
| 19 |
+
from models.phoneme_mapper import PhonemeMapper
|
| 20 |
+
from models.error_taxonomy import ErrorMapper, ErrorType
|
| 21 |
+
from inference.inference_pipeline import InferencePipeline
|
| 22 |
+
from config import AudioConfig, default_audio_config
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class StreamingBuffer:
|
| 28 |
+
"""
|
| 29 |
+
Buffer for managing sliding window in streaming audio.
|
| 30 |
+
|
| 31 |
+
Maintains a buffer of audio samples and provides frames
|
| 32 |
+
for processing with overlap management.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, window_size_samples: int, hop_size_samples: int):
|
| 36 |
+
"""
|
| 37 |
+
Initialize streaming buffer.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
window_size_samples: Size of analysis window in samples
|
| 41 |
+
hop_size_samples: Hop size between windows in samples
|
| 42 |
+
"""
|
| 43 |
+
self.window_size_samples = window_size_samples
|
| 44 |
+
self.hop_size_samples = hop_size_samples
|
| 45 |
+
self.buffer = deque(maxlen=window_size_samples + hop_size_samples)
|
| 46 |
+
self.frame_index = 0
|
| 47 |
+
|
| 48 |
+
logger.debug(f"StreamingBuffer initialized: window={window_size_samples}, hop={hop_size_samples}")
|
| 49 |
+
|
| 50 |
+
def add_chunk(self, audio_chunk: np.ndarray) -> bool:
|
| 51 |
+
"""
|
| 52 |
+
Add audio chunk to buffer.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
audio_chunk: Audio samples to add
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
True if buffer has enough data for a frame, False otherwise
|
| 59 |
+
"""
|
| 60 |
+
self.buffer.extend(audio_chunk)
|
| 61 |
+
return len(self.buffer) >= self.window_size_samples
|
| 62 |
+
|
| 63 |
+
def get_frame(self) -> Optional[np.ndarray]:
|
| 64 |
+
"""
|
| 65 |
+
Get current frame from buffer.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Audio frame array if ready, None otherwise
|
| 69 |
+
"""
|
| 70 |
+
if len(self.buffer) < self.window_size_samples:
|
| 71 |
+
return None
|
| 72 |
+
|
| 73 |
+
# Extract window (last window_size_samples)
|
| 74 |
+
frame = np.array(list(self.buffer)[-self.window_size_samples:])
|
| 75 |
+
return frame
|
| 76 |
+
|
| 77 |
+
def slide(self):
|
| 78 |
+
"""Advance buffer by hop size."""
|
| 79 |
+
# Remove oldest hop_size_samples
|
| 80 |
+
for _ in range(min(self.hop_size_samples, len(self.buffer))):
|
| 81 |
+
if self.buffer:
|
| 82 |
+
self.buffer.popleft()
|
| 83 |
+
self.frame_index += 1
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# Global instances (will be injected)
|
| 87 |
+
inference_pipeline: Optional[InferencePipeline] = None
|
| 88 |
+
phoneme_mapper: Optional[PhonemeMapper] = None
|
| 89 |
+
error_mapper: Optional[ErrorMapper] = None
|
| 90 |
+
|
| 91 |
+
# Active streaming sessions
|
| 92 |
+
streaming_sessions: Dict[str, Dict] = {}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def initialize_streaming(
|
| 96 |
+
pipeline: InferencePipeline,
|
| 97 |
+
mapper: Optional[PhonemeMapper] = None,
|
| 98 |
+
error_mapper_instance: Optional[ErrorMapper] = None
|
| 99 |
+
):
|
| 100 |
+
"""
|
| 101 |
+
Initialize streaming with dependencies.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
pipeline: InferencePipeline instance
|
| 105 |
+
mapper: Optional PhonemeMapper instance
|
| 106 |
+
error_mapper_instance: Optional ErrorMapper instance
|
| 107 |
+
"""
|
| 108 |
+
global inference_pipeline, phoneme_mapper, error_mapper
|
| 109 |
+
|
| 110 |
+
inference_pipeline = pipeline
|
| 111 |
+
|
| 112 |
+
if mapper is None:
|
| 113 |
+
try:
|
| 114 |
+
phoneme_mapper = PhonemeMapper(
|
| 115 |
+
frame_duration_ms=default_audio_config.chunk_duration_ms,
|
| 116 |
+
sample_rate=default_audio_config.sample_rate
|
| 117 |
+
)
|
| 118 |
+
logger.info("✅ PhonemeMapper initialized for streaming")
|
| 119 |
+
except Exception as e:
|
| 120 |
+
logger.warning(f"⚠️ PhonemeMapper not available: {e}")
|
| 121 |
+
phoneme_mapper = None
|
| 122 |
+
|
| 123 |
+
if error_mapper_instance is None:
|
| 124 |
+
try:
|
| 125 |
+
error_mapper = ErrorMapper()
|
| 126 |
+
logger.info("✅ ErrorMapper initialized for streaming")
|
| 127 |
+
except Exception as e:
|
| 128 |
+
logger.error(f"❌ ErrorMapper failed to initialize: {e}")
|
| 129 |
+
error_mapper = None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
async def handle_streaming_websocket(websocket: WebSocket, session_id: Optional[str] = None):
|
| 133 |
+
"""
|
| 134 |
+
Handle WebSocket connection for streaming diagnosis.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
websocket: WebSocket connection
|
| 138 |
+
session_id: Optional session ID (auto-generated if not provided)
|
| 139 |
+
"""
|
| 140 |
+
if inference_pipeline is None:
|
| 141 |
+
await websocket.close(code=1003, reason="Inference pipeline not loaded")
|
| 142 |
+
return
|
| 143 |
+
|
| 144 |
+
# Generate session ID
|
| 145 |
+
if not session_id:
|
| 146 |
+
session_id = str(uuid.uuid4())
|
| 147 |
+
|
| 148 |
+
# Accept connection
|
| 149 |
+
await websocket.accept()
|
| 150 |
+
logger.info(f"🔌 WebSocket connected: session_id={session_id}")
|
| 151 |
+
|
| 152 |
+
# Initialize buffer
|
| 153 |
+
window_size_samples = int(
|
| 154 |
+
inference_pipeline.inference_config.window_size_ms *
|
| 155 |
+
inference_pipeline.audio_config.sample_rate / 1000
|
| 156 |
+
)
|
| 157 |
+
hop_size_samples = int(
|
| 158 |
+
inference_pipeline.inference_config.hop_size_ms *
|
| 159 |
+
inference_pipeline.audio_config.sample_rate / 1000
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
buffer = StreamingBuffer(window_size_samples, hop_size_samples)
|
| 163 |
+
|
| 164 |
+
# Session metadata
|
| 165 |
+
streaming_sessions[session_id] = {
|
| 166 |
+
"session_id": session_id,
|
| 167 |
+
"connected_at": datetime.now(),
|
| 168 |
+
"frame_count": 0,
|
| 169 |
+
"total_latency_ms": 0.0
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
frame_index = 0
|
| 173 |
+
start_time = time.time()
|
| 174 |
+
|
| 175 |
+
try:
|
| 176 |
+
while True:
|
| 177 |
+
# Receive audio chunk
|
| 178 |
+
try:
|
| 179 |
+
data = await websocket.receive_bytes()
|
| 180 |
+
|
| 181 |
+
# Convert bytes to numpy array
|
| 182 |
+
# Assuming 16-bit PCM, mono, 16kHz
|
| 183 |
+
audio_chunk = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
|
| 184 |
+
|
| 185 |
+
# Add to buffer
|
| 186 |
+
buffer.add_chunk(audio_chunk)
|
| 187 |
+
|
| 188 |
+
# Process if buffer is ready
|
| 189 |
+
if buffer.get_frame() is not None:
|
| 190 |
+
frame_start_time = time.time()
|
| 191 |
+
|
| 192 |
+
# Get frame
|
| 193 |
+
frame = buffer.get_frame()
|
| 194 |
+
|
| 195 |
+
# Run inference
|
| 196 |
+
try:
|
| 197 |
+
result = inference_pipeline.predict_phone_level(
|
| 198 |
+
frame,
|
| 199 |
+
return_timestamps=False
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
if result.frame_predictions:
|
| 203 |
+
frame_pred = result.frame_predictions[0] # Single frame result
|
| 204 |
+
|
| 205 |
+
# Map to error detail
|
| 206 |
+
class_id = frame_pred.articulation_class
|
| 207 |
+
if frame_pred.fluency_label == 'stutter':
|
| 208 |
+
class_id += 4
|
| 209 |
+
|
| 210 |
+
error_detail = None
|
| 211 |
+
phoneme = '' # Streaming doesn't have text input
|
| 212 |
+
|
| 213 |
+
if error_mapper:
|
| 214 |
+
try:
|
| 215 |
+
error_detail_obj = error_mapper.map_classifier_output(
|
| 216 |
+
class_id=class_id,
|
| 217 |
+
confidence=frame_pred.confidence,
|
| 218 |
+
phoneme=phoneme,
|
| 219 |
+
fluency_label=frame_pred.fluency_label
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
if error_detail_obj.error_type != ErrorType.NORMAL:
|
| 223 |
+
error_detail = ErrorDetailSchema(
|
| 224 |
+
phoneme=error_detail_obj.phoneme,
|
| 225 |
+
error_type=error_detail_obj.error_type.value,
|
| 226 |
+
wrong_sound=error_detail_obj.wrong_sound,
|
| 227 |
+
severity=error_detail_obj.severity,
|
| 228 |
+
confidence=error_detail_obj.confidence,
|
| 229 |
+
therapy=error_detail_obj.therapy,
|
| 230 |
+
frame_indices=[frame_index]
|
| 231 |
+
)
|
| 232 |
+
except Exception as e:
|
| 233 |
+
logger.warning(f"Error mapping failed: {e}")
|
| 234 |
+
|
| 235 |
+
# Calculate latency
|
| 236 |
+
latency_ms = (time.time() - frame_start_time) * 1000
|
| 237 |
+
|
| 238 |
+
# Get severity level
|
| 239 |
+
severity_level = "none"
|
| 240 |
+
if error_detail and error_mapper:
|
| 241 |
+
severity_level = error_mapper.get_severity_level(error_detail.severity).value
|
| 242 |
+
|
| 243 |
+
# Create response
|
| 244 |
+
response = StreamingDiagnosisResponse(
|
| 245 |
+
session_id=session_id,
|
| 246 |
+
frame_id=frame_index,
|
| 247 |
+
timestamp=frame_index * (inference_pipeline.inference_config.hop_size_ms / 1000.0),
|
| 248 |
+
phoneme=phoneme,
|
| 249 |
+
fluency=FluencyInfo(
|
| 250 |
+
label=frame_pred.fluency_label,
|
| 251 |
+
confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob)
|
| 252 |
+
),
|
| 253 |
+
articulation=ArticulationInfo(
|
| 254 |
+
label=frame_pred.articulation_label,
|
| 255 |
+
confidence=frame_pred.confidence,
|
| 256 |
+
class_id=frame_pred.articulation_class
|
| 257 |
+
),
|
| 258 |
+
error=error_detail,
|
| 259 |
+
severity_level=severity_level,
|
| 260 |
+
confidence=frame_pred.confidence,
|
| 261 |
+
latency_ms=latency_ms
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# Send response
|
| 265 |
+
await websocket.send_json(response.model_dump())
|
| 266 |
+
|
| 267 |
+
# Update session stats
|
| 268 |
+
streaming_sessions[session_id]["frame_count"] += 1
|
| 269 |
+
streaming_sessions[session_id]["total_latency_ms"] += latency_ms
|
| 270 |
+
|
| 271 |
+
# Check latency requirement
|
| 272 |
+
if latency_ms > 50.0:
|
| 273 |
+
logger.warning(f"⚠️ Latency exceeded 50ms: {latency_ms:.1f}ms")
|
| 274 |
+
|
| 275 |
+
# Slide buffer
|
| 276 |
+
buffer.slide()
|
| 277 |
+
frame_index += 1
|
| 278 |
+
|
| 279 |
+
except Exception as e:
|
| 280 |
+
logger.error(f"❌ Inference failed: {e}", exc_info=True)
|
| 281 |
+
await websocket.send_json({
|
| 282 |
+
"error": f"Inference failed: {str(e)}",
|
| 283 |
+
"frame_id": frame_index
|
| 284 |
+
})
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(f"❌ Error processing chunk: {e}", exc_info=True)
|
| 288 |
+
await websocket.send_json({
|
| 289 |
+
"error": f"Processing failed: {str(e)}",
|
| 290 |
+
"frame_id": frame_index
|
| 291 |
+
})
|
| 292 |
+
|
| 293 |
+
except WebSocketDisconnect:
|
| 294 |
+
logger.info(f"🔌 WebSocket disconnected: session_id={session_id}")
|
| 295 |
+
except Exception as e:
|
| 296 |
+
logger.error(f"❌ WebSocket error: {e}", exc_info=True)
|
| 297 |
+
finally:
|
| 298 |
+
# Cleanup session
|
| 299 |
+
if session_id in streaming_sessions:
|
| 300 |
+
session_data = streaming_sessions[session_id]
|
| 301 |
+
avg_latency = session_data["total_latency_ms"] / session_data["frame_count"] if session_data["frame_count"] > 0 else 0.0
|
| 302 |
+
logger.info(f"📊 Session {session_id} stats: {session_data['frame_count']} frames, "
|
| 303 |
+
f"avg_latency={avg_latency:.1f}ms")
|
| 304 |
+
del streaming_sessions[session_id]
|
| 305 |
+
|
app.py
CHANGED
|
@@ -173,33 +173,65 @@ async def diagnose_speech(
|
|
| 173 |
|
| 174 |
# Run inference
|
| 175 |
logger.info("🔄 Running inference pipeline...")
|
| 176 |
-
|
|
|
|
| 177 |
temp_file,
|
| 178 |
-
return_timestamps=True
|
| 179 |
-
apply_smoothing=True
|
| 180 |
)
|
| 181 |
|
| 182 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 183 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
# Format response
|
| 185 |
response = {
|
| 186 |
"status": "success",
|
| 187 |
"fluency_metrics": {
|
| 188 |
-
"mean_fluency":
|
| 189 |
-
"fluency_percentage":
|
| 190 |
-
"fluent_frames_ratio":
|
| 191 |
-
"
|
| 192 |
-
"
|
| 193 |
-
"max": result.fluency_metrics.get("max", 0.0),
|
| 194 |
-
"median": result.fluency_metrics.get("median", 0.0)
|
| 195 |
},
|
| 196 |
"articulation_results": {
|
| 197 |
-
"total_frames":
|
| 198 |
-
"frame_duration_ms":
|
| 199 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
},
|
| 201 |
-
"confidence":
|
| 202 |
-
"confidence_percentage":
|
| 203 |
"processing_time_ms": processing_time_ms
|
| 204 |
}
|
| 205 |
|
|
|
|
| 173 |
|
| 174 |
# Run inference
|
| 175 |
logger.info("🔄 Running inference pipeline...")
|
| 176 |
+
# Use new phone-level prediction
|
| 177 |
+
result = inference_pipeline.predict_phone_level(
|
| 178 |
temp_file,
|
| 179 |
+
return_timestamps=True
|
|
|
|
| 180 |
)
|
| 181 |
|
| 182 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 183 |
|
| 184 |
+
# Extract metrics from new PhoneLevelResult format
|
| 185 |
+
aggregate = result.aggregate
|
| 186 |
+
mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
|
| 187 |
+
fluency_percentage = (1.0 - mean_fluency_stutter) * 100 # Convert stutter prob to fluency percentage
|
| 188 |
+
|
| 189 |
+
# Count fluent frames
|
| 190 |
+
fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
|
| 191 |
+
fluent_frames_ratio = fluent_frames / result.num_frames if result.num_frames > 0 else 0.0
|
| 192 |
+
|
| 193 |
+
# Extract articulation class distribution
|
| 194 |
+
articulation_class_counts = {}
|
| 195 |
+
for fp in result.frame_predictions:
|
| 196 |
+
label = fp.articulation_label
|
| 197 |
+
articulation_class_counts[label] = articulation_class_counts.get(label, 0) + 1
|
| 198 |
+
|
| 199 |
+
# Get dominant articulation class
|
| 200 |
+
dominant_articulation = aggregate.get("articulation_label", "normal")
|
| 201 |
+
|
| 202 |
+
# Calculate average confidence
|
| 203 |
+
avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
|
| 204 |
+
|
| 205 |
# Format response
|
| 206 |
response = {
|
| 207 |
"status": "success",
|
| 208 |
"fluency_metrics": {
|
| 209 |
+
"mean_fluency": fluency_percentage / 100.0,
|
| 210 |
+
"fluency_percentage": fluency_percentage,
|
| 211 |
+
"fluent_frames_ratio": fluent_frames_ratio,
|
| 212 |
+
"fluent_frames_percentage": fluent_frames_ratio * 100,
|
| 213 |
+
"stutter_probability": mean_fluency_stutter
|
|
|
|
|
|
|
| 214 |
},
|
| 215 |
"articulation_results": {
|
| 216 |
+
"total_frames": result.num_frames,
|
| 217 |
+
"frame_duration_ms": int(inference_pipeline.inference_config.hop_size_ms),
|
| 218 |
+
"dominant_class": aggregate.get("articulation_class", 0),
|
| 219 |
+
"dominant_label": dominant_articulation,
|
| 220 |
+
"class_distribution": articulation_class_counts,
|
| 221 |
+
"frame_predictions": [
|
| 222 |
+
{
|
| 223 |
+
"time": fp.time,
|
| 224 |
+
"fluency_prob": fp.fluency_prob,
|
| 225 |
+
"fluency_label": fp.fluency_label,
|
| 226 |
+
"articulation_class": fp.articulation_class,
|
| 227 |
+
"articulation_label": fp.articulation_label,
|
| 228 |
+
"confidence": fp.confidence
|
| 229 |
+
}
|
| 230 |
+
for fp in result.frame_predictions
|
| 231 |
+
]
|
| 232 |
},
|
| 233 |
+
"confidence": avg_confidence,
|
| 234 |
+
"confidence_percentage": avg_confidence * 100,
|
| 235 |
"processing_time_ms": processing_time_ms
|
| 236 |
}
|
| 237 |
|
config.py
CHANGED
|
@@ -92,12 +92,24 @@ class InferenceConfig:
|
|
| 92 |
Reduces jitter in frame-level predictions.
|
| 93 |
batch_size: Number of chunks to process in parallel during inference.
|
| 94 |
Higher values = faster but more memory usage.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
"""
|
| 96 |
fluency_threshold: float = 0.5
|
| 97 |
articulation_threshold: float = 0.6
|
| 98 |
min_chunk_duration_ms: int = 10
|
| 99 |
smoothing_window: int = 5
|
| 100 |
batch_size: int = 32
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
@dataclass
|
|
|
|
| 92 |
Reduces jitter in frame-level predictions.
|
| 93 |
batch_size: Number of chunks to process in parallel during inference.
|
| 94 |
Higher values = faster but more memory usage.
|
| 95 |
+
window_size_ms: Size of sliding window in milliseconds (default: 1000ms = 1 second).
|
| 96 |
+
Minimum for Wav2Vec2 stability.
|
| 97 |
+
hop_size_ms: Hop size between windows in milliseconds (default: 10ms).
|
| 98 |
+
Controls temporal resolution (100 frames/second).
|
| 99 |
+
frame_rate: Frames per second (calculated from hop_size_ms).
|
| 100 |
+
minimum_audio_length: Minimum audio length in seconds (must be >= window_size_ms).
|
| 101 |
+
phone_level_strategy: Strategy for phone-level analysis ("sliding_window").
|
| 102 |
"""
|
| 103 |
fluency_threshold: float = 0.5
|
| 104 |
articulation_threshold: float = 0.6
|
| 105 |
min_chunk_duration_ms: int = 10
|
| 106 |
smoothing_window: int = 5
|
| 107 |
batch_size: int = 32
|
| 108 |
+
window_size_ms: int = 1000 # 1 second minimum for Wav2Vec2
|
| 109 |
+
hop_size_ms: int = 10 # 10ms for phone-level resolution
|
| 110 |
+
frame_rate: float = 100.0 # 100 frames per second (1/hop_size_ms)
|
| 111 |
+
minimum_audio_length: float = 1.0 # Must be >= window_size_ms
|
| 112 |
+
phone_level_strategy: str = "sliding_window"
|
| 113 |
|
| 114 |
|
| 115 |
@dataclass
|
data/therapy_recommendations.json
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"substitutions": {
|
| 3 |
+
"/s/→/θ/": "Lisp - Use tongue tip placement behind upper teeth. Practice /s/ in isolation, then in words. Use mirror feedback to ensure tongue is not protruding.",
|
| 4 |
+
"/s/→/ʃ/": "Sibilant confusion - Practice /s/ vs /sh/ distinction. Focus on tongue position: /s/ has tongue tip up, /sh/ has tongue body back.",
|
| 5 |
+
"/s/→/z/": "Voicing error - Practice voiceless /s/ vs voiced /z/. Place hand on throat to feel vibration difference.",
|
| 6 |
+
"/r/→/w/": "Rhotacism - Practice tongue position: curl tongue back, avoid lip rounding. Start with /r/ in isolation, then CV syllables (ra, re, ri, ro, ru).",
|
| 7 |
+
"/r/→/l/": "Rhotacism - Focus on tongue tip position vs. tongue body placement. /r/ uses tongue body, /l/ uses tongue tip.",
|
| 8 |
+
"/r/→/ɹ/": "Rhotacism variant - Practice standard /r/ production with proper tongue curl.",
|
| 9 |
+
"/l/→/w/": "Liquid substitution - Practice lateral tongue placement with tongue tip up to alveolar ridge.",
|
| 10 |
+
"/l/→/j/": "Liquid substitution - Focus on tongue tip contact for /l/ vs. tongue body for /j/.",
|
| 11 |
+
"/k/→/t/": "Velar to alveolar substitution - Practice back tongue placement for /k/. Use mirror to see tongue position.",
|
| 12 |
+
"/k/→/p/": "Velar to bilabial substitution - Practice velar placement: tongue back, soft palate contact.",
|
| 13 |
+
"/g/→/d/": "Velar to alveolar substitution - Practice voiced velar /g/ with tongue back position.",
|
| 14 |
+
"/g/→/b/": "Velar to bilabial substitution - Practice velar placement for /g/.",
|
| 15 |
+
"/θ/→/f/": "Th-fronting - Practice tongue tip placement between teeth for /θ/. Use mirror to ensure correct position.",
|
| 16 |
+
"/θ/→/s/": "Th-fronting - Practice interdental placement for /θ/ vs. alveolar /s/.",
|
| 17 |
+
"/ð/→/v/": "Voiced th-fronting - Practice interdental placement for /ð/ vs. labiodental /v/.",
|
| 18 |
+
"/ð/→/z/": "Voiced th-fronting - Practice interdental /ð/ vs. alveolar /z/.",
|
| 19 |
+
"/ʃ/→/s/": "Sh-sound confusion - Practice /sh/ with tongue body back vs. /s/ with tongue tip up.",
|
| 20 |
+
"/ʃ/→/tʃ/": "Fricative to affricate - Practice sustained /sh/ vs. stop-release /ch/.",
|
| 21 |
+
"/tʃ/→/ʃ/": "Affricate to fricative - Practice stop component of /ch/ before fricative release.",
|
| 22 |
+
"/tʃ/→/ts/": "Affricate substitution - Practice /ch/ with proper tongue placement and air release.",
|
| 23 |
+
"generic": "Substitution error for {phoneme}. Practice correct articulator placement with mirror feedback. Start in isolation, then syllables, then words."
|
| 24 |
+
},
|
| 25 |
+
"omissions": {
|
| 26 |
+
"/r/": "Practice /r/ in isolation, then in CV syllables (ra, re, ri, ro, ru). Focus on tongue curl and lip position. Use visual cues and mirror feedback.",
|
| 27 |
+
"/l/": "Lateral tongue placement - practice with tongue tip up to alveolar ridge. Start with /l/ in isolation, then blend into words.",
|
| 28 |
+
"/s/": "Practice /s/ with tongue tip placement, use mirror to check position. Start in isolation, then fricative-only words (sss, see, say).",
|
| 29 |
+
"/k/": "Practice velar /k/ with tongue back position. Use mirror to see tongue placement. Start with /k/ in isolation.",
|
| 30 |
+
"/g/": "Practice voiced velar /g/ with tongue back. Start in isolation, then CV syllables.",
|
| 31 |
+
"/t/": "Practice alveolar /t/ with tongue tip contact. Use mirror feedback for placement.",
|
| 32 |
+
"/d/": "Practice voiced alveolar /d/ with tongue tip contact.",
|
| 33 |
+
"/p/": "Practice bilabial /p/ with lip closure. Use mirror to ensure proper closure.",
|
| 34 |
+
"/b/": "Practice voiced bilabial /b/ with lip closure.",
|
| 35 |
+
"/f/": "Practice labiodental /f/ with lower lip to upper teeth contact.",
|
| 36 |
+
"/v/": "Practice voiced labiodental /v/ with lip-teeth contact.",
|
| 37 |
+
"/θ/": "Practice interdental /θ/ with tongue tip between teeth.",
|
| 38 |
+
"/ð/": "Practice voiced interdental /ð/ with tongue tip between teeth.",
|
| 39 |
+
"/ʃ/": "Practice /sh/ with tongue body back and lip rounding.",
|
| 40 |
+
"/tʃ/": "Practice /ch/ with stop then fricative release.",
|
| 41 |
+
"/dʒ/": "Practice /j/ (as in judge) with stop then fricative release.",
|
| 42 |
+
"generic": "Omission error for {phoneme}. Say the sound separately first, then blend into syllables, then words. Use visual cues and mirror feedback."
|
| 43 |
+
},
|
| 44 |
+
"distortions": {
|
| 45 |
+
"/s/": "Sibilant clarity - use mirror feedback, ensure tongue tip is up and air stream is central. Practice sustained /s/ sound.",
|
| 46 |
+
"/z/": "Voiced sibilant clarity - practice with voicing, ensure proper tongue placement.",
|
| 47 |
+
"/ʃ/": "Fricative voicing exercise - practice /sh/ vs /s/ distinction. Focus on tongue body position.",
|
| 48 |
+
"/tʃ/": "Affricate clarity - practice stop component then fricative release. Ensure proper timing.",
|
| 49 |
+
"/r/": "Rhotacism - practice tongue position and lip rounding control. Use mirror to see tongue curl.",
|
| 50 |
+
"/l/": "Lateral clarity - ensure tongue tip is up and air flows over sides of tongue.",
|
| 51 |
+
"/θ/": "Interdental clarity - practice with tongue tip between teeth, ensure air stream is correct.",
|
| 52 |
+
"/ð/": "Voiced interdental clarity - practice with voicing and proper tongue placement.",
|
| 53 |
+
"/k/": "Velar stop clarity - practice with proper tongue back placement and release.",
|
| 54 |
+
"/g/": "Voiced velar stop clarity - practice with voicing and tongue placement.",
|
| 55 |
+
"/t/": "Alveolar stop clarity - practice with tongue tip contact and clean release.",
|
| 56 |
+
"/d/": "Voiced alveolar stop clarity - practice with voicing and proper contact.",
|
| 57 |
+
"generic": "Distortion error for {phoneme}. Use mirror feedback and watch articulator position carefully. Practice in isolation first, then in words."
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
diagnosis/ai_engine/detect_stuttering.py
CHANGED
|
@@ -30,6 +30,9 @@ from scipy.spatial import ConvexHull
|
|
| 30 |
from scipy.stats import pearsonr
|
| 31 |
from difflib import SequenceMatcher
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
logger = logging.getLogger(__name__)
|
| 34 |
|
| 35 |
# === CONFIGURATION ===
|
|
|
|
| 30 |
from scipy.stats import pearsonr
|
| 31 |
from difflib import SequenceMatcher
|
| 32 |
|
| 33 |
+
import warnings
|
| 34 |
+
warnings.filterwarnings("ignore", message="CUDA requested but not available")
|
| 35 |
+
|
| 36 |
logger = logging.getLogger(__name__)
|
| 37 |
|
| 38 |
# === CONFIGURATION ===
|
inference/inference_pipeline.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
Inference Pipeline for Speech Pathology Diagnosis
|
| 3 |
|
| 4 |
This module provides the inference pipeline for real-time and batch processing
|
| 5 |
-
of audio for fluency and articulation analysis.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import logging
|
|
@@ -13,7 +13,8 @@ import soundfile as sf
|
|
| 13 |
from typing import Dict, List, Optional, Tuple, Union
|
| 14 |
from pathlib import Path
|
| 15 |
import time
|
| 16 |
-
from dataclasses import dataclass
|
|
|
|
| 17 |
|
| 18 |
from models.speech_pathology_model import SpeechPathologyClassifier, load_speech_pathology_model
|
| 19 |
from config import AudioConfig, ModelConfig, InferenceConfig
|
|
@@ -22,60 +23,50 @@ logger = logging.getLogger(__name__)
|
|
| 22 |
|
| 23 |
|
| 24 |
@dataclass
|
| 25 |
-
class
|
| 26 |
"""
|
| 27 |
-
Container for
|
| 28 |
|
| 29 |
Attributes:
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
articulation_class: Class index (0-3)
|
| 32 |
-
|
| 33 |
-
articulation_probs: Probabilities for all 4 classes
|
| 34 |
confidence: Overall confidence score
|
| 35 |
-
timestamp_ms: Timestamp in milliseconds (for streaming)
|
| 36 |
-
frame_index: Frame index (for streaming)
|
| 37 |
"""
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
articulation_class: int
|
| 40 |
-
|
| 41 |
-
articulation_probs: List[float]
|
| 42 |
confidence: float
|
| 43 |
-
timestamp_ms: Optional[float] = None
|
| 44 |
-
frame_index: Optional[int] = None
|
| 45 |
|
| 46 |
|
| 47 |
@dataclass
|
| 48 |
-
class
|
| 49 |
"""
|
| 50 |
-
Container for
|
| 51 |
|
| 52 |
Attributes:
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
confidence: Overall confidence
|
| 56 |
-
timestamps: List of timestamps in milliseconds
|
| 57 |
-
frame_duration_ms: Duration of each frame in milliseconds
|
| 58 |
"""
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
frame_duration_ms: float
|
| 64 |
|
| 65 |
|
| 66 |
class InferencePipeline:
|
| 67 |
"""
|
| 68 |
-
Inference pipeline for speech pathology diagnosis.
|
| 69 |
-
|
| 70 |
-
Handles both batch processing (full audio files) and streaming
|
| 71 |
-
(real-time chunk-by-chunk) inference with phone-level granularity.
|
| 72 |
|
| 73 |
-
|
| 74 |
-
-
|
| 75 |
-
-
|
| 76 |
-
-
|
| 77 |
-
-
|
| 78 |
-
- Temporal smoothing of predictions
|
| 79 |
"""
|
| 80 |
|
| 81 |
def __init__(
|
|
@@ -93,9 +84,6 @@ class InferencePipeline:
|
|
| 93 |
audio_config: Audio processing configuration
|
| 94 |
model_config: Model configuration
|
| 95 |
inference_config: Inference configuration
|
| 96 |
-
|
| 97 |
-
Raises:
|
| 98 |
-
RuntimeError: If model cannot be loaded or initialized
|
| 99 |
"""
|
| 100 |
# Load configurations
|
| 101 |
from config import default_audio_config, default_model_config, default_inference_config
|
|
@@ -104,18 +92,17 @@ class InferencePipeline:
|
|
| 104 |
self.model_config = model_config or default_model_config
|
| 105 |
self.inference_config = inference_config or default_inference_config
|
| 106 |
|
| 107 |
-
logger.info("Initializing InferencePipeline...")
|
| 108 |
-
logger.info(f"
|
| 109 |
-
|
| 110 |
-
logger.info(f"
|
| 111 |
-
f"articulation_threshold={self.inference_config.articulation_threshold}")
|
| 112 |
|
| 113 |
# Initialize or use provided model
|
| 114 |
if model is None:
|
| 115 |
logger.info("Loading SpeechPathologyClassifier...")
|
| 116 |
self.model = load_speech_pathology_model(
|
| 117 |
model_name=self.model_config.model_name,
|
| 118 |
-
classifier_hidden_dims=
|
| 119 |
dropout=self.model_config.dropout,
|
| 120 |
device=self.model_config.device,
|
| 121 |
use_fp16=self.model_config.use_fp16
|
|
@@ -127,16 +114,16 @@ class InferencePipeline:
|
|
| 127 |
# Get processor for audio preprocessing
|
| 128 |
self.processor = self.model.processor
|
| 129 |
|
| 130 |
-
# Calculate
|
| 131 |
-
self.
|
| 132 |
-
self.
|
| 133 |
)
|
| 134 |
self.hop_size_samples = int(
|
| 135 |
-
self.
|
| 136 |
)
|
| 137 |
|
| 138 |
-
logger.info(f"
|
| 139 |
-
|
| 140 |
logger.info("✅ InferencePipeline initialized successfully")
|
| 141 |
|
| 142 |
def preprocess_audio(
|
|
@@ -153,335 +140,276 @@ class InferencePipeline:
|
|
| 153 |
|
| 154 |
Returns:
|
| 155 |
Preprocessed audio array normalized to [-1, 1] range
|
| 156 |
-
|
| 157 |
-
Raises:
|
| 158 |
-
ValueError: If audio cannot be loaded or processed
|
| 159 |
"""
|
| 160 |
target_sr = target_sr or self.audio_config.sample_rate
|
| 161 |
|
| 162 |
# Load audio if path provided
|
| 163 |
if isinstance(audio, (str, Path)):
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
raise ValueError(f"Cannot load audio file: {e}") from e
|
| 170 |
else:
|
| 171 |
-
audio_array = audio
|
| 172 |
-
# Resample if needed
|
| 173 |
if len(audio_array.shape) > 1:
|
| 174 |
-
audio_array =
|
| 175 |
|
| 176 |
-
# Normalize to [-1, 1]
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
audio_array = audio_array / max_val
|
| 181 |
-
logger.debug("Normalized audio to [-1, 1] range")
|
| 182 |
|
| 183 |
return audio_array
|
| 184 |
|
| 185 |
-
def
|
| 186 |
self,
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
apply_smoothing: bool = True
|
| 190 |
-
) -> BatchPredictionResult:
|
| 191 |
"""
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
Processes audio in overlapping frames for phone-level analysis.
|
| 195 |
|
| 196 |
Args:
|
| 197 |
-
|
| 198 |
-
return_timestamps: Whether to include timestamps for each frame
|
| 199 |
-
apply_smoothing: Whether to apply temporal smoothing
|
| 200 |
|
| 201 |
Returns:
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
ValueError: If audio cannot be processed
|
| 206 |
-
RuntimeError: If inference fails
|
| 207 |
"""
|
| 208 |
-
|
| 209 |
-
|
| 210 |
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
|
|
|
|
|
|
| 217 |
|
| 218 |
-
#
|
| 219 |
-
|
| 220 |
-
timestamps = []
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
|
| 228 |
-
|
| 229 |
-
if len(frame) < self.inference_config.min_chunk_duration_ms * \
|
| 230 |
-
self.audio_config.sample_rate / 1000:
|
| 231 |
-
continue
|
| 232 |
|
| 233 |
-
#
|
| 234 |
-
|
|
|
|
|
|
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
predictions.append(pred_result)
|
| 243 |
-
if return_timestamps:
|
| 244 |
-
timestamps.append(frame_timestamp_ms)
|
| 245 |
-
frame_idx += 1
|
| 246 |
-
except Exception as e:
|
| 247 |
-
logger.warning(f"Failed to predict frame {frame_idx}: {e}")
|
| 248 |
-
continue
|
| 249 |
-
|
| 250 |
-
if not predictions:
|
| 251 |
-
raise ValueError("No valid frames extracted from audio")
|
| 252 |
-
|
| 253 |
-
logger.info(f"Processed {len(predictions)} frames")
|
| 254 |
-
|
| 255 |
-
# Apply temporal smoothing if requested
|
| 256 |
-
if apply_smoothing and len(predictions) > 1:
|
| 257 |
-
predictions = self._apply_temporal_smoothing(predictions)
|
| 258 |
-
|
| 259 |
-
# Aggregate results
|
| 260 |
-
result = self._aggregate_predictions(predictions, timestamps)
|
| 261 |
-
|
| 262 |
-
elapsed_time = time.time() - start_time
|
| 263 |
-
logger.info(f"Batch prediction completed in {elapsed_time:.2f}s "
|
| 264 |
-
f"({duration_seconds/elapsed_time:.1f}x real-time)")
|
| 265 |
-
|
| 266 |
-
return result
|
| 267 |
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 271 |
|
| 272 |
-
def
|
| 273 |
self,
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
) -> PredictionResult:
|
| 278 |
"""
|
| 279 |
-
Predict fluency and articulation
|
| 280 |
-
|
| 281 |
-
Designed for real-time processing with <200ms latency requirement.
|
| 282 |
|
| 283 |
Args:
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
timestamp_ms: Optional timestamp in milliseconds
|
| 287 |
|
| 288 |
Returns:
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
Raises:
|
| 292 |
-
ValueError: If chunk is invalid
|
| 293 |
-
RuntimeError: If inference fails
|
| 294 |
"""
|
|
|
|
|
|
|
| 295 |
try:
|
| 296 |
-
# Preprocess
|
| 297 |
-
|
|
|
|
| 298 |
|
| 299 |
-
|
| 300 |
-
expected_samples = self.frame_size_samples
|
| 301 |
-
if len(chunk) < expected_samples * 0.5: # Allow some tolerance
|
| 302 |
-
logger.warning(f"Chunk size {len(chunk)} is smaller than expected {expected_samples}")
|
| 303 |
|
| 304 |
-
#
|
| 305 |
-
if
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
|
| 310 |
-
#
|
| 311 |
-
|
| 312 |
-
chunk,
|
| 313 |
-
frame_index=frame_index,
|
| 314 |
-
timestamp_ms=timestamp_ms
|
| 315 |
-
)
|
| 316 |
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 350 |
|
| 351 |
-
#
|
| 352 |
-
|
|
|
|
| 353 |
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
)
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
articulation_class_name=outputs["articulation_class_name"],
|
| 365 |
-
articulation_probs=outputs["articulation_probs"],
|
| 366 |
-
confidence=outputs["confidence"],
|
| 367 |
-
timestamp_ms=timestamp_ms,
|
| 368 |
-
frame_index=frame_index
|
| 369 |
-
)
|
| 370 |
|
| 371 |
-
def
|
| 372 |
self,
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
|
|
|
| 376 |
"""
|
| 377 |
-
|
| 378 |
|
| 379 |
Args:
|
| 380 |
-
|
| 381 |
-
|
|
|
|
| 382 |
|
| 383 |
Returns:
|
| 384 |
-
|
| 385 |
"""
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
if len(predictions) <= window_size:
|
| 389 |
-
return predictions
|
| 390 |
-
|
| 391 |
-
smoothed = []
|
| 392 |
-
for i in range(len(predictions)):
|
| 393 |
-
# Get window indices
|
| 394 |
-
start_idx = max(0, i - window_size // 2)
|
| 395 |
-
end_idx = min(len(predictions), i + window_size // 2 + 1)
|
| 396 |
-
window_preds = predictions[start_idx:end_idx]
|
| 397 |
-
|
| 398 |
-
# Average fluency scores
|
| 399 |
-
avg_fluency = np.mean([p.fluency_score for p in window_preds])
|
| 400 |
-
|
| 401 |
-
# Average articulation probabilities
|
| 402 |
-
avg_articulation_probs = np.mean(
|
| 403 |
-
[p.articulation_probs for p in window_preds],
|
| 404 |
-
axis=0
|
| 405 |
-
)
|
| 406 |
-
|
| 407 |
-
# Get most likely class from averaged probabilities
|
| 408 |
-
articulation_class = int(np.argmax(avg_articulation_probs))
|
| 409 |
-
articulation_class_name = self.model.get_articulation_class_name(articulation_class)
|
| 410 |
-
|
| 411 |
-
# Calculate confidence
|
| 412 |
-
confidence = (avg_fluency + avg_articulation_probs[articulation_class]) / 2.0
|
| 413 |
-
|
| 414 |
-
smoothed.append(PredictionResult(
|
| 415 |
-
fluency_score=float(avg_fluency),
|
| 416 |
-
articulation_class=articulation_class,
|
| 417 |
-
articulation_class_name=articulation_class_name,
|
| 418 |
-
articulation_probs=avg_articulation_probs.tolist(),
|
| 419 |
-
confidence=float(confidence),
|
| 420 |
-
timestamp_ms=predictions[i].timestamp_ms,
|
| 421 |
-
frame_index=predictions[i].frame_index
|
| 422 |
-
))
|
| 423 |
-
|
| 424 |
-
return smoothed
|
| 425 |
|
| 426 |
-
def
|
| 427 |
self,
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
|
|
|
| 431 |
"""
|
| 432 |
-
|
| 433 |
|
| 434 |
Args:
|
| 435 |
-
|
| 436 |
-
|
|
|
|
| 437 |
|
| 438 |
Returns:
|
| 439 |
-
|
| 440 |
"""
|
| 441 |
-
if
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
# Calculate fluency metrics
|
| 445 |
-
fluency_scores = [p.fluency_score for p in predictions]
|
| 446 |
-
fluency_metrics = {
|
| 447 |
-
"mean": float(np.mean(fluency_scores)),
|
| 448 |
-
"std": float(np.std(fluency_scores)),
|
| 449 |
-
"min": float(np.min(fluency_scores)),
|
| 450 |
-
"max": float(np.max(fluency_scores)),
|
| 451 |
-
"median": float(np.median(fluency_scores)),
|
| 452 |
-
"fluent_frames_ratio": float(np.mean([s >= self.inference_config.fluency_threshold
|
| 453 |
-
for s in fluency_scores]))
|
| 454 |
-
}
|
| 455 |
-
|
| 456 |
-
# Articulation scores per frame
|
| 457 |
-
articulation_scores = [
|
| 458 |
-
{
|
| 459 |
-
"class": p.articulation_class,
|
| 460 |
-
"class_name": p.articulation_class_name,
|
| 461 |
-
"probs": p.articulation_probs,
|
| 462 |
-
"confidence": p.confidence
|
| 463 |
-
}
|
| 464 |
-
for p in predictions
|
| 465 |
-
]
|
| 466 |
|
| 467 |
-
#
|
| 468 |
-
|
| 469 |
|
| 470 |
-
#
|
| 471 |
-
if
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
-
return
|
| 475 |
-
fluency_metrics=fluency_metrics,
|
| 476 |
-
articulation_scores=articulation_scores,
|
| 477 |
-
confidence=overall_confidence,
|
| 478 |
-
timestamps=timestamps or [],
|
| 479 |
-
frame_duration_ms=self.audio_config.chunk_duration_ms
|
| 480 |
-
)
|
| 481 |
|
| 482 |
|
| 483 |
def create_inference_pipeline(
|
| 484 |
-
|
| 485 |
audio_config: Optional[AudioConfig] = None,
|
| 486 |
model_config: Optional[ModelConfig] = None,
|
| 487 |
inference_config: Optional[InferenceConfig] = None
|
|
@@ -490,31 +418,17 @@ def create_inference_pipeline(
|
|
| 490 |
Factory function to create an InferencePipeline instance.
|
| 491 |
|
| 492 |
Args:
|
| 493 |
-
|
| 494 |
-
audio_config:
|
| 495 |
-
model_config:
|
| 496 |
-
inference_config:
|
| 497 |
|
| 498 |
Returns:
|
| 499 |
-
InferencePipeline instance
|
| 500 |
"""
|
| 501 |
-
model = None
|
| 502 |
-
if model_path:
|
| 503 |
-
logger.info(f"Loading model from: {model_path}")
|
| 504 |
-
model_config = model_config or ModelConfig()
|
| 505 |
-
model = load_speech_pathology_model(
|
| 506 |
-
model_name=model_config.model_name,
|
| 507 |
-
classifier_hidden_dims=model_config.classifier_hidden_dims,
|
| 508 |
-
dropout=model_config.dropout,
|
| 509 |
-
device=model_config.device,
|
| 510 |
-
use_fp16=model_config.use_fp16,
|
| 511 |
-
model_path=model_path
|
| 512 |
-
)
|
| 513 |
-
|
| 514 |
return InferencePipeline(
|
| 515 |
model=model,
|
| 516 |
audio_config=audio_config,
|
| 517 |
model_config=model_config,
|
| 518 |
inference_config=inference_config
|
| 519 |
)
|
| 520 |
-
|
|
|
|
| 2 |
Inference Pipeline for Speech Pathology Diagnosis
|
| 3 |
|
| 4 |
This module provides the inference pipeline for real-time and batch processing
|
| 5 |
+
of audio for fluency and articulation analysis using sliding window approach.
|
| 6 |
"""
|
| 7 |
|
| 8 |
import logging
|
|
|
|
| 13 |
from typing import Dict, List, Optional, Tuple, Union
|
| 14 |
from pathlib import Path
|
| 15 |
import time
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
from collections import deque
|
| 18 |
|
| 19 |
from models.speech_pathology_model import SpeechPathologyClassifier, load_speech_pathology_model
|
| 20 |
from config import AudioConfig, ModelConfig, InferenceConfig
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@dataclass
|
| 26 |
+
class FramePrediction:
|
| 27 |
"""
|
| 28 |
+
Container for a single frame prediction.
|
| 29 |
|
| 30 |
Attributes:
|
| 31 |
+
time: Timestamp in seconds
|
| 32 |
+
fluency_prob: Probability of stutter (0-1)
|
| 33 |
+
fluency_label: 'normal' or 'stutter'
|
| 34 |
articulation_class: Class index (0-3)
|
| 35 |
+
articulation_label: Class name
|
|
|
|
| 36 |
confidence: Overall confidence score
|
|
|
|
|
|
|
| 37 |
"""
|
| 38 |
+
time: float
|
| 39 |
+
fluency_prob: float
|
| 40 |
+
fluency_label: str
|
| 41 |
articulation_class: int
|
| 42 |
+
articulation_label: str
|
|
|
|
| 43 |
confidence: float
|
|
|
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
@dataclass
|
| 47 |
+
class PhoneLevelResult:
|
| 48 |
"""
|
| 49 |
+
Container for phone-level prediction results.
|
| 50 |
|
| 51 |
Attributes:
|
| 52 |
+
frame_predictions: List of frame-level predictions
|
| 53 |
+
aggregate: Aggregated statistics
|
|
|
|
|
|
|
|
|
|
| 54 |
"""
|
| 55 |
+
frame_predictions: List[FramePrediction]
|
| 56 |
+
aggregate: Dict[str, Union[float, int, str]]
|
| 57 |
+
duration: float
|
| 58 |
+
num_frames: int
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
class InferencePipeline:
|
| 62 |
"""
|
| 63 |
+
Inference pipeline for speech pathology diagnosis using sliding window approach.
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
Architecture:
|
| 66 |
+
- 1-second sliding windows (minimum for Wav2Vec2)
|
| 67 |
+
- 10ms hop size for phone-level resolution
|
| 68 |
+
- Wav2Vec2 feature extraction per window
|
| 69 |
+
- Multi-task classifier for fluency + articulation
|
|
|
|
| 70 |
"""
|
| 71 |
|
| 72 |
def __init__(
|
|
|
|
| 84 |
audio_config: Audio processing configuration
|
| 85 |
model_config: Model configuration
|
| 86 |
inference_config: Inference configuration
|
|
|
|
|
|
|
|
|
|
| 87 |
"""
|
| 88 |
# Load configurations
|
| 89 |
from config import default_audio_config, default_model_config, default_inference_config
|
|
|
|
| 92 |
self.model_config = model_config or default_model_config
|
| 93 |
self.inference_config = inference_config or default_inference_config
|
| 94 |
|
| 95 |
+
logger.info("Initializing InferencePipeline (sliding window)...")
|
| 96 |
+
logger.info(f"Window size: {self.inference_config.window_size_ms}ms")
|
| 97 |
+
logger.info(f"Hop size: {self.inference_config.hop_size_ms}ms")
|
| 98 |
+
logger.info(f"Frame rate: {self.inference_config.frame_rate} fps")
|
|
|
|
| 99 |
|
| 100 |
# Initialize or use provided model
|
| 101 |
if model is None:
|
| 102 |
logger.info("Loading SpeechPathologyClassifier...")
|
| 103 |
self.model = load_speech_pathology_model(
|
| 104 |
model_name=self.model_config.model_name,
|
| 105 |
+
classifier_hidden_dims=[512, 256], # 1024 → 512 → 256
|
| 106 |
dropout=self.model_config.dropout,
|
| 107 |
device=self.model_config.device,
|
| 108 |
use_fp16=self.model_config.use_fp16
|
|
|
|
| 114 |
# Get processor for audio preprocessing
|
| 115 |
self.processor = self.model.processor
|
| 116 |
|
| 117 |
+
# Calculate window and hop sizes in samples
|
| 118 |
+
self.window_size_samples = int(
|
| 119 |
+
self.inference_config.window_size_ms * self.audio_config.sample_rate / 1000
|
| 120 |
)
|
| 121 |
self.hop_size_samples = int(
|
| 122 |
+
self.inference_config.hop_size_ms * self.audio_config.sample_rate / 1000
|
| 123 |
)
|
| 124 |
|
| 125 |
+
logger.info(f"Window size: {self.window_size_samples} samples")
|
| 126 |
+
logger.info(f"Hop size: {self.hop_size_samples} samples")
|
| 127 |
logger.info("✅ InferencePipeline initialized successfully")
|
| 128 |
|
| 129 |
def preprocess_audio(
|
|
|
|
| 140 |
|
| 141 |
Returns:
|
| 142 |
Preprocessed audio array normalized to [-1, 1] range
|
|
|
|
|
|
|
|
|
|
| 143 |
"""
|
| 144 |
target_sr = target_sr or self.audio_config.sample_rate
|
| 145 |
|
| 146 |
# Load audio if path provided
|
| 147 |
if isinstance(audio, (str, Path)):
|
| 148 |
+
audio_path = Path(audio)
|
| 149 |
+
if not audio_path.exists():
|
| 150 |
+
raise ValueError(f"Audio file not found: {audio_path}")
|
| 151 |
+
|
| 152 |
+
audio_array, sr = librosa.load(str(audio_path), sr=target_sr, mono=True)
|
|
|
|
| 153 |
else:
|
| 154 |
+
audio_array = np.array(audio, dtype=np.float32)
|
|
|
|
| 155 |
if len(audio_array.shape) > 1:
|
| 156 |
+
audio_array = np.mean(audio_array, axis=0) # Convert to mono
|
| 157 |
|
| 158 |
+
# Normalize to [-1, 1]
|
| 159 |
+
max_val = np.abs(audio_array).max()
|
| 160 |
+
if max_val > 0:
|
| 161 |
+
audio_array = audio_array / max_val
|
|
|
|
|
|
|
| 162 |
|
| 163 |
return audio_array
|
| 164 |
|
| 165 |
+
def get_phone_level_features(
|
| 166 |
self,
|
| 167 |
+
audio: np.ndarray
|
| 168 |
+
) -> Tuple[torch.Tensor, np.ndarray]:
|
|
|
|
|
|
|
| 169 |
"""
|
| 170 |
+
Extract phone-level features using sliding window approach.
|
|
|
|
|
|
|
| 171 |
|
| 172 |
Args:
|
| 173 |
+
audio: Preprocessed audio array (16kHz, mono, normalized)
|
|
|
|
|
|
|
| 174 |
|
| 175 |
Returns:
|
| 176 |
+
Tuple of (frame_features, frame_times):
|
| 177 |
+
- frame_features: Tensor of shape (num_frames, 1024)
|
| 178 |
+
- frame_times: Array of timestamps in seconds
|
|
|
|
|
|
|
| 179 |
"""
|
| 180 |
+
num_samples = len(audio)
|
| 181 |
+
num_windows = max(1, (num_samples - self.window_size_samples) // self.hop_size_samples + 1)
|
| 182 |
|
| 183 |
+
frame_features_list = []
|
| 184 |
+
frame_times = []
|
| 185 |
+
|
| 186 |
+
logger.info(f"Extracting features from {num_windows} windows...")
|
| 187 |
+
|
| 188 |
+
for i in range(num_windows):
|
| 189 |
+
start_sample = i * self.hop_size_samples
|
| 190 |
+
end_sample = min(start_sample + self.window_size_samples, num_samples)
|
| 191 |
|
| 192 |
+
# Extract window
|
| 193 |
+
window = audio[start_sample:end_sample]
|
|
|
|
| 194 |
|
| 195 |
+
# Pad if necessary (at the end of audio)
|
| 196 |
+
if len(window) < self.window_size_samples:
|
| 197 |
+
padding = self.window_size_samples - len(window)
|
| 198 |
+
window = np.pad(window, (0, padding), mode='constant')
|
| 199 |
+
|
| 200 |
+
# Convert to tensor
|
| 201 |
+
audio_tensor = torch.from_numpy(window).float()
|
| 202 |
+
|
| 203 |
+
# Process through feature extractor
|
| 204 |
+
with torch.no_grad():
|
| 205 |
+
inputs = self.processor(
|
| 206 |
+
audio_tensor,
|
| 207 |
+
sampling_rate=self.audio_config.sample_rate,
|
| 208 |
+
return_tensors="pt",
|
| 209 |
+
padding=True
|
| 210 |
+
)
|
| 211 |
|
| 212 |
+
input_values = inputs.input_values.to(self.model.device)
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
# Extract Wav2Vec2 features
|
| 215 |
+
wav2vec2_outputs = self.model.wav2vec2_model(
|
| 216 |
+
input_values=input_values
|
| 217 |
+
)
|
| 218 |
|
| 219 |
+
# Get features: (batch_size, seq_len, 1024)
|
| 220 |
+
features = wav2vec2_outputs.last_hidden_state
|
| 221 |
+
|
| 222 |
+
# Pool to single vector: mean over sequence length
|
| 223 |
+
pooled_features = torch.mean(features, dim=1) # (batch_size, 1024)
|
| 224 |
+
frame_features_list.append(pooled_features.cpu())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
# Calculate timestamp (center of window)
|
| 227 |
+
frame_time = (start_sample + self.window_size_samples / 2) / self.audio_config.sample_rate
|
| 228 |
+
frame_times.append(frame_time)
|
| 229 |
+
|
| 230 |
+
# Stack all features
|
| 231 |
+
frame_features = torch.cat(frame_features_list, dim=0) # (num_frames, 1024)
|
| 232 |
+
frame_times = np.array(frame_times)
|
| 233 |
+
|
| 234 |
+
logger.info(f"Extracted {len(frame_features)} frame features")
|
| 235 |
+
|
| 236 |
+
return frame_features, frame_times
|
| 237 |
|
| 238 |
+
def predict_phone_level(
|
| 239 |
self,
|
| 240 |
+
audio: Union[np.ndarray, str, Path],
|
| 241 |
+
return_timestamps: bool = True
|
| 242 |
+
) -> PhoneLevelResult:
|
|
|
|
| 243 |
"""
|
| 244 |
+
Predict fluency and articulation at phone-level resolution.
|
|
|
|
|
|
|
| 245 |
|
| 246 |
Args:
|
| 247 |
+
audio: Audio array, file path, or Path object
|
| 248 |
+
return_timestamps: Whether to include timestamps in results
|
|
|
|
| 249 |
|
| 250 |
Returns:
|
| 251 |
+
PhoneLevelResult with frame-level predictions and aggregates
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
"""
|
| 253 |
+
start_time = time.time()
|
| 254 |
+
|
| 255 |
try:
|
| 256 |
+
# Preprocess audio
|
| 257 |
+
audio_array = self.preprocess_audio(audio)
|
| 258 |
+
duration = len(audio_array) / self.audio_config.sample_rate
|
| 259 |
|
| 260 |
+
logger.info(f"Processing audio: {duration:.2f}s")
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
+
# Check minimum length
|
| 263 |
+
if duration < self.inference_config.minimum_audio_length:
|
| 264 |
+
logger.warning(f"Audio shorter than minimum ({duration:.2f}s < {self.inference_config.minimum_audio_length}s), "
|
| 265 |
+
f"padding to minimum")
|
| 266 |
+
min_samples = int(self.inference_config.minimum_audio_length * self.audio_config.sample_rate)
|
| 267 |
+
if len(audio_array) < min_samples:
|
| 268 |
+
padding = min_samples - len(audio_array)
|
| 269 |
+
audio_array = np.pad(audio_array, (0, padding), mode='constant')
|
| 270 |
+
duration = len(audio_array) / self.audio_config.sample_rate
|
| 271 |
|
| 272 |
+
# Extract phone-level features
|
| 273 |
+
frame_features, frame_times = self.get_phone_level_features(audio_array)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
+
# Move features to device
|
| 276 |
+
frame_features = frame_features.to(self.model.device)
|
| 277 |
+
|
| 278 |
+
# Predict using classifier
|
| 279 |
+
self.model.eval()
|
| 280 |
+
with torch.no_grad():
|
| 281 |
+
# Pass through shared layers and heads
|
| 282 |
+
shared_features = self.model.classifier_head.shared_layers(frame_features)
|
| 283 |
+
|
| 284 |
+
# Get predictions from all heads
|
| 285 |
+
fluency_logits = self.model.classifier_head.fluency_head(shared_features)
|
| 286 |
+
articulation_logits = self.model.classifier_head.articulation_head(shared_features)
|
| 287 |
+
full_logits = self.model.classifier_head.full_head(shared_features)
|
| 288 |
+
|
| 289 |
+
# Apply softmax
|
| 290 |
+
fluency_probs = torch.softmax(fluency_logits, dim=-1) # (num_frames, 2)
|
| 291 |
+
articulation_probs = torch.softmax(articulation_logits, dim=-1) # (num_frames, 4)
|
| 292 |
+
full_probs = torch.softmax(full_logits, dim=-1) # (num_frames, 8)
|
| 293 |
+
|
| 294 |
+
# Convert to numpy
|
| 295 |
+
fluency_probs = fluency_probs.cpu().numpy()
|
| 296 |
+
articulation_probs = articulation_probs.cpu().numpy()
|
| 297 |
+
full_probs = full_probs.cpu().numpy()
|
| 298 |
+
|
| 299 |
+
# Create frame predictions
|
| 300 |
+
frame_predictions = []
|
| 301 |
+
for i in range(len(frame_features)):
|
| 302 |
+
# Fluency: class 0 = normal, class 1 = stutter
|
| 303 |
+
fluency_prob_stutter = fluency_probs[i, 1]
|
| 304 |
+
fluency_label = 'stutter' if fluency_prob_stutter > self.inference_config.fluency_threshold else 'normal'
|
| 305 |
+
|
| 306 |
+
# Articulation: get class with highest probability
|
| 307 |
+
articulation_class = int(np.argmax(articulation_probs[i]))
|
| 308 |
+
articulation_label = self.model.get_articulation_class_name(articulation_class)
|
| 309 |
+
|
| 310 |
+
# Confidence: average of max probabilities
|
| 311 |
+
confidence = (np.max(fluency_probs[i]) + np.max(articulation_probs[i])) / 2.0
|
| 312 |
+
|
| 313 |
+
frame_pred = FramePrediction(
|
| 314 |
+
time=frame_times[i] if return_timestamps else 0.0,
|
| 315 |
+
fluency_prob=float(fluency_prob_stutter),
|
| 316 |
+
fluency_label=fluency_label,
|
| 317 |
+
articulation_class=articulation_class,
|
| 318 |
+
articulation_label=articulation_label,
|
| 319 |
+
confidence=float(confidence)
|
| 320 |
+
)
|
| 321 |
+
frame_predictions.append(frame_pred)
|
| 322 |
|
| 323 |
+
# Aggregate statistics
|
| 324 |
+
fluency_scores = [fp.fluency_prob for fp in frame_predictions]
|
| 325 |
+
articulation_classes = [fp.articulation_class for fp in frame_predictions]
|
| 326 |
|
| 327 |
+
aggregate = {
|
| 328 |
+
'fluency_score': float(np.mean(fluency_scores)),
|
| 329 |
+
'articulation_class': int(np.bincount(articulation_classes).argmax()),
|
| 330 |
+
'articulation_label': self.model.get_articulation_class_name(
|
| 331 |
+
int(np.bincount(articulation_classes).argmax())
|
| 332 |
+
),
|
| 333 |
+
'num_frames': len(frame_predictions),
|
| 334 |
+
'duration': duration
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
elapsed_time = time.time() - start_time
|
| 338 |
+
logger.info(f"Phone-level prediction completed in {elapsed_time:.2f}s "
|
| 339 |
+
f"({duration/elapsed_time:.1f}x real-time)")
|
| 340 |
+
|
| 341 |
+
return PhoneLevelResult(
|
| 342 |
+
frame_predictions=frame_predictions,
|
| 343 |
+
aggregate=aggregate,
|
| 344 |
+
duration=duration,
|
| 345 |
+
num_frames=len(frame_predictions)
|
| 346 |
)
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
logger.error(f"Phone-level prediction failed: {e}", exc_info=True)
|
| 350 |
+
raise RuntimeError(f"Phone-level prediction failed: {e}") from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 351 |
|
| 352 |
+
def predict_batch(
|
| 353 |
self,
|
| 354 |
+
audio_path: Union[str, Path],
|
| 355 |
+
return_timestamps: bool = True,
|
| 356 |
+
apply_smoothing: bool = True
|
| 357 |
+
) -> PhoneLevelResult:
|
| 358 |
"""
|
| 359 |
+
Predict on audio file (batch processing).
|
| 360 |
|
| 361 |
Args:
|
| 362 |
+
audio_path: Path to audio file
|
| 363 |
+
return_timestamps: Whether to include timestamps
|
| 364 |
+
apply_smoothing: Whether to apply temporal smoothing (not implemented yet)
|
| 365 |
|
| 366 |
Returns:
|
| 367 |
+
PhoneLevelResult
|
| 368 |
"""
|
| 369 |
+
return self.predict_phone_level(audio_path, return_timestamps=return_timestamps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
|
| 371 |
+
def predict_streaming_chunk(
|
| 372 |
self,
|
| 373 |
+
chunk: np.ndarray,
|
| 374 |
+
buffer: Optional[deque] = None,
|
| 375 |
+
timestamp: Optional[float] = None
|
| 376 |
+
) -> Optional[FramePrediction]:
|
| 377 |
"""
|
| 378 |
+
Predict on streaming audio chunk.
|
| 379 |
|
| 380 |
Args:
|
| 381 |
+
chunk: Audio chunk array
|
| 382 |
+
buffer: Optional buffer for maintaining sliding window
|
| 383 |
+
timestamp: Optional timestamp for the chunk
|
| 384 |
|
| 385 |
Returns:
|
| 386 |
+
FramePrediction if enough data accumulated, None otherwise
|
| 387 |
"""
|
| 388 |
+
if buffer is None:
|
| 389 |
+
buffer = deque(maxlen=self.window_size_samples)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
|
| 391 |
+
# Add chunk to buffer
|
| 392 |
+
buffer.extend(chunk)
|
| 393 |
|
| 394 |
+
# Check if we have enough data for a window
|
| 395 |
+
if len(buffer) >= self.window_size_samples:
|
| 396 |
+
# Extract window
|
| 397 |
+
window = np.array(list(buffer)[-self.window_size_samples:])
|
| 398 |
+
|
| 399 |
+
# Process window
|
| 400 |
+
try:
|
| 401 |
+
result = self.predict_phone_level(window, return_timestamps=False)
|
| 402 |
+
if result.frame_predictions:
|
| 403 |
+
return result.frame_predictions[-1] # Return latest frame
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.warning(f"Streaming prediction failed: {e}")
|
| 406 |
+
return None
|
| 407 |
|
| 408 |
+
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
|
| 411 |
def create_inference_pipeline(
|
| 412 |
+
model: Optional[SpeechPathologyClassifier] = None,
|
| 413 |
audio_config: Optional[AudioConfig] = None,
|
| 414 |
model_config: Optional[ModelConfig] = None,
|
| 415 |
inference_config: Optional[InferenceConfig] = None
|
|
|
|
| 418 |
Factory function to create an InferencePipeline instance.
|
| 419 |
|
| 420 |
Args:
|
| 421 |
+
model: Optional pre-initialized model
|
| 422 |
+
audio_config: Optional audio configuration
|
| 423 |
+
model_config: Optional model configuration
|
| 424 |
+
inference_config: Optional inference configuration
|
| 425 |
|
| 426 |
Returns:
|
| 427 |
+
Initialized InferencePipeline instance
|
| 428 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
return InferencePipeline(
|
| 430 |
model=model,
|
| 431 |
audio_config=audio_config,
|
| 432 |
model_config=model_config,
|
| 433 |
inference_config=inference_config
|
| 434 |
)
|
|
|
models/error_taxonomy.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Error Taxonomy for Speech Pathology Analysis
|
| 3 |
+
|
| 4 |
+
This module defines error types, severity levels, and therapy recommendations
|
| 5 |
+
for phoneme-level error detection.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import json
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Optional, Dict, List
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from pydantic import BaseModel, Field
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ErrorType(str, Enum):
|
| 20 |
+
"""Types of articulation errors."""
|
| 21 |
+
NORMAL = "normal"
|
| 22 |
+
SUBSTITUTION = "substitution"
|
| 23 |
+
OMISSION = "omission"
|
| 24 |
+
DISTORTION = "distortion"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SeverityLevel(str, Enum):
|
| 28 |
+
"""Severity levels for errors."""
|
| 29 |
+
NONE = "none"
|
| 30 |
+
LOW = "low"
|
| 31 |
+
MEDIUM = "medium"
|
| 32 |
+
HIGH = "high"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class ErrorDetail:
|
| 37 |
+
"""
|
| 38 |
+
Detailed error information for a phoneme.
|
| 39 |
+
|
| 40 |
+
Attributes:
|
| 41 |
+
phoneme: Expected phoneme symbol (e.g., '/s/')
|
| 42 |
+
error_type: Type of error (NORMAL, SUBSTITUTION, OMISSION, DISTORTION)
|
| 43 |
+
wrong_sound: For substitutions, the incorrect phoneme produced (e.g., '/θ/')
|
| 44 |
+
severity: Severity score (0.0-1.0)
|
| 45 |
+
confidence: Model confidence in the error detection (0.0-1.0)
|
| 46 |
+
therapy: Therapy recommendation text
|
| 47 |
+
frame_indices: List of frame indices where this error occurs
|
| 48 |
+
"""
|
| 49 |
+
phoneme: str
|
| 50 |
+
error_type: ErrorType
|
| 51 |
+
wrong_sound: Optional[str] = None
|
| 52 |
+
severity: float = 0.0
|
| 53 |
+
confidence: float = 0.0
|
| 54 |
+
therapy: str = ""
|
| 55 |
+
frame_indices: List[int] = field(default_factory=list)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ErrorDetailPydantic(BaseModel):
|
| 59 |
+
"""Pydantic model for API serialization."""
|
| 60 |
+
phoneme: str
|
| 61 |
+
error_type: str
|
| 62 |
+
wrong_sound: Optional[str] = None
|
| 63 |
+
severity: float = Field(ge=0.0, le=1.0)
|
| 64 |
+
confidence: float = Field(ge=0.0, le=1.0)
|
| 65 |
+
therapy: str
|
| 66 |
+
frame_indices: List[int] = Field(default_factory=list)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class ErrorMapper:
|
| 70 |
+
"""
|
| 71 |
+
Maps classifier outputs to error types and generates therapy recommendations.
|
| 72 |
+
|
| 73 |
+
Classifier output mapping (8 classes):
|
| 74 |
+
- Class 0: Normal articulation, normal fluency
|
| 75 |
+
- Class 1: Substitution, normal fluency
|
| 76 |
+
- Class 2: Omission, normal fluency
|
| 77 |
+
- Class 3: Distortion, normal fluency
|
| 78 |
+
- Class 4: Normal articulation, stutter
|
| 79 |
+
- Class 5: Substitution, stutter
|
| 80 |
+
- Class 6: Omission, stutter
|
| 81 |
+
- Class 7: Distortion, stutter
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def __init__(self, therapy_db_path: Optional[str] = None):
|
| 85 |
+
"""
|
| 86 |
+
Initialize the ErrorMapper.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
therapy_db_path: Path to therapy recommendations JSON file.
|
| 90 |
+
If None, uses default location: data/therapy_recommendations.json
|
| 91 |
+
"""
|
| 92 |
+
self.therapy_db: Dict = {}
|
| 93 |
+
|
| 94 |
+
# Default path
|
| 95 |
+
if therapy_db_path is None:
|
| 96 |
+
therapy_db_path = Path(__file__).parent.parent / "data" / "therapy_recommendations.json"
|
| 97 |
+
else:
|
| 98 |
+
therapy_db_path = Path(therapy_db_path)
|
| 99 |
+
|
| 100 |
+
# Load therapy database
|
| 101 |
+
try:
|
| 102 |
+
if therapy_db_path.exists():
|
| 103 |
+
with open(therapy_db_path, 'r', encoding='utf-8') as f:
|
| 104 |
+
self.therapy_db = json.load(f)
|
| 105 |
+
logger.info(f"✅ Loaded therapy database from {therapy_db_path}")
|
| 106 |
+
else:
|
| 107 |
+
logger.warning(f"Therapy database not found at {therapy_db_path}, using defaults")
|
| 108 |
+
self.therapy_db = self._get_default_therapy_db()
|
| 109 |
+
except Exception as e:
|
| 110 |
+
logger.error(f"Failed to load therapy database: {e}, using defaults")
|
| 111 |
+
self.therapy_db = self._get_default_therapy_db()
|
| 112 |
+
|
| 113 |
+
# Common substitution mappings (phoneme → likely wrong sound)
|
| 114 |
+
self.substitution_map: Dict[str, List[str]] = {
|
| 115 |
+
'/s/': ['/θ/', '/ʃ/', '/z/'], # lisp, sh-sound, voicing
|
| 116 |
+
'/r/': ['/w/', '/l/', '/ɹ/'], # rhotacism variants
|
| 117 |
+
'/l/': ['/w/', '/j/'], # liquid substitutions
|
| 118 |
+
'/k/': ['/t/', '/p/'], # velar → alveolar/bilabial
|
| 119 |
+
'/g/': ['/d/', '/b/'], # velar → alveolar/bilabial
|
| 120 |
+
'/θ/': ['/f/', '/s/'], # th → f or s
|
| 121 |
+
'/ð/': ['/v/', '/z/'], # voiced th → v or z
|
| 122 |
+
'/ʃ/': ['/s/', '/tʃ/'], # sh → s or ch
|
| 123 |
+
'/tʃ/': ['/ʃ/', '/ts/'], # ch → sh or ts
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
def map_classifier_output(
|
| 127 |
+
self,
|
| 128 |
+
class_id: int,
|
| 129 |
+
confidence: float,
|
| 130 |
+
phoneme: str,
|
| 131 |
+
fluency_label: str = "normal"
|
| 132 |
+
) -> ErrorDetail:
|
| 133 |
+
"""
|
| 134 |
+
Map classifier output to ErrorDetail.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
class_id: Classifier output class (0-7)
|
| 138 |
+
confidence: Model confidence (0.0-1.0)
|
| 139 |
+
phoneme: Expected phoneme symbol
|
| 140 |
+
fluency_label: Fluency label ("normal" or "stutter")
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
ErrorDetail object with error information
|
| 144 |
+
"""
|
| 145 |
+
# Determine error type from class_id
|
| 146 |
+
if class_id == 0 or class_id == 4:
|
| 147 |
+
error_type = ErrorType.NORMAL
|
| 148 |
+
elif class_id == 1 or class_id == 5:
|
| 149 |
+
error_type = ErrorType.SUBSTITUTION
|
| 150 |
+
elif class_id == 2 or class_id == 6:
|
| 151 |
+
error_type = ErrorType.OMISSION
|
| 152 |
+
elif class_id == 3 or class_id == 7:
|
| 153 |
+
error_type = ErrorType.DISTORTION
|
| 154 |
+
else:
|
| 155 |
+
logger.warning(f"Unknown class_id: {class_id}, defaulting to NORMAL")
|
| 156 |
+
error_type = ErrorType.NORMAL
|
| 157 |
+
|
| 158 |
+
# Calculate severity from confidence
|
| 159 |
+
# Higher confidence in error = higher severity
|
| 160 |
+
if error_type == ErrorType.NORMAL:
|
| 161 |
+
severity = 0.0
|
| 162 |
+
else:
|
| 163 |
+
severity = confidence # Use confidence as severity proxy
|
| 164 |
+
|
| 165 |
+
# Get wrong sound for substitutions
|
| 166 |
+
wrong_sound = None
|
| 167 |
+
if error_type == ErrorType.SUBSTITUTION:
|
| 168 |
+
wrong_sound = self._map_substitution(phoneme, confidence)
|
| 169 |
+
|
| 170 |
+
# Get therapy recommendation
|
| 171 |
+
therapy = self.get_therapy(error_type, phoneme, wrong_sound)
|
| 172 |
+
|
| 173 |
+
return ErrorDetail(
|
| 174 |
+
phoneme=phoneme,
|
| 175 |
+
error_type=error_type,
|
| 176 |
+
wrong_sound=wrong_sound,
|
| 177 |
+
severity=severity,
|
| 178 |
+
confidence=confidence,
|
| 179 |
+
therapy=therapy
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
def _map_substitution(self, phoneme: str, confidence: float) -> Optional[str]:
|
| 183 |
+
"""
|
| 184 |
+
Map substitution error to likely wrong sound.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
phoneme: Expected phoneme
|
| 188 |
+
confidence: Model confidence
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
Most likely wrong phoneme, or None if unknown
|
| 192 |
+
"""
|
| 193 |
+
if phoneme in self.substitution_map:
|
| 194 |
+
# Return first (most common) substitution
|
| 195 |
+
return self.substitution_map[phoneme][0]
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
def get_therapy(
|
| 199 |
+
self,
|
| 200 |
+
error_type: ErrorType,
|
| 201 |
+
phoneme: str,
|
| 202 |
+
wrong_sound: Optional[str] = None
|
| 203 |
+
) -> str:
|
| 204 |
+
"""
|
| 205 |
+
Get therapy recommendation for an error.
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
error_type: Type of error
|
| 209 |
+
phoneme: Expected phoneme
|
| 210 |
+
wrong_sound: For substitutions, the wrong sound produced
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Therapy recommendation text
|
| 214 |
+
"""
|
| 215 |
+
if error_type == ErrorType.NORMAL:
|
| 216 |
+
return "No therapy needed - production is correct."
|
| 217 |
+
|
| 218 |
+
# Build lookup key
|
| 219 |
+
if error_type == ErrorType.SUBSTITUTION and wrong_sound:
|
| 220 |
+
key = f"{phoneme}→{wrong_sound}"
|
| 221 |
+
if "substitutions" in self.therapy_db and key in self.therapy_db["substitutions"]:
|
| 222 |
+
return self.therapy_db["substitutions"][key]
|
| 223 |
+
|
| 224 |
+
# Fallback to generic recommendations
|
| 225 |
+
if error_type == ErrorType.SUBSTITUTION:
|
| 226 |
+
if "substitutions" in self.therapy_db and "generic" in self.therapy_db["substitutions"]:
|
| 227 |
+
return self.therapy_db["substitutions"]["generic"].replace("{phoneme}", phoneme)
|
| 228 |
+
return f"Substitution error for {phoneme}. Practice correct articulator placement."
|
| 229 |
+
|
| 230 |
+
elif error_type == ErrorType.OMISSION:
|
| 231 |
+
if "omissions" in self.therapy_db and phoneme in self.therapy_db["omissions"]:
|
| 232 |
+
return self.therapy_db["omissions"][phoneme]
|
| 233 |
+
if "omissions" in self.therapy_db and "generic" in self.therapy_db["omissions"]:
|
| 234 |
+
return self.therapy_db["omissions"]["generic"].replace("{phoneme}", phoneme)
|
| 235 |
+
return f"Omission error for {phoneme}. Practice saying the sound separately first."
|
| 236 |
+
|
| 237 |
+
elif error_type == ErrorType.DISTORTION:
|
| 238 |
+
if "distortions" in self.therapy_db and phoneme in self.therapy_db["distortions"]:
|
| 239 |
+
return self.therapy_db["distortions"][phoneme]
|
| 240 |
+
if "distortions" in self.therapy_db and "generic" in self.therapy_db["distortions"]:
|
| 241 |
+
return self.therapy_db["distortions"]["generic"].replace("{phoneme}", phoneme)
|
| 242 |
+
return f"Distortion error for {phoneme}. Use mirror feedback and watch articulator position."
|
| 243 |
+
|
| 244 |
+
return "Consult with speech-language pathologist for personalized therapy plan."
|
| 245 |
+
|
| 246 |
+
def get_severity_level(self, severity: float) -> SeverityLevel:
|
| 247 |
+
"""
|
| 248 |
+
Convert severity score to severity level.
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
severity: Severity score (0.0-1.0)
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
SeverityLevel enum
|
| 255 |
+
"""
|
| 256 |
+
if severity == 0.0:
|
| 257 |
+
return SeverityLevel.NONE
|
| 258 |
+
elif severity < 0.3:
|
| 259 |
+
return SeverityLevel.LOW
|
| 260 |
+
elif severity < 0.7:
|
| 261 |
+
return SeverityLevel.MEDIUM
|
| 262 |
+
else:
|
| 263 |
+
return SeverityLevel.HIGH
|
| 264 |
+
|
| 265 |
+
def _get_default_therapy_db(self) -> Dict:
|
| 266 |
+
"""Get default therapy database if file not found."""
|
| 267 |
+
return {
|
| 268 |
+
"substitutions": {
|
| 269 |
+
"/s/→/θ/": "Lisp - Use tongue tip placement behind upper teeth. Practice /s/ in isolation.",
|
| 270 |
+
"/r/→/w/": "Rhotacism - Practice tongue position: curl tongue back, avoid lip rounding.",
|
| 271 |
+
"/r/→/l/": "Rhotacism - Focus on tongue tip position vs. tongue body placement.",
|
| 272 |
+
"generic": "Substitution error for {phoneme}. Practice correct articulator placement with mirror feedback."
|
| 273 |
+
},
|
| 274 |
+
"omissions": {
|
| 275 |
+
"/r/": "Practice /r/ in isolation, then in CV syllables (ra, re, ri, ro, ru).",
|
| 276 |
+
"/l/": "Lateral tongue placement - practice with tongue tip up to alveolar ridge.",
|
| 277 |
+
"/s/": "Practice /s/ with tongue tip placement, use mirror to check position.",
|
| 278 |
+
"generic": "Omission error for {phoneme}. Say the sound separately first, then blend into words."
|
| 279 |
+
},
|
| 280 |
+
"distortions": {
|
| 281 |
+
"/s/": "Sibilant clarity - use mirror feedback, ensure tongue tip is up and air stream is central.",
|
| 282 |
+
"/ʃ/": "Fricative voicing exercise - practice /sh/ vs /s/ distinction.",
|
| 283 |
+
"/r/": "Rhotacism - practice tongue position and lip rounding control.",
|
| 284 |
+
"generic": "Distortion error for {phoneme}. Use mirror feedback and watch articulator position carefully."
|
| 285 |
+
}
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# Unit test function
|
| 290 |
+
def test_error_mapper():
|
| 291 |
+
"""Test the ErrorMapper."""
|
| 292 |
+
print("Testing ErrorMapper...")
|
| 293 |
+
|
| 294 |
+
mapper = ErrorMapper()
|
| 295 |
+
|
| 296 |
+
# Test 1: Normal (class 0)
|
| 297 |
+
error = mapper.map_classifier_output(0, 0.95, "/k/")
|
| 298 |
+
assert error.error_type == ErrorType.NORMAL
|
| 299 |
+
assert error.severity == 0.0
|
| 300 |
+
print(f"✅ Normal error: {error.error_type}, therapy: {error.therapy[:50]}...")
|
| 301 |
+
|
| 302 |
+
# Test 2: Substitution (class 1)
|
| 303 |
+
error = mapper.map_classifier_output(1, 0.78, "/s/")
|
| 304 |
+
assert error.error_type == ErrorType.SUBSTITUTION
|
| 305 |
+
assert error.wrong_sound is not None
|
| 306 |
+
print(f"✅ Substitution error: {error.error_type}, wrong_sound: {error.wrong_sound}")
|
| 307 |
+
print(f" Therapy: {error.therapy[:80]}...")
|
| 308 |
+
|
| 309 |
+
# Test 3: Omission (class 2)
|
| 310 |
+
error = mapper.map_classifier_output(2, 0.85, "/r/")
|
| 311 |
+
assert error.error_type == ErrorType.OMISSION
|
| 312 |
+
print(f"✅ Omission error: {error.error_type}")
|
| 313 |
+
print(f" Therapy: {error.therapy[:80]}...")
|
| 314 |
+
|
| 315 |
+
# Test 4: Distortion (class 3)
|
| 316 |
+
error = mapper.map_classifier_output(3, 0.65, "/s/")
|
| 317 |
+
assert error.error_type == ErrorType.DISTORTION
|
| 318 |
+
print(f"✅ Distortion error: {error.error_type}")
|
| 319 |
+
print(f" Therapy: {error.therapy[:80]}...")
|
| 320 |
+
|
| 321 |
+
# Test 5: Severity levels
|
| 322 |
+
assert mapper.get_severity_level(0.0) == SeverityLevel.NONE
|
| 323 |
+
assert mapper.get_severity_level(0.2) == SeverityLevel.LOW
|
| 324 |
+
assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM
|
| 325 |
+
assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH
|
| 326 |
+
print("✅ Severity level mapping correct")
|
| 327 |
+
|
| 328 |
+
print("\n✅ All tests passed!")
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
if __name__ == "__main__":
|
| 332 |
+
test_error_mapper()
|
| 333 |
+
|
models/phoneme_mapper.py
ADDED
|
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Phoneme Mapper for Speech Pathology Analysis
|
| 3 |
+
|
| 4 |
+
This module provides grapheme-to-phoneme (G2P) conversion and alignment
|
| 5 |
+
of phonemes to audio frames for phone-level error detection.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
from typing import List, Tuple, Optional, Dict
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import g2p_en
|
| 15 |
+
G2P_AVAILABLE = True
|
| 16 |
+
except ImportError:
|
| 17 |
+
G2P_AVAILABLE = False
|
| 18 |
+
logging.warning("g2p_en not available. Install with: pip install g2p-en")
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class PhonemeSegment:
|
| 25 |
+
"""
|
| 26 |
+
Represents a phoneme segment with timing information.
|
| 27 |
+
|
| 28 |
+
Attributes:
|
| 29 |
+
phoneme: Phoneme symbol (e.g., '/r/', '/k/')
|
| 30 |
+
start_time: Start time in seconds
|
| 31 |
+
end_time: End time in seconds
|
| 32 |
+
duration: Duration in seconds
|
| 33 |
+
frame_start: Starting frame index
|
| 34 |
+
frame_end: Ending frame index (exclusive)
|
| 35 |
+
"""
|
| 36 |
+
phoneme: str
|
| 37 |
+
start_time: float
|
| 38 |
+
end_time: float
|
| 39 |
+
duration: float
|
| 40 |
+
frame_start: int
|
| 41 |
+
frame_end: int
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class PhonemeMapper:
|
| 45 |
+
"""
|
| 46 |
+
Maps text to phonemes and aligns them to audio frames.
|
| 47 |
+
|
| 48 |
+
Uses g2p_en library for English grapheme-to-phoneme conversion.
|
| 49 |
+
Aligns phonemes to 20ms frames for phone-level analysis.
|
| 50 |
+
|
| 51 |
+
Example:
|
| 52 |
+
>>> mapper = PhonemeMapper()
|
| 53 |
+
>>> phonemes = mapper.text_to_phonemes("robot")
|
| 54 |
+
>>> # Returns: [('/r/', 0.0), ('/o/', 0.1), ('/b/', 0.2), ('/o/', 0.3), ('/t/', 0.4)]
|
| 55 |
+
>>> frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25, frame_duration_ms=20)
|
| 56 |
+
>>> # Returns: ['/r/', '/r/', '/r/', '/o/', '/o/', '/b/', '/b/', ...]
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, frame_duration_ms: int = 20, sample_rate: int = 16000):
|
| 60 |
+
"""
|
| 61 |
+
Initialize the PhonemeMapper.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
frame_duration_ms: Duration of each frame in milliseconds (default: 20ms)
|
| 65 |
+
sample_rate: Audio sample rate in Hz (default: 16000)
|
| 66 |
+
|
| 67 |
+
Raises:
|
| 68 |
+
ImportError: If g2p_en is not available
|
| 69 |
+
"""
|
| 70 |
+
if not G2P_AVAILABLE:
|
| 71 |
+
raise ImportError(
|
| 72 |
+
"g2p_en library is required. Install with: pip install g2p-en"
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
self.g2p = g2p_en.G2p()
|
| 77 |
+
logger.info("✅ G2P model loaded successfully")
|
| 78 |
+
except Exception as e:
|
| 79 |
+
logger.error(f"❌ Failed to load G2P model: {e}")
|
| 80 |
+
raise
|
| 81 |
+
|
| 82 |
+
self.frame_duration_ms = frame_duration_ms
|
| 83 |
+
self.frame_duration_s = frame_duration_ms / 1000.0
|
| 84 |
+
self.sample_rate = sample_rate
|
| 85 |
+
|
| 86 |
+
# Average phoneme duration (typical English: 50-100ms)
|
| 87 |
+
# We'll use 80ms as default, but adjust based on text length
|
| 88 |
+
self.avg_phoneme_duration_ms = 80
|
| 89 |
+
self.avg_phoneme_duration_s = self.avg_phoneme_duration_ms / 1000.0
|
| 90 |
+
|
| 91 |
+
logger.info(f"PhonemeMapper initialized: frame_duration={frame_duration_ms}ms, "
|
| 92 |
+
f"avg_phoneme_duration={self.avg_phoneme_duration_ms}ms")
|
| 93 |
+
|
| 94 |
+
def text_to_phonemes(
|
| 95 |
+
self,
|
| 96 |
+
text: str,
|
| 97 |
+
duration: Optional[float] = None
|
| 98 |
+
) -> List[Tuple[str, float]]:
|
| 99 |
+
"""
|
| 100 |
+
Convert text to phonemes with timing information.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
text: Input text string (e.g., "robot", "cat")
|
| 104 |
+
duration: Optional audio duration in seconds. If provided, phonemes
|
| 105 |
+
are distributed evenly across this duration. If None, uses
|
| 106 |
+
estimated duration based on phoneme count.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
List of tuples: [(phoneme, start_time), ...]
|
| 110 |
+
- phoneme: Phoneme symbol with slashes (e.g., '/r/', '/k/')
|
| 111 |
+
- start_time: Start time in seconds
|
| 112 |
+
|
| 113 |
+
Example:
|
| 114 |
+
>>> mapper = PhonemeMapper()
|
| 115 |
+
>>> phonemes = mapper.text_to_phonemes("cat")
|
| 116 |
+
>>> # Returns: [('/k/', 0.0), ('/æ/', 0.08), ('/t/', 0.16)]
|
| 117 |
+
"""
|
| 118 |
+
if not text or not text.strip():
|
| 119 |
+
logger.warning("Empty text provided, returning empty phoneme list")
|
| 120 |
+
return []
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# Convert to phonemes using g2p_en
|
| 124 |
+
phoneme_list = self.g2p(text.lower().strip())
|
| 125 |
+
|
| 126 |
+
# Filter out punctuation and empty strings
|
| 127 |
+
phoneme_list = [p for p in phoneme_list if p and p.strip() and not p.isspace()]
|
| 128 |
+
|
| 129 |
+
if not phoneme_list:
|
| 130 |
+
logger.warning(f"No phonemes extracted from text: '{text}'")
|
| 131 |
+
return []
|
| 132 |
+
|
| 133 |
+
# Add slashes if not present
|
| 134 |
+
formatted_phonemes = []
|
| 135 |
+
for p in phoneme_list:
|
| 136 |
+
if not p.startswith('/'):
|
| 137 |
+
p = '/' + p
|
| 138 |
+
if not p.endswith('/'):
|
| 139 |
+
p = p + '/'
|
| 140 |
+
formatted_phonemes.append(p)
|
| 141 |
+
|
| 142 |
+
logger.debug(f"Extracted {len(formatted_phonemes)} phonemes from '{text}': {formatted_phonemes}")
|
| 143 |
+
|
| 144 |
+
# Calculate timing
|
| 145 |
+
if duration is None:
|
| 146 |
+
# Estimate duration: avg_phoneme_duration * num_phonemes
|
| 147 |
+
total_duration = len(formatted_phonemes) * self.avg_phoneme_duration_s
|
| 148 |
+
else:
|
| 149 |
+
total_duration = duration
|
| 150 |
+
|
| 151 |
+
# Distribute phonemes evenly across duration
|
| 152 |
+
if len(formatted_phonemes) == 1:
|
| 153 |
+
phoneme_duration = total_duration
|
| 154 |
+
else:
|
| 155 |
+
phoneme_duration = total_duration / len(formatted_phonemes)
|
| 156 |
+
|
| 157 |
+
# Create phoneme-time pairs
|
| 158 |
+
phoneme_times = []
|
| 159 |
+
for i, phoneme in enumerate(formatted_phonemes):
|
| 160 |
+
start_time = i * phoneme_duration
|
| 161 |
+
phoneme_times.append((phoneme, start_time))
|
| 162 |
+
|
| 163 |
+
logger.info(f"Converted '{text}' to {len(phoneme_times)} phonemes over {total_duration:.2f}s")
|
| 164 |
+
|
| 165 |
+
return phoneme_times
|
| 166 |
+
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.error(f"Error converting text to phonemes: {e}", exc_info=True)
|
| 169 |
+
raise RuntimeError(f"Failed to convert text to phonemes: {e}") from e
|
| 170 |
+
|
| 171 |
+
def align_phonemes_to_frames(
|
| 172 |
+
self,
|
| 173 |
+
phoneme_times: List[Tuple[str, float]],
|
| 174 |
+
num_frames: int,
|
| 175 |
+
frame_duration_ms: Optional[int] = None
|
| 176 |
+
) -> List[str]:
|
| 177 |
+
"""
|
| 178 |
+
Align phonemes to audio frames.
|
| 179 |
+
|
| 180 |
+
Each frame gets assigned the phoneme that overlaps with its time window.
|
| 181 |
+
If multiple phonemes overlap, uses the one with the most overlap.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
phoneme_times: List of (phoneme, start_time) tuples from text_to_phonemes()
|
| 185 |
+
num_frames: Total number of frames in the audio
|
| 186 |
+
frame_duration_ms: Optional frame duration override
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
List of phonemes, one per frame: ['/r/', '/r/', '/o/', '/b/', ...]
|
| 190 |
+
|
| 191 |
+
Example:
|
| 192 |
+
>>> mapper = PhonemeMapper()
|
| 193 |
+
>>> phonemes = [('/k/', 0.0), ('/æ/', 0.08), ('/t/', 0.16)]
|
| 194 |
+
>>> frames = mapper.align_phonemes_to_frames(phonemes, num_frames=15, frame_duration_ms=20)
|
| 195 |
+
>>> # Returns: ['/k/', '/k/', '/k/', '/k/', '/æ/', '/æ/', '/æ/', '/æ/', '/t/', ...]
|
| 196 |
+
"""
|
| 197 |
+
if not phoneme_times:
|
| 198 |
+
logger.warning("No phonemes provided, returning empty frame list")
|
| 199 |
+
return [''] * num_frames
|
| 200 |
+
|
| 201 |
+
frame_duration_s = (frame_duration_ms / 1000.0) if frame_duration_ms else self.frame_duration_s
|
| 202 |
+
|
| 203 |
+
# Calculate phoneme end times (assume equal duration for simplicity)
|
| 204 |
+
phoneme_segments = []
|
| 205 |
+
for i, (phoneme, start_time) in enumerate(phoneme_times):
|
| 206 |
+
if i < len(phoneme_times) - 1:
|
| 207 |
+
end_time = phoneme_times[i + 1][1]
|
| 208 |
+
else:
|
| 209 |
+
# Last phoneme: estimate duration
|
| 210 |
+
if len(phoneme_times) > 1:
|
| 211 |
+
avg_duration = phoneme_times[1][1] - phoneme_times[0][1]
|
| 212 |
+
else:
|
| 213 |
+
avg_duration = self.avg_phoneme_duration_s
|
| 214 |
+
end_time = start_time + avg_duration
|
| 215 |
+
|
| 216 |
+
phoneme_segments.append(PhonemeSegment(
|
| 217 |
+
phoneme=phoneme,
|
| 218 |
+
start_time=start_time,
|
| 219 |
+
end_time=end_time,
|
| 220 |
+
duration=end_time - start_time,
|
| 221 |
+
frame_start=-1, # Will be calculated
|
| 222 |
+
frame_end=-1
|
| 223 |
+
))
|
| 224 |
+
|
| 225 |
+
# Map each frame to a phoneme
|
| 226 |
+
frame_phonemes = []
|
| 227 |
+
for frame_idx in range(num_frames):
|
| 228 |
+
frame_start_time = frame_idx * frame_duration_s
|
| 229 |
+
frame_end_time = (frame_idx + 1) * frame_duration_s
|
| 230 |
+
frame_center_time = frame_start_time + (frame_duration_s / 2.0)
|
| 231 |
+
|
| 232 |
+
# Find phoneme with most overlap
|
| 233 |
+
best_phoneme = ''
|
| 234 |
+
max_overlap = 0.0
|
| 235 |
+
|
| 236 |
+
for seg in phoneme_segments:
|
| 237 |
+
# Calculate overlap
|
| 238 |
+
overlap_start = max(frame_start_time, seg.start_time)
|
| 239 |
+
overlap_end = min(frame_end_time, seg.end_time)
|
| 240 |
+
overlap = max(0.0, overlap_end - overlap_start)
|
| 241 |
+
|
| 242 |
+
if overlap > max_overlap:
|
| 243 |
+
max_overlap = overlap
|
| 244 |
+
best_phoneme = seg.phoneme
|
| 245 |
+
|
| 246 |
+
# If no overlap, use closest phoneme
|
| 247 |
+
if not best_phoneme:
|
| 248 |
+
closest_seg = min(
|
| 249 |
+
phoneme_segments,
|
| 250 |
+
key=lambda s: abs(frame_center_time - (s.start_time + s.duration / 2))
|
| 251 |
+
)
|
| 252 |
+
best_phoneme = closest_seg.phoneme
|
| 253 |
+
|
| 254 |
+
frame_phonemes.append(best_phoneme)
|
| 255 |
+
|
| 256 |
+
logger.debug(f"Aligned {len(phoneme_times)} phonemes to {num_frames} frames")
|
| 257 |
+
|
| 258 |
+
return frame_phonemes
|
| 259 |
+
|
| 260 |
+
def get_phoneme_boundaries(
|
| 261 |
+
self,
|
| 262 |
+
phoneme_times: List[Tuple[str, float]],
|
| 263 |
+
duration: float
|
| 264 |
+
) -> List[PhonemeSegment]:
|
| 265 |
+
"""
|
| 266 |
+
Get detailed phoneme boundary information.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
phoneme_times: List of (phoneme, start_time) tuples
|
| 270 |
+
duration: Total audio duration in seconds
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
List of PhonemeSegment objects with timing and frame information
|
| 274 |
+
"""
|
| 275 |
+
segments = []
|
| 276 |
+
|
| 277 |
+
for i, (phoneme, start_time) in enumerate(phoneme_times):
|
| 278 |
+
if i < len(phoneme_times) - 1:
|
| 279 |
+
end_time = phoneme_times[i + 1][1]
|
| 280 |
+
else:
|
| 281 |
+
end_time = duration
|
| 282 |
+
|
| 283 |
+
frame_start = int(start_time / self.frame_duration_s)
|
| 284 |
+
frame_end = int(end_time / self.frame_duration_s)
|
| 285 |
+
|
| 286 |
+
segments.append(PhonemeSegment(
|
| 287 |
+
phoneme=phoneme,
|
| 288 |
+
start_time=start_time,
|
| 289 |
+
end_time=end_time,
|
| 290 |
+
duration=end_time - start_time,
|
| 291 |
+
frame_start=frame_start,
|
| 292 |
+
frame_end=frame_end
|
| 293 |
+
))
|
| 294 |
+
|
| 295 |
+
return segments
|
| 296 |
+
|
| 297 |
+
def map_text_to_frames(
|
| 298 |
+
self,
|
| 299 |
+
text: str,
|
| 300 |
+
num_frames: int,
|
| 301 |
+
audio_duration: Optional[float] = None
|
| 302 |
+
) -> List[str]:
|
| 303 |
+
"""
|
| 304 |
+
Complete pipeline: text → phonemes → frame alignment.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
text: Input text string
|
| 308 |
+
num_frames: Number of audio frames
|
| 309 |
+
audio_duration: Optional audio duration in seconds
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
List of phonemes, one per frame
|
| 313 |
+
"""
|
| 314 |
+
# Convert text to phonemes
|
| 315 |
+
phoneme_times = self.text_to_phonemes(text, duration=audio_duration)
|
| 316 |
+
|
| 317 |
+
if not phoneme_times:
|
| 318 |
+
return [''] * num_frames
|
| 319 |
+
|
| 320 |
+
# Align to frames
|
| 321 |
+
frame_phonemes = self.align_phonemes_to_frames(phoneme_times, num_frames)
|
| 322 |
+
|
| 323 |
+
return frame_phonemes
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
# Unit test function
|
| 327 |
+
def test_phoneme_mapper():
|
| 328 |
+
"""Test the PhonemeMapper with example text."""
|
| 329 |
+
print("Testing PhonemeMapper...")
|
| 330 |
+
|
| 331 |
+
try:
|
| 332 |
+
mapper = PhonemeMapper(frame_duration_ms=20)
|
| 333 |
+
|
| 334 |
+
# Test 1: Simple word
|
| 335 |
+
print("\n1. Testing 'robot':")
|
| 336 |
+
phonemes = mapper.text_to_phonemes("robot")
|
| 337 |
+
print(f" Phonemes: {phonemes}")
|
| 338 |
+
assert len(phonemes) > 0, "Should extract phonemes"
|
| 339 |
+
|
| 340 |
+
# Test 2: Frame alignment
|
| 341 |
+
print("\n2. Testing frame alignment:")
|
| 342 |
+
frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25)
|
| 343 |
+
print(f" Frame phonemes (first 10): {frame_phonemes[:10]}")
|
| 344 |
+
assert len(frame_phonemes) == 25, "Should have 25 frames"
|
| 345 |
+
|
| 346 |
+
# Test 3: Complete pipeline
|
| 347 |
+
print("\n3. Testing complete pipeline with 'cat':")
|
| 348 |
+
cat_frames = mapper.map_text_to_frames("cat", num_frames=15)
|
| 349 |
+
print(f" Frame phonemes: {cat_frames}")
|
| 350 |
+
assert len(cat_frames) == 15, "Should have 15 frames"
|
| 351 |
+
|
| 352 |
+
print("\n✅ All tests passed!")
|
| 353 |
+
|
| 354 |
+
except ImportError as e:
|
| 355 |
+
print(f"❌ G2P library not available: {e}")
|
| 356 |
+
print(" Install with: pip install g2p-en")
|
| 357 |
+
except Exception as e:
|
| 358 |
+
print(f"❌ Test failed: {e}")
|
| 359 |
+
raise
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
test_phoneme_mapper()
|
| 364 |
+
|
models/speech_pathology_model.py
CHANGED
|
@@ -11,7 +11,7 @@ import logging
|
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
| 13 |
from torch.nn import functional as F
|
| 14 |
-
from transformers import Wav2Vec2Model,
|
| 15 |
from typing import Dict, Optional, Tuple, List
|
| 16 |
import os
|
| 17 |
|
|
@@ -51,7 +51,7 @@ class MultiTaskClassifierHead(nn.Module):
|
|
| 51 |
|
| 52 |
self.num_articulation_classes = num_articulation_classes
|
| 53 |
|
| 54 |
-
# Build shared feature layers
|
| 55 |
layers = []
|
| 56 |
prev_dim = input_dim
|
| 57 |
|
|
@@ -67,15 +67,15 @@ class MultiTaskClassifierHead(nn.Module):
|
|
| 67 |
self.shared_layers = nn.Sequential(*layers)
|
| 68 |
shared_output_dim = prev_dim
|
| 69 |
|
| 70 |
-
# Fluency head
|
| 71 |
self.fluency_head = nn.Sequential(
|
| 72 |
nn.Linear(shared_output_dim, 64),
|
| 73 |
nn.ReLU(),
|
| 74 |
nn.Dropout(dropout),
|
| 75 |
-
nn.Linear(64,
|
| 76 |
)
|
| 77 |
|
| 78 |
-
# Articulation head
|
| 79 |
self.articulation_head = nn.Sequential(
|
| 80 |
nn.Linear(shared_output_dim, 64),
|
| 81 |
nn.ReLU(),
|
|
@@ -83,6 +83,14 @@ class MultiTaskClassifierHead(nn.Module):
|
|
| 83 |
nn.Linear(64, num_articulation_classes), # 4 classes
|
| 84 |
)
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
logger.info(
|
| 87 |
f"Initialized MultiTaskClassifierHead: "
|
| 88 |
f"input_dim={input_dim}, hidden_dims={hidden_dims}, "
|
|
@@ -124,18 +132,23 @@ class MultiTaskClassifierHead(nn.Module):
|
|
| 124 |
shared_features = self.shared_layers(pooled_features)
|
| 125 |
|
| 126 |
# Task-specific heads
|
| 127 |
-
fluency_logits = self.fluency_head(shared_features)
|
| 128 |
-
articulation_logits = self.articulation_head(shared_features)
|
|
|
|
| 129 |
|
| 130 |
# Apply activations
|
| 131 |
-
fluency_probs =
|
| 132 |
-
articulation_probs = F.softmax(articulation_logits, dim=-1)
|
|
|
|
| 133 |
|
| 134 |
return {
|
| 135 |
"fluency_logits": fluency_logits,
|
| 136 |
"articulation_logits": articulation_logits,
|
|
|
|
| 137 |
"fluency_probs": fluency_probs,
|
| 138 |
"articulation_probs": articulation_probs,
|
|
|
|
|
|
|
| 139 |
}
|
| 140 |
|
| 141 |
|
|
@@ -210,13 +223,15 @@ class SpeechPathologyClassifier(nn.Module):
|
|
| 210 |
# Load Wav2Vec2 model and processor
|
| 211 |
hf_token = os.getenv("HF_TOKEN")
|
| 212 |
|
| 213 |
-
logger.info("Loading Wav2Vec2 model and
|
| 214 |
self.wav2vec2_model = Wav2Vec2Model.from_pretrained(
|
| 215 |
model_name,
|
| 216 |
token=hf_token if hf_token else None
|
| 217 |
)
|
| 218 |
|
| 219 |
-
|
|
|
|
|
|
|
| 220 |
model_name,
|
| 221 |
token=hf_token if hf_token else None
|
| 222 |
)
|
|
@@ -281,16 +296,49 @@ class SpeechPathologyClassifier(nn.Module):
|
|
| 281 |
- articulation_probs: Articulation class probabilities
|
| 282 |
- wav2vec2_features: Raw Wav2Vec2 features (for debugging)
|
| 283 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
# Extract features using Wav2Vec2
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
# Get last hidden state (features)
|
| 292 |
features = wav2vec2_outputs.last_hidden_state # (batch_size, seq_len, feature_dim)
|
| 293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
# Pass through classifier head
|
| 295 |
outputs = self.classifier_head(features, attention_mask)
|
| 296 |
|
|
|
|
| 11 |
import torch
|
| 12 |
import torch.nn as nn
|
| 13 |
from torch.nn import functional as F
|
| 14 |
+
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2Config
|
| 15 |
from typing import Dict, Optional, Tuple, List
|
| 16 |
import os
|
| 17 |
|
|
|
|
| 51 |
|
| 52 |
self.num_articulation_classes = num_articulation_classes
|
| 53 |
|
| 54 |
+
# Build shared feature layers: 1024 → 512 → 256
|
| 55 |
layers = []
|
| 56 |
prev_dim = input_dim
|
| 57 |
|
|
|
|
| 67 |
self.shared_layers = nn.Sequential(*layers)
|
| 68 |
shared_output_dim = prev_dim
|
| 69 |
|
| 70 |
+
# Fluency head: 256 → 64 → 2 (stutter/normal)
|
| 71 |
self.fluency_head = nn.Sequential(
|
| 72 |
nn.Linear(shared_output_dim, 64),
|
| 73 |
nn.ReLU(),
|
| 74 |
nn.Dropout(dropout),
|
| 75 |
+
nn.Linear(64, 2), # 2 classes: stutter/normal
|
| 76 |
)
|
| 77 |
|
| 78 |
+
# Articulation head: 256 → 64 → 4 (normal/sub/omit/dist)
|
| 79 |
self.articulation_head = nn.Sequential(
|
| 80 |
nn.Linear(shared_output_dim, 64),
|
| 81 |
nn.ReLU(),
|
|
|
|
| 83 |
nn.Linear(64, num_articulation_classes), # 4 classes
|
| 84 |
)
|
| 85 |
|
| 86 |
+
# Full combined head: 256 → 128 → 8 (all classes combined)
|
| 87 |
+
self.full_head = nn.Sequential(
|
| 88 |
+
nn.Linear(shared_output_dim, 128),
|
| 89 |
+
nn.ReLU(),
|
| 90 |
+
nn.Dropout(dropout),
|
| 91 |
+
nn.Linear(128, 8), # 8 classes (combined fluency + articulation)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
logger.info(
|
| 95 |
f"Initialized MultiTaskClassifierHead: "
|
| 96 |
f"input_dim={input_dim}, hidden_dims={hidden_dims}, "
|
|
|
|
| 132 |
shared_features = self.shared_layers(pooled_features)
|
| 133 |
|
| 134 |
# Task-specific heads
|
| 135 |
+
fluency_logits = self.fluency_head(shared_features) # (batch, 2)
|
| 136 |
+
articulation_logits = self.articulation_head(shared_features) # (batch, 4)
|
| 137 |
+
full_logits = self.full_head(shared_features) # (batch, 8)
|
| 138 |
|
| 139 |
# Apply activations
|
| 140 |
+
fluency_probs = F.softmax(fluency_logits, dim=-1) # (batch, 2)
|
| 141 |
+
articulation_probs = F.softmax(articulation_logits, dim=-1) # (batch, 4)
|
| 142 |
+
full_probs = F.softmax(full_logits, dim=-1) # (batch, 8)
|
| 143 |
|
| 144 |
return {
|
| 145 |
"fluency_logits": fluency_logits,
|
| 146 |
"articulation_logits": articulation_logits,
|
| 147 |
+
"full_logits": full_logits,
|
| 148 |
"fluency_probs": fluency_probs,
|
| 149 |
"articulation_probs": articulation_probs,
|
| 150 |
+
"full_probs": full_probs,
|
| 151 |
+
"shared_features": shared_features,
|
| 152 |
}
|
| 153 |
|
| 154 |
|
|
|
|
| 223 |
# Load Wav2Vec2 model and processor
|
| 224 |
hf_token = os.getenv("HF_TOKEN")
|
| 225 |
|
| 226 |
+
logger.info("Loading Wav2Vec2 model and feature extractor...")
|
| 227 |
self.wav2vec2_model = Wav2Vec2Model.from_pretrained(
|
| 228 |
model_name,
|
| 229 |
token=hf_token if hf_token else None
|
| 230 |
)
|
| 231 |
|
| 232 |
+
# Use FeatureExtractor instead of Processor for feature extraction tasks
|
| 233 |
+
# Processor includes tokenizer which requires vocab file (not available for pre-trained models)
|
| 234 |
+
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
|
| 235 |
model_name,
|
| 236 |
token=hf_token if hf_token else None
|
| 237 |
)
|
|
|
|
| 296 |
- articulation_probs: Articulation class probabilities
|
| 297 |
- wav2vec2_features: Raw Wav2Vec2 features (for debugging)
|
| 298 |
"""
|
| 299 |
+
# #region agent log
|
| 300 |
+
try:
|
| 301 |
+
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
|
| 302 |
+
import json, time
|
| 303 |
+
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Before Wav2Vec2 forward","data":{"input_values_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n')
|
| 304 |
+
except: pass
|
| 305 |
+
# #endregion
|
| 306 |
+
|
| 307 |
# Extract features using Wav2Vec2
|
| 308 |
+
try:
|
| 309 |
+
with torch.no_grad() if not self.training else torch.enable_grad():
|
| 310 |
+
wav2vec2_outputs = self.wav2vec2_model(
|
| 311 |
+
input_values=input_values,
|
| 312 |
+
attention_mask=attention_mask
|
| 313 |
+
)
|
| 314 |
+
except Exception as e:
|
| 315 |
+
# #region agent log
|
| 316 |
+
try:
|
| 317 |
+
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
|
| 318 |
+
import json, time
|
| 319 |
+
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Wav2Vec2 forward exception","data":{"error":str(e),"error_type":type(e).__name__,"input_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n')
|
| 320 |
+
except: pass
|
| 321 |
+
# #endregion
|
| 322 |
+
raise
|
| 323 |
|
| 324 |
# Get last hidden state (features)
|
| 325 |
features = wav2vec2_outputs.last_hidden_state # (batch_size, seq_len, feature_dim)
|
| 326 |
|
| 327 |
+
# #region agent log
|
| 328 |
+
try:
|
| 329 |
+
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
|
| 330 |
+
import json, time
|
| 331 |
+
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:297","message":"After Wav2Vec2 forward","data":{"features_shape":list(features.shape),"seq_len":features.shape[1] if len(features.shape) > 1 else 0},"timestamp":int(time.time()*1000)}) + '\n')
|
| 332 |
+
except: pass
|
| 333 |
+
# #endregion
|
| 334 |
+
|
| 335 |
+
# Safety check: ensure sequence length is valid (at least 1)
|
| 336 |
+
if features.shape[1] < 1:
|
| 337 |
+
raise ValueError(
|
| 338 |
+
f"Wav2Vec2 output sequence length is too short: {features.shape[1]}. "
|
| 339 |
+
f"Input was {input_values.shape}. Try using longer audio segments (>= 500ms)."
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
# Pass through classifier head
|
| 343 |
outputs = self.classifier_head(features, attention_mask)
|
| 344 |
|
tests/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test module for speech pathology diagnosis system.
|
| 3 |
+
"""
|
| 4 |
+
|
tests/integration_tests.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration tests for speech pathology diagnosis API.
|
| 3 |
+
|
| 4 |
+
Tests API endpoints, error mapping, and therapy recommendations.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
import numpy as np
|
| 9 |
+
import tempfile
|
| 10 |
+
import soundfile as sf
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def test_phoneme_mapping():
|
| 18 |
+
"""Test phoneme mapping functionality."""
|
| 19 |
+
logger.info("Testing phoneme mapping...")
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from models.phoneme_mapper import PhonemeMapper
|
| 23 |
+
|
| 24 |
+
mapper = PhonemeMapper(frame_duration_ms=20)
|
| 25 |
+
|
| 26 |
+
# Test 1: Simple word
|
| 27 |
+
phonemes = mapper.text_to_phonemes("robot")
|
| 28 |
+
assert len(phonemes) > 0, "Should extract phonemes"
|
| 29 |
+
logger.info(f"✅ 'robot' → {len(phonemes)} phonemes: {[p[0] for p in phonemes]}")
|
| 30 |
+
|
| 31 |
+
# Test 2: Frame alignment
|
| 32 |
+
frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25)
|
| 33 |
+
assert len(frame_phonemes) == 25, "Should have 25 frames"
|
| 34 |
+
logger.info(f"✅ Aligned to {len(frame_phonemes)} frames")
|
| 35 |
+
|
| 36 |
+
# Test 3: Complete pipeline
|
| 37 |
+
cat_frames = mapper.map_text_to_frames("cat", num_frames=15)
|
| 38 |
+
assert len(cat_frames) == 15, "Should have 15 frames"
|
| 39 |
+
logger.info(f"✅ 'cat' → {len(cat_frames)} frame phonemes")
|
| 40 |
+
|
| 41 |
+
return True
|
| 42 |
+
|
| 43 |
+
except ImportError as e:
|
| 44 |
+
logger.warning(f"⚠️ G2P library not available: {e}")
|
| 45 |
+
return False
|
| 46 |
+
except Exception as e:
|
| 47 |
+
logger.error(f"❌ Phoneme mapping test failed: {e}")
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_error_taxonomy():
|
| 52 |
+
"""Test error taxonomy and therapy mapping."""
|
| 53 |
+
logger.info("Testing error taxonomy...")
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel
|
| 57 |
+
|
| 58 |
+
mapper = ErrorMapper()
|
| 59 |
+
|
| 60 |
+
# Test 1: Normal (class 0)
|
| 61 |
+
error = mapper.map_classifier_output(0, 0.95, "/k/")
|
| 62 |
+
assert error.error_type == ErrorType.NORMAL
|
| 63 |
+
assert error.severity == 0.0
|
| 64 |
+
logger.info(f"✅ Normal error mapping: {error.error_type}")
|
| 65 |
+
|
| 66 |
+
# Test 2: Substitution (class 1)
|
| 67 |
+
error = mapper.map_classifier_output(1, 0.78, "/s/")
|
| 68 |
+
assert error.error_type == ErrorType.SUBSTITUTION
|
| 69 |
+
assert error.wrong_sound is not None
|
| 70 |
+
logger.info(f"✅ Substitution error: {error.error_type}, wrong_sound={error.wrong_sound}")
|
| 71 |
+
logger.info(f" Therapy: {error.therapy[:60]}...")
|
| 72 |
+
|
| 73 |
+
# Test 3: Omission (class 2)
|
| 74 |
+
error = mapper.map_classifier_output(2, 0.85, "/r/")
|
| 75 |
+
assert error.error_type == ErrorType.OMISSION
|
| 76 |
+
logger.info(f"✅ Omission error: {error.error_type}")
|
| 77 |
+
logger.info(f" Therapy: {error.therapy[:60]}...")
|
| 78 |
+
|
| 79 |
+
# Test 4: Distortion (class 3)
|
| 80 |
+
error = mapper.map_classifier_output(3, 0.65, "/s/")
|
| 81 |
+
assert error.error_type == ErrorType.DISTORTION
|
| 82 |
+
logger.info(f"✅ Distortion error: {error.error_type}")
|
| 83 |
+
logger.info(f" Therapy: {error.therapy[:60]}...")
|
| 84 |
+
|
| 85 |
+
# Test 5: Severity levels
|
| 86 |
+
assert mapper.get_severity_level(0.0) == SeverityLevel.NONE
|
| 87 |
+
assert mapper.get_severity_level(0.2) == SeverityLevel.LOW
|
| 88 |
+
assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM
|
| 89 |
+
assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH
|
| 90 |
+
logger.info("✅ Severity level mapping correct")
|
| 91 |
+
|
| 92 |
+
return True
|
| 93 |
+
|
| 94 |
+
except Exception as e:
|
| 95 |
+
logger.error(f"❌ Error taxonomy test failed: {e}")
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def test_batch_diagnosis_endpoint(pipeline, phoneme_mapper, error_mapper):
|
| 100 |
+
"""Test batch diagnosis endpoint functionality."""
|
| 101 |
+
logger.info("Testing batch diagnosis endpoint...")
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
# Generate test audio
|
| 105 |
+
duration = 2.0
|
| 106 |
+
sample_rate = 16000
|
| 107 |
+
num_samples = int(duration * sample_rate)
|
| 108 |
+
audio = 0.5 * np.sin(2 * np.pi * 440 * np.linspace(0, duration, num_samples))
|
| 109 |
+
audio = audio.astype(np.float32)
|
| 110 |
+
|
| 111 |
+
# Save to temp file
|
| 112 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
| 113 |
+
temp_path = f.name
|
| 114 |
+
sf.write(temp_path, audio, sample_rate)
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Run inference
|
| 118 |
+
result = pipeline.predict_phone_level(temp_path, return_timestamps=True)
|
| 119 |
+
|
| 120 |
+
# Map phonemes
|
| 121 |
+
text = "test audio"
|
| 122 |
+
frame_phonemes = phoneme_mapper.map_text_to_frames(
|
| 123 |
+
text,
|
| 124 |
+
num_frames=result.num_frames,
|
| 125 |
+
audio_duration=result.duration
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Process errors
|
| 129 |
+
errors = []
|
| 130 |
+
for i, frame_pred in enumerate(result.frame_predictions):
|
| 131 |
+
class_id = frame_pred.articulation_class
|
| 132 |
+
if frame_pred.fluency_label == 'stutter':
|
| 133 |
+
class_id += 4
|
| 134 |
+
|
| 135 |
+
error_detail = error_mapper.map_classifier_output(
|
| 136 |
+
class_id=class_id,
|
| 137 |
+
confidence=frame_pred.confidence,
|
| 138 |
+
phoneme=frame_phonemes[i] if i < len(frame_phonemes) else '',
|
| 139 |
+
fluency_label=frame_pred.fluency_label
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if error_detail.error_type != ErrorType.NORMAL:
|
| 143 |
+
errors.append(error_detail)
|
| 144 |
+
|
| 145 |
+
logger.info(f"✅ Batch diagnosis: {result.num_frames} frames, {len(errors)} errors detected")
|
| 146 |
+
|
| 147 |
+
return True
|
| 148 |
+
|
| 149 |
+
finally:
|
| 150 |
+
import os
|
| 151 |
+
if os.path.exists(temp_path):
|
| 152 |
+
os.remove(temp_path)
|
| 153 |
+
|
| 154 |
+
except Exception as e:
|
| 155 |
+
logger.error(f"❌ Batch diagnosis test failed: {e}")
|
| 156 |
+
return False
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def test_therapy_recommendations():
|
| 160 |
+
"""Test therapy recommendation coverage."""
|
| 161 |
+
logger.info("Testing therapy recommendations...")
|
| 162 |
+
|
| 163 |
+
try:
|
| 164 |
+
from models.error_taxonomy import ErrorMapper, ErrorType
|
| 165 |
+
|
| 166 |
+
mapper = ErrorMapper()
|
| 167 |
+
|
| 168 |
+
# Test common phonemes
|
| 169 |
+
test_cases = [
|
| 170 |
+
("/s/", ErrorType.SUBSTITUTION, "/θ/"),
|
| 171 |
+
("/r/", ErrorType.OMISSION, None),
|
| 172 |
+
("/s/", ErrorType.DISTORTION, None),
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
+
for phoneme, error_type, wrong_sound in test_cases:
|
| 176 |
+
therapy = mapper.get_therapy(error_type, phoneme, wrong_sound)
|
| 177 |
+
assert therapy and len(therapy) > 0, f"Therapy should not be empty for {phoneme}"
|
| 178 |
+
logger.info(f"✅ {phoneme} {error_type.value}: {therapy[:50]}...")
|
| 179 |
+
|
| 180 |
+
return True
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
logger.error(f"❌ Therapy recommendations test failed: {e}")
|
| 184 |
+
return False
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def run_all_integration_tests():
|
| 188 |
+
"""Run all integration tests."""
|
| 189 |
+
logger.info("=" * 60)
|
| 190 |
+
logger.info("Running Integration Tests")
|
| 191 |
+
logger.info("=" * 60)
|
| 192 |
+
|
| 193 |
+
results = {}
|
| 194 |
+
|
| 195 |
+
# Test 1: Phoneme mapping
|
| 196 |
+
logger.info("\n1. Phoneme Mapping Test")
|
| 197 |
+
results["phoneme_mapping"] = test_phoneme_mapping()
|
| 198 |
+
|
| 199 |
+
# Test 2: Error taxonomy
|
| 200 |
+
logger.info("\n2. Error Taxonomy Test")
|
| 201 |
+
results["error_taxonomy"] = test_error_taxonomy()
|
| 202 |
+
|
| 203 |
+
# Test 3: Therapy recommendations
|
| 204 |
+
logger.info("\n3. Therapy Recommendations Test")
|
| 205 |
+
results["therapy_recommendations"] = test_therapy_recommendations()
|
| 206 |
+
|
| 207 |
+
# Test 4: Batch diagnosis (if pipeline available)
|
| 208 |
+
try:
|
| 209 |
+
from inference.inference_pipeline import create_inference_pipeline
|
| 210 |
+
from models.phoneme_mapper import PhonemeMapper
|
| 211 |
+
from models.error_taxonomy import ErrorMapper
|
| 212 |
+
|
| 213 |
+
logger.info("\n4. Batch Diagnosis Test")
|
| 214 |
+
pipeline = create_inference_pipeline()
|
| 215 |
+
phoneme_mapper = PhonemeMapper()
|
| 216 |
+
error_mapper = ErrorMapper()
|
| 217 |
+
|
| 218 |
+
results["batch_diagnosis"] = test_batch_diagnosis_endpoint(
|
| 219 |
+
pipeline, phoneme_mapper, error_mapper
|
| 220 |
+
)
|
| 221 |
+
except Exception as e:
|
| 222 |
+
logger.warning(f"⚠️ Batch diagnosis test skipped: {e}")
|
| 223 |
+
results["batch_diagnosis"] = False
|
| 224 |
+
|
| 225 |
+
# Summary
|
| 226 |
+
logger.info("\n" + "=" * 60)
|
| 227 |
+
logger.info("Integration Test Summary")
|
| 228 |
+
logger.info("=" * 60)
|
| 229 |
+
|
| 230 |
+
for test_name, passed in results.items():
|
| 231 |
+
status = "✅ PASSED" if passed else "❌ FAILED"
|
| 232 |
+
logger.info(f"{status}: {test_name}")
|
| 233 |
+
|
| 234 |
+
all_passed = all(results.values())
|
| 235 |
+
return all_passed, results
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
logging.basicConfig(level=logging.INFO)
|
| 240 |
+
|
| 241 |
+
all_passed, results = run_all_integration_tests()
|
| 242 |
+
|
| 243 |
+
if all_passed:
|
| 244 |
+
logger.info("\n✅ All integration tests passed!")
|
| 245 |
+
exit(0)
|
| 246 |
+
else:
|
| 247 |
+
logger.error("\n❌ Some integration tests failed!")
|
| 248 |
+
exit(1)
|
| 249 |
+
|
tests/performance_tests.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Performance tests for speech pathology diagnosis system.
|
| 3 |
+
|
| 4 |
+
Tests latency requirements:
|
| 5 |
+
- File batch: <200ms per file
|
| 6 |
+
- Per-frame: <50ms
|
| 7 |
+
- WebSocket roundtrip: <100ms
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import time
|
| 11 |
+
import numpy as np
|
| 12 |
+
import logging
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import asyncio
|
| 15 |
+
from typing import Dict, List
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def generate_test_audio(duration_seconds: float = 1.0, sample_rate: int = 16000) -> np.ndarray:
|
| 21 |
+
"""
|
| 22 |
+
Generate synthetic test audio.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
duration_seconds: Duration in seconds
|
| 26 |
+
sample_rate: Sample rate in Hz
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Audio array
|
| 30 |
+
"""
|
| 31 |
+
num_samples = int(duration_seconds * sample_rate)
|
| 32 |
+
# Generate simple sine wave
|
| 33 |
+
t = np.linspace(0, duration_seconds, num_samples)
|
| 34 |
+
audio = 0.5 * np.sin(2 * np.pi * 440 * t) # 440 Hz tone
|
| 35 |
+
return audio.astype(np.float32)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def test_batch_latency(pipeline, num_files: int = 10) -> Dict[str, float]:
|
| 39 |
+
"""
|
| 40 |
+
Test batch file processing latency.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
pipeline: InferencePipeline instance
|
| 44 |
+
num_files: Number of test files to process
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
Dictionary with latency statistics
|
| 48 |
+
"""
|
| 49 |
+
logger.info(f"Testing batch latency with {num_files} files...")
|
| 50 |
+
|
| 51 |
+
latencies = []
|
| 52 |
+
|
| 53 |
+
for i in range(num_files):
|
| 54 |
+
# Generate test audio
|
| 55 |
+
audio = generate_test_audio(duration_seconds=1.0)
|
| 56 |
+
|
| 57 |
+
# Save to temp file
|
| 58 |
+
import tempfile
|
| 59 |
+
import soundfile as sf
|
| 60 |
+
|
| 61 |
+
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
| 62 |
+
temp_path = f.name
|
| 63 |
+
sf.write(temp_path, audio, 16000)
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
start_time = time.time()
|
| 67 |
+
result = pipeline.predict_phone_level(temp_path, return_timestamps=True)
|
| 68 |
+
latency_ms = (time.time() - start_time) * 1000
|
| 69 |
+
latencies.append(latency_ms)
|
| 70 |
+
|
| 71 |
+
logger.info(f" File {i+1}: {latency_ms:.1f}ms ({result.num_frames} frames)")
|
| 72 |
+
except Exception as e:
|
| 73 |
+
logger.error(f" File {i+1} failed: {e}")
|
| 74 |
+
finally:
|
| 75 |
+
import os
|
| 76 |
+
if os.path.exists(temp_path):
|
| 77 |
+
os.remove(temp_path)
|
| 78 |
+
|
| 79 |
+
if not latencies:
|
| 80 |
+
return {"error": "No successful runs"}
|
| 81 |
+
|
| 82 |
+
avg_latency = sum(latencies) / len(latencies)
|
| 83 |
+
max_latency = max(latencies)
|
| 84 |
+
min_latency = min(latencies)
|
| 85 |
+
|
| 86 |
+
result = {
|
| 87 |
+
"avg_latency_ms": avg_latency,
|
| 88 |
+
"max_latency_ms": max_latency,
|
| 89 |
+
"min_latency_ms": min_latency,
|
| 90 |
+
"num_files": len(latencies),
|
| 91 |
+
"target_ms": 200.0,
|
| 92 |
+
"passed": avg_latency < 200.0
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
logger.info(f"✅ Batch latency test: avg={avg_latency:.1f}ms, max={max_latency:.1f}ms, "
|
| 96 |
+
f"target=200ms, passed={result['passed']}")
|
| 97 |
+
|
| 98 |
+
return result
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def test_frame_latency(pipeline, num_frames: int = 100) -> Dict[str, float]:
|
| 102 |
+
"""
|
| 103 |
+
Test per-frame processing latency.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
pipeline: InferencePipeline instance
|
| 107 |
+
num_frames: Number of frames to test
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
Dictionary with latency statistics
|
| 111 |
+
"""
|
| 112 |
+
logger.info(f"Testing frame latency with {num_frames} frames...")
|
| 113 |
+
|
| 114 |
+
# Generate 1 second of audio (enough for one window)
|
| 115 |
+
audio = generate_test_audio(duration_seconds=1.0)
|
| 116 |
+
|
| 117 |
+
latencies = []
|
| 118 |
+
|
| 119 |
+
for i in range(num_frames):
|
| 120 |
+
start_time = time.time()
|
| 121 |
+
try:
|
| 122 |
+
result = pipeline.predict_phone_level(audio, return_timestamps=False)
|
| 123 |
+
latency_ms = (time.time() - start_time) * 1000
|
| 124 |
+
latencies.append(latency_ms)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
logger.error(f" Frame {i+1} failed: {e}")
|
| 127 |
+
|
| 128 |
+
if not latencies:
|
| 129 |
+
return {"error": "No successful runs"}
|
| 130 |
+
|
| 131 |
+
avg_latency = sum(latencies) / len(latencies)
|
| 132 |
+
max_latency = max(latencies)
|
| 133 |
+
min_latency = min(latencies)
|
| 134 |
+
p95_latency = sorted(latencies)[int(len(latencies) * 0.95)]
|
| 135 |
+
|
| 136 |
+
result = {
|
| 137 |
+
"avg_latency_ms": avg_latency,
|
| 138 |
+
"max_latency_ms": max_latency,
|
| 139 |
+
"min_latency_ms": min_latency,
|
| 140 |
+
"p95_latency_ms": p95_latency,
|
| 141 |
+
"num_frames": len(latencies),
|
| 142 |
+
"target_ms": 50.0,
|
| 143 |
+
"passed": avg_latency < 50.0
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
logger.info(f"✅ Frame latency test: avg={avg_latency:.1f}ms, p95={p95_latency:.1f}ms, "
|
| 147 |
+
f"target=50ms, passed={result['passed']}")
|
| 148 |
+
|
| 149 |
+
return result
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
async def test_websocket_latency(websocket_url: str, num_chunks: int = 50) -> Dict[str, float]:
|
| 153 |
+
"""
|
| 154 |
+
Test WebSocket streaming latency.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
websocket_url: WebSocket URL
|
| 158 |
+
num_chunks: Number of chunks to send
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Dictionary with latency statistics
|
| 162 |
+
"""
|
| 163 |
+
try:
|
| 164 |
+
import websockets
|
| 165 |
+
|
| 166 |
+
logger.info(f"Testing WebSocket latency with {num_chunks} chunks...")
|
| 167 |
+
|
| 168 |
+
latencies = []
|
| 169 |
+
|
| 170 |
+
async with websockets.connect(websocket_url) as websocket:
|
| 171 |
+
# Generate test audio chunk (20ms @ 16kHz = 320 samples)
|
| 172 |
+
chunk_samples = 320
|
| 173 |
+
audio_chunk = generate_test_audio(duration_seconds=0.02)
|
| 174 |
+
chunk_bytes = (audio_chunk * 32768).astype(np.int16).tobytes()
|
| 175 |
+
|
| 176 |
+
for i in range(num_chunks):
|
| 177 |
+
start_time = time.time()
|
| 178 |
+
|
| 179 |
+
# Send chunk
|
| 180 |
+
await websocket.send(chunk_bytes)
|
| 181 |
+
|
| 182 |
+
# Receive response
|
| 183 |
+
response = await websocket.recv()
|
| 184 |
+
|
| 185 |
+
latency_ms = (time.time() - start_time) * 1000
|
| 186 |
+
latencies.append(latency_ms)
|
| 187 |
+
|
| 188 |
+
if i % 10 == 0:
|
| 189 |
+
logger.info(f" Chunk {i+1}: {latency_ms:.1f}ms")
|
| 190 |
+
|
| 191 |
+
if not latencies:
|
| 192 |
+
return {"error": "No successful runs"}
|
| 193 |
+
|
| 194 |
+
avg_latency = sum(latencies) / len(latencies)
|
| 195 |
+
max_latency = max(latencies)
|
| 196 |
+
p95_latency = sorted(latencies)[int(len(latencies) * 0.95)]
|
| 197 |
+
|
| 198 |
+
result = {
|
| 199 |
+
"avg_latency_ms": avg_latency,
|
| 200 |
+
"max_latency_ms": max_latency,
|
| 201 |
+
"p95_latency_ms": p95_latency,
|
| 202 |
+
"num_chunks": len(latencies),
|
| 203 |
+
"target_ms": 100.0,
|
| 204 |
+
"passed": avg_latency < 100.0
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
logger.info(f"✅ WebSocket latency test: avg={avg_latency:.1f}ms, p95={p95_latency:.1f}ms, "
|
| 208 |
+
f"target=100ms, passed={result['passed']}")
|
| 209 |
+
|
| 210 |
+
return result
|
| 211 |
+
|
| 212 |
+
except ImportError:
|
| 213 |
+
logger.warning("websockets library not available, skipping WebSocket test")
|
| 214 |
+
return {"error": "websockets library not available"}
|
| 215 |
+
except Exception as e:
|
| 216 |
+
logger.error(f"WebSocket test failed: {e}")
|
| 217 |
+
return {"error": str(e)}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def test_concurrent_connections(pipeline, num_connections: int = 10) -> Dict[str, Any]:
|
| 221 |
+
"""
|
| 222 |
+
Test concurrent processing (simulated).
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
pipeline: InferencePipeline instance
|
| 226 |
+
num_connections: Number of concurrent requests
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Dictionary with results
|
| 230 |
+
"""
|
| 231 |
+
logger.info(f"Testing {num_connections} concurrent connections...")
|
| 232 |
+
|
| 233 |
+
import concurrent.futures
|
| 234 |
+
|
| 235 |
+
def process_audio(i: int):
|
| 236 |
+
try:
|
| 237 |
+
audio = generate_test_audio(duration_seconds=0.5)
|
| 238 |
+
start_time = time.time()
|
| 239 |
+
result = pipeline.predict_phone_level(audio, return_timestamps=False)
|
| 240 |
+
latency_ms = (time.time() - start_time) * 1000
|
| 241 |
+
return {"success": True, "latency_ms": latency_ms, "frames": result.num_frames}
|
| 242 |
+
except Exception as e:
|
| 243 |
+
return {"success": False, "error": str(e)}
|
| 244 |
+
|
| 245 |
+
start_time = time.time()
|
| 246 |
+
|
| 247 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_connections) as executor:
|
| 248 |
+
futures = [executor.submit(process_audio, i) for i in range(num_connections)]
|
| 249 |
+
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
| 250 |
+
|
| 251 |
+
total_time = time.time() - start_time
|
| 252 |
+
|
| 253 |
+
successful = sum(1 for r in results if r.get("success", False))
|
| 254 |
+
avg_latency = sum(r["latency_ms"] for r in results if r.get("success", False)) / successful if successful > 0 else 0.0
|
| 255 |
+
|
| 256 |
+
result = {
|
| 257 |
+
"total_connections": num_connections,
|
| 258 |
+
"successful": successful,
|
| 259 |
+
"failed": num_connections - successful,
|
| 260 |
+
"total_time_seconds": total_time,
|
| 261 |
+
"avg_latency_ms": avg_latency,
|
| 262 |
+
"throughput_per_second": successful / total_time if total_time > 0 else 0.0
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
logger.info(f"✅ Concurrent test: {successful}/{num_connections} successful, "
|
| 266 |
+
f"avg_latency={avg_latency:.1f}ms, throughput={result['throughput_per_second']:.1f}/s")
|
| 267 |
+
|
| 268 |
+
return result
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def run_all_performance_tests(pipeline, websocket_url: Optional[str] = None) -> Dict[str, Any]:
|
| 272 |
+
"""
|
| 273 |
+
Run all performance tests.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
pipeline: InferencePipeline instance
|
| 277 |
+
websocket_url: Optional WebSocket URL for streaming tests
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
Dictionary with all test results
|
| 281 |
+
"""
|
| 282 |
+
logger.info("=" * 60)
|
| 283 |
+
logger.info("Running Performance Tests")
|
| 284 |
+
logger.info("=" * 60)
|
| 285 |
+
|
| 286 |
+
results = {}
|
| 287 |
+
|
| 288 |
+
# Test 1: Batch latency
|
| 289 |
+
logger.info("\n1. Batch File Latency Test")
|
| 290 |
+
results["batch_latency"] = test_batch_latency(pipeline)
|
| 291 |
+
|
| 292 |
+
# Test 2: Frame latency
|
| 293 |
+
logger.info("\n2. Per-Frame Latency Test")
|
| 294 |
+
results["frame_latency"] = test_frame_latency(pipeline)
|
| 295 |
+
|
| 296 |
+
# Test 3: Concurrent connections
|
| 297 |
+
logger.info("\n3. Concurrent Connections Test")
|
| 298 |
+
results["concurrent"] = test_concurrent_connections(pipeline, num_connections=10)
|
| 299 |
+
|
| 300 |
+
# Test 4: WebSocket latency (if URL provided)
|
| 301 |
+
if websocket_url:
|
| 302 |
+
logger.info("\n4. WebSocket Latency Test")
|
| 303 |
+
results["websocket_latency"] = asyncio.run(test_websocket_latency(websocket_url))
|
| 304 |
+
|
| 305 |
+
# Summary
|
| 306 |
+
logger.info("\n" + "=" * 60)
|
| 307 |
+
logger.info("Performance Test Summary")
|
| 308 |
+
logger.info("=" * 60)
|
| 309 |
+
|
| 310 |
+
if "batch_latency" in results and results["batch_latency"].get("passed"):
|
| 311 |
+
logger.info("✅ Batch latency: PASSED")
|
| 312 |
+
else:
|
| 313 |
+
logger.warning("❌ Batch latency: FAILED")
|
| 314 |
+
|
| 315 |
+
if "frame_latency" in results and results["frame_latency"].get("passed"):
|
| 316 |
+
logger.info("✅ Frame latency: PASSED")
|
| 317 |
+
else:
|
| 318 |
+
logger.warning("❌ Frame latency: FAILED")
|
| 319 |
+
|
| 320 |
+
if "websocket_latency" in results and results["websocket_latency"].get("passed"):
|
| 321 |
+
logger.info("✅ WebSocket latency: PASSED")
|
| 322 |
+
elif "websocket_latency" in results:
|
| 323 |
+
logger.warning("❌ WebSocket latency: FAILED")
|
| 324 |
+
|
| 325 |
+
return results
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if __name__ == "__main__":
|
| 329 |
+
# Example usage
|
| 330 |
+
logging.basicConfig(level=logging.INFO)
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
+
from inference.inference_pipeline import create_inference_pipeline
|
| 334 |
+
|
| 335 |
+
pipeline = create_inference_pipeline()
|
| 336 |
+
results = run_all_performance_tests(pipeline)
|
| 337 |
+
|
| 338 |
+
print("\nTest Results:")
|
| 339 |
+
import json
|
| 340 |
+
print(json.dumps(results, indent=2))
|
| 341 |
+
|
| 342 |
+
except Exception as e:
|
| 343 |
+
logger.error(f"Test failed: {e}", exc_info=True)
|
| 344 |
+
|
ui/gradio_interface.py
CHANGED
|
@@ -163,21 +163,21 @@ def analyze_speech(
|
|
| 163 |
try:
|
| 164 |
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
|
| 165 |
import json
|
| 166 |
-
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:139","message":"After predict_batch call","data":{"success":True},"timestamp":int(time.time()*1000)}) + '\n')
|
| 167 |
except: pass
|
| 168 |
# #endregion
|
| 169 |
|
| 170 |
# Calculate processing time
|
| 171 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 172 |
|
| 173 |
-
# Extract metrics
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
|
| 178 |
-
#
|
| 179 |
-
|
| 180 |
-
fluent_frames_percentage =
|
| 181 |
|
| 182 |
# Format fluency score with color coding
|
| 183 |
if fluency_percentage >= 80:
|
|
@@ -203,10 +203,22 @@ def analyze_speech(
|
|
| 203 |
"""
|
| 204 |
|
| 205 |
# Format articulation issues
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
|
| 208 |
-
# Format confidence
|
| 209 |
-
confidence_percentage = result.confidence * 100
|
| 210 |
confidence_html = f"""
|
| 211 |
<div style='text-align: center; padding: 10px;'>
|
| 212 |
<h3 style='color: #2196F3; font-size: 32px; margin: 5px 0;'>
|
|
@@ -223,7 +235,7 @@ def analyze_speech(
|
|
| 223 |
⏱️ Processing Time: <strong>{processing_time_ms:.0f}ms</strong>
|
| 224 |
</p>
|
| 225 |
<p style='color: #999; font-size: 12px;'>
|
| 226 |
-
Analyzed {
|
| 227 |
</p>
|
| 228 |
</div>
|
| 229 |
"""
|
|
@@ -232,24 +244,32 @@ def analyze_speech(
|
|
| 232 |
json_output = {
|
| 233 |
"status": "success",
|
| 234 |
"fluency_metrics": {
|
| 235 |
-
"mean_fluency":
|
| 236 |
"fluency_percentage": fluency_percentage,
|
| 237 |
-
"fluent_frames_ratio":
|
| 238 |
"fluent_frames_percentage": fluent_frames_percentage,
|
| 239 |
-
"
|
| 240 |
-
"min": fluency_metrics.get("min", 0.0),
|
| 241 |
-
"max": fluency_metrics.get("max", 0.0),
|
| 242 |
-
"median": fluency_metrics.get("median", 0.0)
|
| 243 |
},
|
| 244 |
"articulation_results": {
|
| 245 |
-
"total_frames":
|
| 246 |
-
"
|
| 247 |
-
"
|
|
|
|
| 248 |
},
|
| 249 |
-
"confidence":
|
| 250 |
"confidence_percentage": confidence_percentage,
|
| 251 |
"processing_time_ms": processing_time_ms,
|
| 252 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
}
|
| 254 |
|
| 255 |
logger.info(f"✅ Analysis complete: fluency={fluency_percentage:.1f}%, "
|
|
|
|
| 163 |
try:
|
| 164 |
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
|
| 165 |
import json
|
| 166 |
+
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"B","location":"gradio_interface.py:139","message":"After predict_batch call","data":{"success":True,"num_frames":result.num_frames},"timestamp":int(time.time()*1000)}) + '\n')
|
| 167 |
except: pass
|
| 168 |
# #endregion
|
| 169 |
|
| 170 |
# Calculate processing time
|
| 171 |
processing_time_ms = (time.time() - start_time) * 1000
|
| 172 |
|
| 173 |
+
# Extract metrics from new PhoneLevelResult format
|
| 174 |
+
aggregate = result.aggregate
|
| 175 |
+
mean_fluency_stutter = aggregate.get("fluency_score", 0.0)
|
| 176 |
+
fluency_percentage = (1.0 - mean_fluency_stutter) * 100 # Convert stutter prob to fluency percentage
|
| 177 |
|
| 178 |
+
# Count fluent frames
|
| 179 |
+
fluent_frames = sum(1 for fp in result.frame_predictions if fp.fluency_label == 'normal')
|
| 180 |
+
fluent_frames_percentage = (fluent_frames / result.num_frames * 100) if result.num_frames > 0 else 0.0
|
| 181 |
|
| 182 |
# Format fluency score with color coding
|
| 183 |
if fluency_percentage >= 80:
|
|
|
|
| 203 |
"""
|
| 204 |
|
| 205 |
# Format articulation issues
|
| 206 |
+
articulation_class = aggregate.get("articulation_class", 0)
|
| 207 |
+
articulation_label = aggregate.get("articulation_label", "normal")
|
| 208 |
+
articulation_text = f"**Dominant Class:** {articulation_label.capitalize()}\n\n"
|
| 209 |
+
articulation_text += f"**Frame Breakdown:**\n"
|
| 210 |
+
class_counts = {}
|
| 211 |
+
for fp in result.frame_predictions:
|
| 212 |
+
label = fp.articulation_label
|
| 213 |
+
class_counts[label] = class_counts.get(label, 0) + 1
|
| 214 |
+
for label, count in sorted(class_counts.items(), key=lambda x: x[1], reverse=True):
|
| 215 |
+
percentage = (count / result.num_frames * 100) if result.num_frames > 0 else 0.0
|
| 216 |
+
articulation_text += f"- {label.capitalize()}: {count} frames ({percentage:.1f}%)\n"
|
| 217 |
+
|
| 218 |
+
# Calculate average confidence
|
| 219 |
+
avg_confidence = sum(fp.confidence for fp in result.frame_predictions) / result.num_frames if result.num_frames > 0 else 0.0
|
| 220 |
+
confidence_percentage = avg_confidence * 100
|
| 221 |
|
|
|
|
|
|
|
| 222 |
confidence_html = f"""
|
| 223 |
<div style='text-align: center; padding: 10px;'>
|
| 224 |
<h3 style='color: #2196F3; font-size: 32px; margin: 5px 0;'>
|
|
|
|
| 235 |
⏱️ Processing Time: <strong>{processing_time_ms:.0f}ms</strong>
|
| 236 |
</p>
|
| 237 |
<p style='color: #999; font-size: 12px;'>
|
| 238 |
+
Analyzed {result.num_frames} frames ({result.duration:.2f}s audio)
|
| 239 |
</p>
|
| 240 |
</div>
|
| 241 |
"""
|
|
|
|
| 244 |
json_output = {
|
| 245 |
"status": "success",
|
| 246 |
"fluency_metrics": {
|
| 247 |
+
"mean_fluency": fluency_percentage / 100.0,
|
| 248 |
"fluency_percentage": fluency_percentage,
|
| 249 |
+
"fluent_frames_ratio": fluent_frames / result.num_frames if result.num_frames > 0 else 0.0,
|
| 250 |
"fluent_frames_percentage": fluent_frames_percentage,
|
| 251 |
+
"stutter_probability": mean_fluency_stutter
|
|
|
|
|
|
|
|
|
|
| 252 |
},
|
| 253 |
"articulation_results": {
|
| 254 |
+
"total_frames": result.num_frames,
|
| 255 |
+
"dominant_class": articulation_class,
|
| 256 |
+
"dominant_label": articulation_label,
|
| 257 |
+
"class_distribution": class_counts
|
| 258 |
},
|
| 259 |
+
"confidence": avg_confidence,
|
| 260 |
"confidence_percentage": confidence_percentage,
|
| 261 |
"processing_time_ms": processing_time_ms,
|
| 262 |
+
"frame_predictions": [
|
| 263 |
+
{
|
| 264 |
+
"time": fp.time,
|
| 265 |
+
"fluency_prob": fp.fluency_prob,
|
| 266 |
+
"fluency_label": fp.fluency_label,
|
| 267 |
+
"articulation_class": fp.articulation_class,
|
| 268 |
+
"articulation_label": fp.articulation_label,
|
| 269 |
+
"confidence": fp.confidence
|
| 270 |
+
}
|
| 271 |
+
for fp in result.frame_predictions[:20] # First 20 frames for preview
|
| 272 |
+
]
|
| 273 |
}
|
| 274 |
|
| 275 |
logger.info(f"✅ Analysis complete: fluency={fluency_percentage:.1f}%, "
|