|
|
""" |
|
|
REST API routes for Speech Pathology Diagnosis. |
|
|
|
|
|
This module provides FastAPI endpoints for batch file analysis, |
|
|
session management, and health checks. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import time |
|
|
import tempfile |
|
|
import uuid |
|
|
from pathlib import Path |
|
|
from typing import Optional, List, Dict, Any |
|
|
from datetime import datetime |
|
|
|
|
|
from fastapi import APIRouter, UploadFile, File, HTTPException, Query |
|
|
from fastapi.responses import JSONResponse |
|
|
|
|
|
from api.schemas import ( |
|
|
BatchDiagnosisResponse, |
|
|
FrameDiagnosis, |
|
|
ErrorReport, |
|
|
SummaryMetrics, |
|
|
SessionListResponse, |
|
|
HealthResponse, |
|
|
ErrorDetailSchema, |
|
|
FluencyInfo, |
|
|
ArticulationInfo |
|
|
) |
|
|
from models.phoneme_mapper import PhonemeMapper |
|
|
from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel |
|
|
from inference.inference_pipeline import InferencePipeline |
|
|
from config import AudioConfig, default_audio_config |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
router = APIRouter(prefix="/diagnose", tags=["diagnosis"]) |
|
|
|
|
|
|
|
|
sessions: Dict[str, BatchDiagnosisResponse] = {} |
|
|
|
|
|
|
|
|
inference_pipeline: Optional[InferencePipeline] = None |
|
|
phoneme_mapper: Optional[PhonemeMapper] = None |
|
|
error_mapper: Optional[ErrorMapper] = None |
|
|
|
|
|
|
|
|
def get_phoneme_mapper() -> Optional[PhonemeMapper]: |
|
|
"""Get the global PhonemeMapper instance.""" |
|
|
return phoneme_mapper |
|
|
|
|
|
|
|
|
def get_error_mapper() -> Optional[ErrorMapper]: |
|
|
"""Get the global ErrorMapper instance.""" |
|
|
return error_mapper |
|
|
|
|
|
|
|
|
def initialize_routes( |
|
|
pipeline: InferencePipeline, |
|
|
mapper: Optional[PhonemeMapper] = None, |
|
|
error_mapper_instance: Optional[ErrorMapper] = None |
|
|
): |
|
|
""" |
|
|
Initialize routes with dependencies. |
|
|
|
|
|
Args: |
|
|
pipeline: InferencePipeline instance |
|
|
mapper: Optional PhonemeMapper instance |
|
|
error_mapper_instance: Optional ErrorMapper instance |
|
|
""" |
|
|
global inference_pipeline, phoneme_mapper, error_mapper |
|
|
|
|
|
inference_pipeline = pipeline |
|
|
|
|
|
if mapper is None: |
|
|
try: |
|
|
phoneme_mapper = PhonemeMapper( |
|
|
frame_duration_ms=default_audio_config.chunk_duration_ms, |
|
|
sample_rate=default_audio_config.sample_rate |
|
|
) |
|
|
logger.info("โ
PhonemeMapper initialized") |
|
|
except Exception as e: |
|
|
logger.warning(f"โ ๏ธ PhonemeMapper not available: {e}") |
|
|
phoneme_mapper = None |
|
|
|
|
|
if error_mapper_instance is None: |
|
|
try: |
|
|
error_mapper = ErrorMapper() |
|
|
logger.info("โ
ErrorMapper initialized") |
|
|
except Exception as e: |
|
|
logger.error(f"โ ErrorMapper failed to initialize: {e}") |
|
|
error_mapper = None |
|
|
|
|
|
|
|
|
@router.post("/file", response_model=BatchDiagnosisResponse) |
|
|
async def diagnose_file( |
|
|
audio: UploadFile = File(...), |
|
|
text: Optional[str] = Query(None, description="Expected text/transcript for phoneme mapping"), |
|
|
session_id: Optional[str] = Query(None, description="Optional session ID") |
|
|
): |
|
|
""" |
|
|
Analyze audio file for speech pathology errors. |
|
|
|
|
|
Performs complete phoneme-level analysis: |
|
|
- Extracts Wav2Vec2 features |
|
|
- Classifies fluency and articulation per frame |
|
|
- Maps phonemes to frames |
|
|
- Detects errors and generates therapy recommendations |
|
|
|
|
|
Args: |
|
|
audio: Audio file (WAV, MP3, etc.) |
|
|
text: Optional expected text for phoneme mapping |
|
|
session_id: Optional session ID (auto-generated if not provided) |
|
|
|
|
|
Returns: |
|
|
BatchDiagnosisResponse with detailed error analysis |
|
|
""" |
|
|
if inference_pipeline is None: |
|
|
raise HTTPException(status_code=503, detail="Inference pipeline not loaded") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if not session_id: |
|
|
session_id = str(uuid.uuid4()) |
|
|
|
|
|
|
|
|
temp_file = None |
|
|
try: |
|
|
|
|
|
temp_dir = tempfile.gettempdir() |
|
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
temp_file = os.path.join(temp_dir, f"diagnosis_{session_id}_{audio.filename}") |
|
|
|
|
|
|
|
|
content = await audio.read() |
|
|
with open(temp_file, "wb") as f: |
|
|
f.write(content) |
|
|
|
|
|
file_size_mb = len(content) / 1024 / 1024 |
|
|
logger.info(f"๐ Saved file: {temp_file} ({file_size_mb:.2f} MB)") |
|
|
|
|
|
|
|
|
logger.info("๐ Running phone-level inference...") |
|
|
result = inference_pipeline.predict_phone_level( |
|
|
temp_file, |
|
|
return_timestamps=True |
|
|
) |
|
|
|
|
|
|
|
|
frame_phonemes = [] |
|
|
if text and phoneme_mapper: |
|
|
try: |
|
|
frame_phonemes = phoneme_mapper.map_text_to_frames( |
|
|
text, |
|
|
num_frames=result.num_frames, |
|
|
audio_duration=result.duration |
|
|
) |
|
|
logger.info(f"โ
Mapped {len(frame_phonemes)} phonemes to frames") |
|
|
except Exception as e: |
|
|
logger.warning(f"โ ๏ธ Phoneme mapping failed: {e}, using empty phonemes") |
|
|
frame_phonemes = [''] * result.num_frames |
|
|
else: |
|
|
frame_phonemes = [''] * result.num_frames |
|
|
if not text: |
|
|
logger.warning("โ ๏ธ No text provided, phoneme mapping skipped") |
|
|
|
|
|
|
|
|
frame_diagnoses = [] |
|
|
error_reports = [] |
|
|
error_count = 0 |
|
|
|
|
|
for i, frame_pred in enumerate(result.frame_predictions): |
|
|
|
|
|
phoneme = frame_phonemes[i] if i < len(frame_phonemes) else '' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class_id = frame_pred.articulation_class |
|
|
if frame_pred.fluency_label == 'stutter': |
|
|
class_id += 4 |
|
|
|
|
|
|
|
|
error_detail = None |
|
|
if error_mapper: |
|
|
try: |
|
|
error_detail_obj = error_mapper.map_classifier_output( |
|
|
class_id=class_id, |
|
|
confidence=frame_pred.confidence, |
|
|
phoneme=phoneme if phoneme else 'unknown', |
|
|
fluency_label=frame_pred.fluency_label |
|
|
) |
|
|
|
|
|
|
|
|
error_detail_obj.frame_indices = [i] |
|
|
|
|
|
|
|
|
if error_detail_obj.error_type != ErrorType.NORMAL: |
|
|
error_detail = ErrorDetailSchema( |
|
|
phoneme=error_detail_obj.phoneme, |
|
|
error_type=error_detail_obj.error_type.value, |
|
|
wrong_sound=error_detail_obj.wrong_sound, |
|
|
severity=error_detail_obj.severity, |
|
|
confidence=error_detail_obj.confidence, |
|
|
therapy=error_detail_obj.therapy, |
|
|
frame_indices=[i] |
|
|
) |
|
|
error_count += 1 |
|
|
|
|
|
|
|
|
severity_level = error_mapper.get_severity_level(error_detail_obj.severity) |
|
|
error_reports.append(ErrorReport( |
|
|
frame_id=i, |
|
|
timestamp=frame_pred.time, |
|
|
phoneme=error_detail_obj.phoneme, |
|
|
error=error_detail, |
|
|
severity_level=severity_level.value |
|
|
)) |
|
|
except Exception as e: |
|
|
logger.warning(f"Error mapping failed for frame {i}: {e}") |
|
|
|
|
|
|
|
|
severity_level_str = "none" |
|
|
if error_detail: |
|
|
severity_level_str = error_mapper.get_severity_level(error_detail.severity).value if error_mapper else "none" |
|
|
|
|
|
frame_diagnoses.append(FrameDiagnosis( |
|
|
frame_id=i, |
|
|
timestamp=frame_pred.time, |
|
|
phoneme=phoneme if phoneme else 'unknown', |
|
|
fluency=FluencyInfo( |
|
|
label=frame_pred.fluency_label, |
|
|
confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob) |
|
|
), |
|
|
articulation=ArticulationInfo( |
|
|
label=frame_pred.articulation_label, |
|
|
confidence=frame_pred.confidence, |
|
|
class_id=frame_pred.articulation_class |
|
|
), |
|
|
error=error_detail, |
|
|
severity_level=severity_level_str, |
|
|
confidence=frame_pred.confidence |
|
|
)) |
|
|
|
|
|
|
|
|
fluency_scores = [1.0 - fp.fluency_prob for fp in result.frame_predictions] |
|
|
avg_fluency = sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0.0 |
|
|
|
|
|
|
|
|
normal_frames = sum(1 for fp in result.frame_predictions if fp.articulation_class == 0) |
|
|
articulation_score = normal_frames / result.num_frames if result.num_frames > 0 else 0.0 |
|
|
|
|
|
summary = SummaryMetrics( |
|
|
fluency_score=avg_fluency, |
|
|
fluency_percentage=avg_fluency * 100.0, |
|
|
articulation_score=articulation_score, |
|
|
error_count=error_count, |
|
|
error_rate=error_count / result.num_frames if result.num_frames > 0 else 0.0 |
|
|
) |
|
|
|
|
|
|
|
|
therapy_plan = [] |
|
|
if error_mapper: |
|
|
seen_therapies = set() |
|
|
for error_report in error_reports: |
|
|
if error_report.error.therapy and error_report.error.therapy not in seen_therapies: |
|
|
therapy_plan.append(error_report.error.therapy) |
|
|
seen_therapies.add(error_report.error.therapy) |
|
|
|
|
|
processing_time_ms = (time.time() - start_time) * 1000 |
|
|
|
|
|
|
|
|
|
|
|
model_trained = inference_pipeline.model.is_trained if hasattr(inference_pipeline.model, 'is_trained') else False |
|
|
model_version = "wav2vec2-xlsr-53-v2-trained" if model_trained else "wav2vec2-xlsr-53-v2-beta" |
|
|
|
|
|
response = BatchDiagnosisResponse( |
|
|
session_id=session_id, |
|
|
filename=audio.filename or "unknown", |
|
|
duration=result.duration, |
|
|
total_frames=result.num_frames, |
|
|
error_count=error_count, |
|
|
errors=error_reports, |
|
|
frame_diagnoses=frame_diagnoses, |
|
|
summary=summary, |
|
|
therapy_plan=therapy_plan, |
|
|
processing_time_ms=processing_time_ms, |
|
|
created_at=datetime.utcnow(), |
|
|
model_version=model_version, |
|
|
model_trained=model_trained, |
|
|
confidence_filter_threshold=0.65 |
|
|
) |
|
|
|
|
|
|
|
|
sessions[session_id] = response |
|
|
|
|
|
logger.info(f"โ
Diagnosis complete: {error_count} errors, {processing_time_ms:.0f}ms") |
|
|
|
|
|
return response |
|
|
|
|
|
except HTTPException: |
|
|
raise |
|
|
except Exception as e: |
|
|
logger.error(f"โ Diagnosis failed: {e}", exc_info=True) |
|
|
raise HTTPException(status_code=500, detail=f"Diagnosis failed: {str(e)}") |
|
|
|
|
|
finally: |
|
|
|
|
|
if temp_file and os.path.exists(temp_file): |
|
|
try: |
|
|
os.remove(temp_file) |
|
|
logger.debug(f"๐งน Cleaned up: {temp_file}") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not clean up {temp_file}: {e}") |
|
|
|
|
|
|
|
|
@router.get("/results/{session_id}", response_model=BatchDiagnosisResponse) |
|
|
async def get_results(session_id: str): |
|
|
""" |
|
|
Get cached diagnosis results for a session. |
|
|
|
|
|
Args: |
|
|
session_id: Session identifier |
|
|
|
|
|
Returns: |
|
|
BatchDiagnosisResponse |
|
|
""" |
|
|
if session_id not in sessions: |
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found") |
|
|
|
|
|
return sessions[session_id] |
|
|
|
|
|
|
|
|
@router.get("/results", response_model=SessionListResponse) |
|
|
async def list_results(limit: int = Query(10, ge=1, le=100)): |
|
|
""" |
|
|
List all cached diagnosis sessions. |
|
|
|
|
|
Args: |
|
|
limit: Maximum number of sessions to return |
|
|
|
|
|
Returns: |
|
|
SessionListResponse with session metadata |
|
|
""" |
|
|
session_list = [] |
|
|
for sid, response in list(sessions.items())[:limit]: |
|
|
session_list.append({ |
|
|
"session_id": sid, |
|
|
"filename": response.filename, |
|
|
"duration": response.duration, |
|
|
"error_count": response.error_count, |
|
|
"created_at": response.created_at.isoformat(), |
|
|
"processing_time_ms": response.processing_time_ms |
|
|
}) |
|
|
|
|
|
return SessionListResponse( |
|
|
sessions=session_list, |
|
|
total=len(sessions) |
|
|
) |
|
|
|
|
|
|
|
|
@router.delete("/results/{session_id}") |
|
|
async def delete_results(session_id: str): |
|
|
""" |
|
|
Delete cached diagnosis results for a session. |
|
|
|
|
|
Args: |
|
|
session_id: Session identifier |
|
|
|
|
|
Returns: |
|
|
Success message |
|
|
""" |
|
|
if session_id not in sessions: |
|
|
raise HTTPException(status_code=404, detail=f"Session {session_id} not found") |
|
|
|
|
|
del sessions[session_id] |
|
|
logger.info(f"๐๏ธ Deleted session: {session_id}") |
|
|
|
|
|
return {"status": "success", "message": f"Session {session_id} deleted"} |
|
|
|
|
|
|
|
|
@router.get("/health", response_model=HealthResponse) |
|
|
async def health_check(): |
|
|
""" |
|
|
Health check endpoint. |
|
|
|
|
|
Returns: |
|
|
HealthResponse with service status |
|
|
""" |
|
|
import time |
|
|
start_time = getattr(health_check, '_start_time', time.time()) |
|
|
if not hasattr(health_check, '_start_time'): |
|
|
health_check._start_time = start_time |
|
|
|
|
|
uptime = time.time() - start_time |
|
|
|
|
|
return HealthResponse( |
|
|
status="healthy" if inference_pipeline is not None else "degraded", |
|
|
version="2.0.0", |
|
|
model_loaded=inference_pipeline is not None, |
|
|
uptime_seconds=uptime |
|
|
) |
|
|
|
|
|
|