anfastech's picture
New: Phoneme-level speech pathology diagnosis MVP with real-time streaming
1cd6149
"""
Pydantic schemas for Speech Pathology Diagnosis API.
This module defines request and response models for REST API and WebSocket endpoints.
"""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from datetime import datetime
class FluencyInfo(BaseModel):
"""Fluency classification information."""
label: str = Field(..., description="Fluency label: 'normal' or 'stutter'")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score (0-1)")
class ArticulationInfo(BaseModel):
"""Articulation classification information."""
label: str = Field(..., description="Articulation label: 'normal', 'substitution', 'omission', 'distortion'")
confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score (0-1)")
class_id: int = Field(..., ge=0, le=3, description="Class ID: 0=normal, 1=substitution, 2=omission, 3=distortion")
class ErrorDetailSchema(BaseModel):
"""Error detail schema for API responses."""
phoneme: str = Field(..., description="Expected phoneme symbol")
error_type: str = Field(..., description="Error type: normal, substitution, omission, distortion")
wrong_sound: Optional[str] = Field(None, description="For substitutions, the incorrect phoneme produced")
severity: float = Field(..., ge=0.0, le=1.0, description="Severity score (0-1)")
confidence: float = Field(..., ge=0.0, le=1.0, description="Model confidence (0-1)")
therapy: str = Field(..., description="Therapy recommendation")
frame_indices: List[int] = Field(default_factory=list, description="Frame indices where error occurs")
class FrameDiagnosis(BaseModel):
"""Diagnosis for a single frame."""
frame_id: int = Field(..., description="Frame index")
timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
phoneme: str = Field(..., description="Expected phoneme for this frame")
fluency: FluencyInfo = Field(..., description="Fluency classification")
articulation: ArticulationInfo = Field(..., description="Articulation classification")
error: Optional[ErrorDetailSchema] = Field(None, description="Error details if error detected")
severity_level: str = Field(..., description="Severity level: none, low, medium, high")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
class ErrorReport(BaseModel):
"""Detailed error report for a frame."""
frame_id: int = Field(..., description="Frame index")
timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
phoneme: str = Field(..., description="Expected phoneme")
error: ErrorDetailSchema = Field(..., description="Error details")
severity_level: str = Field(..., description="Severity level: none, low, medium, high")
class SummaryMetrics(BaseModel):
"""Summary metrics for the analysis."""
fluency_score: float = Field(..., ge=0.0, le=1.0, description="Average fluency score (0=stutter, 1=normal)")
fluency_percentage: float = Field(..., ge=0.0, le=100.0, description="Fluency percentage")
articulation_score: float = Field(..., ge=0.0, le=1.0, description="Average articulation correctness")
error_count: int = Field(..., ge=0, description="Total number of errors detected")
error_rate: float = Field(..., ge=0.0, le=1.0, description="Error rate (errors/total_frames)")
class BatchDiagnosisResponse(BaseModel):
"""Response for batch file diagnosis."""
session_id: str = Field(..., description="Session identifier")
filename: str = Field(..., description="Processed filename")
duration: float = Field(..., ge=0.0, description="Audio duration in seconds")
total_frames: int = Field(..., ge=0, description="Total number of frames analyzed")
error_count: int = Field(..., ge=0, description="Number of errors detected")
errors: List[ErrorReport] = Field(default_factory=list, description="List of error reports")
frame_diagnoses: List[FrameDiagnosis] = Field(default_factory=list, description="All frame diagnoses")
summary: SummaryMetrics = Field(..., description="Summary metrics")
therapy_plan: List[str] = Field(default_factory=list, description="Therapy recommendations")
processing_time_ms: float = Field(..., ge=0.0, description="Processing time in milliseconds")
created_at: datetime = Field(default_factory=datetime.now, description="Analysis timestamp")
model_version: str = Field(default="wav2vec2-xlsr-53-v2", description="Model version identifier")
model_trained: bool = Field(default=False, description="Whether classifier head is trained")
confidence_filter_threshold: float = Field(default=0.65, ge=0.0, le=1.0, description="Confidence threshold for filtering predictions")
class StreamingDiagnosisRequest(BaseModel):
"""Request for streaming diagnosis."""
audio_chunk: bytes = Field(..., description="Audio chunk data (320 samples for 20ms @ 16kHz)")
sample_rate: int = Field(16000, description="Sample rate in Hz")
session_id: str = Field(..., description="Session identifier")
frame_index: Optional[int] = Field(None, description="Frame index for tracking")
class StreamingDiagnosisResponse(BaseModel):
"""Response for streaming diagnosis (single frame)."""
session_id: str = Field(..., description="Session identifier")
frame_id: int = Field(..., description="Frame index")
timestamp: float = Field(..., ge=0.0, description="Timestamp in seconds")
phoneme: str = Field(..., description="Expected phoneme")
fluency: FluencyInfo = Field(..., description="Fluency classification")
articulation: ArticulationInfo = Field(..., description="Articulation classification")
error: Optional[ErrorDetailSchema] = Field(None, description="Error details if error detected")
severity_level: str = Field(..., description="Severity level")
confidence: float = Field(..., ge=0.0, le=1.0, description="Overall confidence")
latency_ms: float = Field(..., ge=0.0, description="Processing latency in milliseconds")
class SessionListResponse(BaseModel):
"""Response for listing sessions."""
sessions: List[Dict[str, Any]] = Field(..., description="List of session metadata")
total: int = Field(..., ge=0, description="Total number of sessions")
class HealthResponse(BaseModel):
"""Health check response."""
status: str = Field(..., description="Service status")
version: str = Field(..., description="API version")
model_loaded: bool = Field(..., description="Whether model is loaded")
uptime_seconds: float = Field(..., ge=0.0, description="Service uptime in seconds")