""" REST API routes for Speech Pathology Diagnosis. This module provides FastAPI endpoints for batch file analysis, session management, and health checks. """ import logging import os import time import tempfile import uuid from pathlib import Path from typing import Optional, List, Dict, Any from datetime import datetime from fastapi import APIRouter, UploadFile, File, HTTPException, Query from fastapi.responses import JSONResponse from api.schemas import ( BatchDiagnosisResponse, FrameDiagnosis, ErrorReport, SummaryMetrics, SessionListResponse, HealthResponse, ErrorDetailSchema, FluencyInfo, ArticulationInfo ) from models.phoneme_mapper import PhonemeMapper from models.error_taxonomy import ErrorMapper, ErrorType, SeverityLevel from inference.inference_pipeline import InferencePipeline from config import AudioConfig, default_audio_config logger = logging.getLogger(__name__) # Create router router = APIRouter(prefix="/diagnose", tags=["diagnosis"]) # In-memory session storage (in production, use Redis or database) sessions: Dict[str, BatchDiagnosisResponse] = {} # Global instances (will be injected) inference_pipeline: Optional[InferencePipeline] = None phoneme_mapper: Optional[PhonemeMapper] = None error_mapper: Optional[ErrorMapper] = None def get_phoneme_mapper() -> Optional[PhonemeMapper]: """Get the global PhonemeMapper instance.""" return phoneme_mapper def get_error_mapper() -> Optional[ErrorMapper]: """Get the global ErrorMapper instance.""" return error_mapper def initialize_routes( pipeline: InferencePipeline, mapper: Optional[PhonemeMapper] = None, error_mapper_instance: Optional[ErrorMapper] = None ): """ Initialize routes with dependencies. Args: pipeline: InferencePipeline instance mapper: Optional PhonemeMapper instance error_mapper_instance: Optional ErrorMapper instance """ global inference_pipeline, phoneme_mapper, error_mapper inference_pipeline = pipeline if mapper is None: try: phoneme_mapper = PhonemeMapper( frame_duration_ms=default_audio_config.chunk_duration_ms, sample_rate=default_audio_config.sample_rate ) logger.info("✅ PhonemeMapper initialized") except Exception as e: logger.warning(f"⚠️ PhonemeMapper not available: {e}") phoneme_mapper = None if error_mapper_instance is None: try: error_mapper = ErrorMapper() logger.info("✅ ErrorMapper initialized") except Exception as e: logger.error(f"❌ ErrorMapper failed to initialize: {e}") error_mapper = None @router.post("/file", response_model=BatchDiagnosisResponse) async def diagnose_file( audio: UploadFile = File(...), text: Optional[str] = Query(None, description="Expected text/transcript for phoneme mapping"), session_id: Optional[str] = Query(None, description="Optional session ID") ): """ Analyze audio file for speech pathology errors. Performs complete phoneme-level analysis: - Extracts Wav2Vec2 features - Classifies fluency and articulation per frame - Maps phonemes to frames - Detects errors and generates therapy recommendations Args: audio: Audio file (WAV, MP3, etc.) text: Optional expected text for phoneme mapping session_id: Optional session ID (auto-generated if not provided) Returns: BatchDiagnosisResponse with detailed error analysis """ if inference_pipeline is None: raise HTTPException(status_code=503, detail="Inference pipeline not loaded") start_time = time.time() # Generate session ID if not session_id: session_id = str(uuid.uuid4()) # Save uploaded file temp_file = None try: # Create temp file temp_dir = tempfile.gettempdir() os.makedirs(temp_dir, exist_ok=True) temp_file = os.path.join(temp_dir, f"diagnosis_{session_id}_{audio.filename}") # Save file content = await audio.read() with open(temp_file, "wb") as f: f.write(content) file_size_mb = len(content) / 1024 / 1024 logger.info(f"📂 Saved file: {temp_file} ({file_size_mb:.2f} MB)") # Run inference logger.info("🔄 Running phone-level inference...") result = inference_pipeline.predict_phone_level( temp_file, return_timestamps=True ) # Map phonemes to frames if text provided frame_phonemes = [] if text and phoneme_mapper: try: frame_phonemes = phoneme_mapper.map_text_to_frames( text, num_frames=result.num_frames, audio_duration=result.duration ) logger.info(f"✅ Mapped {len(frame_phonemes)} phonemes to frames") except Exception as e: logger.warning(f"⚠️ Phoneme mapping failed: {e}, using empty phonemes") frame_phonemes = [''] * result.num_frames else: frame_phonemes = [''] * result.num_frames if not text: logger.warning("⚠️ No text provided, phoneme mapping skipped") # Process frame predictions with error mapping frame_diagnoses = [] error_reports = [] error_count = 0 for i, frame_pred in enumerate(result.frame_predictions): # Get phoneme for this frame phoneme = frame_phonemes[i] if i < len(frame_phonemes) else '' # Map classifier output to error detail # Combine fluency and articulation into 8-class system # Class = articulation_class * 2 + (1 if stutter else 0) class_id = frame_pred.articulation_class if frame_pred.fluency_label == 'stutter': class_id += 4 # Add 4 for stutter classes (4-7) # Get error detail error_detail = None if error_mapper: try: error_detail_obj = error_mapper.map_classifier_output( class_id=class_id, confidence=frame_pred.confidence, phoneme=phoneme if phoneme else 'unknown', fluency_label=frame_pred.fluency_label ) # Add frame index error_detail_obj.frame_indices = [i] # Convert to schema if error_detail_obj.error_type != ErrorType.NORMAL: error_detail = ErrorDetailSchema( phoneme=error_detail_obj.phoneme, error_type=error_detail_obj.error_type.value, wrong_sound=error_detail_obj.wrong_sound, severity=error_detail_obj.severity, confidence=error_detail_obj.confidence, therapy=error_detail_obj.therapy, frame_indices=[i] ) error_count += 1 # Create error report severity_level = error_mapper.get_severity_level(error_detail_obj.severity) error_reports.append(ErrorReport( frame_id=i, timestamp=frame_pred.time, phoneme=error_detail_obj.phoneme, error=error_detail, severity_level=severity_level.value )) except Exception as e: logger.warning(f"Error mapping failed for frame {i}: {e}") # Create frame diagnosis severity_level_str = "none" if error_detail: severity_level_str = error_mapper.get_severity_level(error_detail.severity).value if error_mapper else "none" frame_diagnoses.append(FrameDiagnosis( frame_id=i, timestamp=frame_pred.time, phoneme=phoneme if phoneme else 'unknown', fluency=FluencyInfo( label=frame_pred.fluency_label, confidence=frame_pred.fluency_prob if frame_pred.fluency_label == 'stutter' else (1.0 - frame_pred.fluency_prob) ), articulation=ArticulationInfo( label=frame_pred.articulation_label, confidence=frame_pred.confidence, class_id=frame_pred.articulation_class ), error=error_detail, severity_level=severity_level_str, confidence=frame_pred.confidence )) # Calculate summary metrics fluency_scores = [1.0 - fp.fluency_prob for fp in result.frame_predictions] # Convert stutter prob to fluency avg_fluency = sum(fluency_scores) / len(fluency_scores) if fluency_scores else 0.0 # Articulation score: percentage of normal frames normal_frames = sum(1 for fp in result.frame_predictions if fp.articulation_class == 0) articulation_score = normal_frames / result.num_frames if result.num_frames > 0 else 0.0 summary = SummaryMetrics( fluency_score=avg_fluency, fluency_percentage=avg_fluency * 100.0, articulation_score=articulation_score, error_count=error_count, error_rate=error_count / result.num_frames if result.num_frames > 0 else 0.0 ) # Generate therapy plan (unique therapy recommendations) therapy_plan = [] if error_mapper: seen_therapies = set() for error_report in error_reports: if error_report.error.therapy and error_report.error.therapy not in seen_therapies: therapy_plan.append(error_report.error.therapy) seen_therapies.add(error_report.error.therapy) processing_time_ms = (time.time() - start_time) * 1000 # Create response # Check if model is trained model_trained = inference_pipeline.model.is_trained if hasattr(inference_pipeline.model, 'is_trained') else False model_version = "wav2vec2-xlsr-53-v2-trained" if model_trained else "wav2vec2-xlsr-53-v2-beta" response = BatchDiagnosisResponse( session_id=session_id, filename=audio.filename or "unknown", duration=result.duration, total_frames=result.num_frames, error_count=error_count, errors=error_reports, frame_diagnoses=frame_diagnoses, summary=summary, therapy_plan=therapy_plan, processing_time_ms=processing_time_ms, created_at=datetime.utcnow(), model_version=model_version, model_trained=model_trained, confidence_filter_threshold=0.65 ) # Store in sessions sessions[session_id] = response logger.info(f"✅ Diagnosis complete: {error_count} errors, {processing_time_ms:.0f}ms") return response except HTTPException: raise except Exception as e: logger.error(f"❌ Diagnosis failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Diagnosis failed: {str(e)}") finally: # Cleanup temp file if temp_file and os.path.exists(temp_file): try: os.remove(temp_file) logger.debug(f"🧹 Cleaned up: {temp_file}") except Exception as e: logger.warning(f"Could not clean up {temp_file}: {e}") @router.get("/results/{session_id}", response_model=BatchDiagnosisResponse) async def get_results(session_id: str): """ Get cached diagnosis results for a session. Args: session_id: Session identifier Returns: BatchDiagnosisResponse """ if session_id not in sessions: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") return sessions[session_id] @router.get("/results", response_model=SessionListResponse) async def list_results(limit: int = Query(10, ge=1, le=100)): """ List all cached diagnosis sessions. Args: limit: Maximum number of sessions to return Returns: SessionListResponse with session metadata """ session_list = [] for sid, response in list(sessions.items())[:limit]: session_list.append({ "session_id": sid, "filename": response.filename, "duration": response.duration, "error_count": response.error_count, "created_at": response.created_at.isoformat(), "processing_time_ms": response.processing_time_ms }) return SessionListResponse( sessions=session_list, total=len(sessions) ) @router.delete("/results/{session_id}") async def delete_results(session_id: str): """ Delete cached diagnosis results for a session. Args: session_id: Session identifier Returns: Success message """ if session_id not in sessions: raise HTTPException(status_code=404, detail=f"Session {session_id} not found") del sessions[session_id] logger.info(f"🗑️ Deleted session: {session_id}") return {"status": "success", "message": f"Session {session_id} deleted"} @router.get("/health", response_model=HealthResponse) async def health_check(): """ Health check endpoint. Returns: HealthResponse with service status """ import time start_time = getattr(health_check, '_start_time', time.time()) if not hasattr(health_check, '_start_time'): health_check._start_time = start_time uptime = time.time() - start_time return HealthResponse( status="healthy" if inference_pipeline is not None else "degraded", version="2.0.0", model_loaded=inference_pipeline is not None, uptime_seconds=uptime )