| """ | |
| Pydantic schemas for Speech Pathology Diagnosis API. | |
| This module defines request and response models for REST API and WebSocket endpoints. | |
| """ | |
| from typing import List, Optional, Dict, Any | |
| from pydantic import BaseModel, Field | |
| from datetime import datetime | |
| class FluencyInfo(BaseModel): | |
| """Fluency classification information.""" | |
| label: str = Field(..., description="Fluency label: 'normal' or 'stutter'") | |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score (0-1)") | |
| class ArticulationInfo(BaseModel): | |
| """Articulation classification information.""" | |
| label: str = Field(..., description="Articulation label: 'normal', 'substitution', 'omission', 'distortion'") | |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score (0-1)") | |
| class_id: int = Field(..., ge=0, le=3, description="Class ID: 0=normal, 1=substitution, 2=omission, 3=distortion") | |
| class ErrorDetailSchema(BaseModel): | |
| """Error detail schema for API responses.""" | |
| phoneme: str = Field(..., description="Expected phoneme symbol") | |
| error_type: str = Field(..., description="Error type: normal, substitution, omission, distortion") | |
| wrong_sound: Optional[str] = Field(None, description="For substitutions, the incorrect phoneme produced") | |
| severity: float = Field(..., ge=0.0, le=1.0, description="Severity score (0-1)") | |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Model confidence (0-1)") | |
| therapy: str = Field(..., description="Therapy recommendation") | |
| frame_indices: List[int] = Field(default_factory=list, description="Frame indices where error occurs") | |
| class FrameDiagnosis(BaseModel): | |
| """Diagnosis for a single frame.""" | |
| frame_id: int = Field(..., description="Frame index") | |
| timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds") | |
| phoneme: str = Field(..., description="Expected phoneme for this frame") | |
| fluency: FluencyInfo = Field(..., description="Fluency classification") | |
| articulation: ArticulationInfo = Field(..., description="Articulation classification") | |
| error: Optional[ErrorDetailSchema] = Field(None, description="Error details if error detected") | |
| severity_level: str = Field(..., description="Severity level: none, low, medium, high") | |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence") | |
| class ErrorReport(BaseModel): | |
| """Detailed error report for a frame.""" | |
| frame_id: int = Field(..., description="Frame index") | |
| timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds") | |
| phoneme: str = Field(..., description="Expected phoneme") | |
| error: ErrorDetailSchema = Field(..., description="Error details") | |
| severity_level: str = Field(..., description="Severity level: none, low, medium, high") | |
| class SummaryMetrics(BaseModel): | |
| """Summary metrics for the analysis.""" | |
| fluency_score: float = Field(..., ge=0.0, le=1.0, description="Average fluency score (0=stutter, 1=normal)") | |
| fluency_percentage: float = Field(..., ge=0.0, le=100.0, description="Fluency percentage") | |
| articulation_score: float = Field(..., ge=0.0, le=1.0, description="Average articulation correctness") | |
| error_count: int = Field(..., ge=0, description="Total number of errors detected") | |
| error_rate: float = Field(..., ge=0.0, le=1.0, description="Error rate (errors/total_frames)") | |
| class BatchDiagnosisResponse(BaseModel): | |
| """Response for batch file diagnosis.""" | |
| session_id: str = Field(..., description="Session identifier") | |
| filename: str = Field(..., description="Processed filename") | |
| duration: float = Field(..., ge=0.0, description="Audio duration in seconds") | |
| total_frames: int = Field(..., ge=0, description="Total number of frames analyzed") | |
| error_count: int = Field(..., ge=0, description="Number of errors detected") | |
| errors: List[ErrorReport] = Field(default_factory=list, description="List of error reports") | |
| frame_diagnoses: List[FrameDiagnosis] = Field(default_factory=list, description="All frame diagnoses") | |
| summary: SummaryMetrics = Field(..., description="Summary metrics") | |
| therapy_plan: List[str] = Field(default_factory=list, description="Therapy recommendations") | |
| processing_time_ms: float = Field(..., ge=0.0, description="Processing time in milliseconds") | |
| created_at: datetime = Field(default_factory=datetime.now, description="Analysis timestamp") | |
| model_version: str = Field(default="wav2vec2-xlsr-53-v2", description="Model version identifier") | |
| model_trained: bool = Field(default=False, description="Whether classifier head is trained") | |
| confidence_filter_threshold: float = Field(default=0.65, ge=0.0, le=1.0, description="Confidence threshold for filtering predictions") | |
| class StreamingDiagnosisRequest(BaseModel): | |
| """Request for streaming diagnosis.""" | |
| audio_chunk: bytes = Field(..., description="Audio chunk data (320 samples for 20ms @ 16kHz)") | |
| sample_rate: int = Field(16000, description="Sample rate in Hz") | |
| session_id: str = Field(..., description="Session identifier") | |
| frame_index: Optional[int] = Field(None, description="Frame index for tracking") | |
| class StreamingDiagnosisResponse(BaseModel): | |
| """Response for streaming diagnosis (single frame).""" | |
| session_id: str = Field(..., description="Session identifier") | |
| frame_id: int = Field(..., description="Frame index") | |
| timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds") | |
| phoneme: str = Field(..., description="Expected phoneme") | |
| fluency: FluencyInfo = Field(..., description="Fluency classification") | |
| articulation: ArticulationInfo = Field(..., description="Articulation classification") | |
| error: Optional[ErrorDetailSchema] = Field(None, description="Error details if error detected") | |
| severity_level: str = Field(..., description="Severity level") | |
| confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence") | |
| latency_ms: float = Field(..., ge=0.0, description="Processing latency in milliseconds") | |
| class SessionListResponse(BaseModel): | |
| """Response for listing sessions.""" | |
| sessions: List[Dict[str, Any]] = Field(..., description="List of session metadata") | |
| total: int = Field(..., ge=0, description="Total number of sessions") | |
| class HealthResponse(BaseModel): | |
| """Health check response.""" | |
| status: str = Field(..., description="Service status") | |
| version: str = Field(..., description="API version") | |
| model_loaded: bool = Field(..., description="Whether model is loaded") | |
| uptime_seconds: float = Field(..., ge=0.0, description="Service uptime in seconds") | |