""" Error Taxonomy for Speech Pathology Analysis This module defines error types, severity levels, and therapy recommendations for phoneme-level error detection. """ import logging import json from enum import Enum from typing import Optional, Dict, List from pathlib import Path from dataclasses import dataclass, field from pydantic import BaseModel, Field logger = logging.getLogger(__name__) class ErrorType(str, Enum): """Types of articulation errors.""" NORMAL = "normal" SUBSTITUTION = "substitution" OMISSION = "omission" DISTORTION = "distortion" class SeverityLevel(str, Enum): """Severity levels for errors.""" NONE = "none" LOW = "low" MEDIUM = "medium" HIGH = "high" @dataclass class ErrorDetail: """ Detailed error information for a phoneme. Attributes: phoneme: Expected phoneme symbol (e.g., '/s/') error_type: Type of error (NORMAL, SUBSTITUTION, OMISSION, DISTORTION) wrong_sound: For substitutions, the incorrect phoneme produced (e.g., '/θ/') severity: Severity score (0.0-1.0) confidence: Model confidence in the error detection (0.0-1.0) therapy: Therapy recommendation text frame_indices: List of frame indices where this error occurs """ phoneme: str error_type: ErrorType wrong_sound: Optional[str] = None severity: float = 0.0 confidence: float = 0.0 therapy: str = "" frame_indices: List[int] = field(default_factory=list) class ErrorDetailPydantic(BaseModel): """Pydantic model for API serialization.""" phoneme: str error_type: str wrong_sound: Optional[str] = None severity: float = Field(ge=0.0, le=1.0) confidence: float = Field(ge=0.0, le=1.0) therapy: str frame_indices: List[int] = Field(default_factory=list) class ErrorMapper: """ Maps classifier outputs to error types and generates therapy recommendations. Classifier output mapping (8 classes): - Class 0: Normal articulation, normal fluency - Class 1: Substitution, normal fluency - Class 2: Omission, normal fluency - Class 3: Distortion, normal fluency - Class 4: Normal articulation, stutter - Class 5: Substitution, stutter - Class 6: Omission, stutter - Class 7: Distortion, stutter """ def __init__(self, therapy_db_path: Optional[str] = None): """ Initialize the ErrorMapper. Args: therapy_db_path: Path to therapy recommendations JSON file. If None, uses default location: data/therapy_recommendations.json """ self.therapy_db: Dict = {} # Default path if therapy_db_path is None: therapy_db_path = Path(__file__).parent.parent / "data" / "therapy_recommendations.json" else: therapy_db_path = Path(therapy_db_path) # Load therapy database try: if therapy_db_path.exists(): with open(therapy_db_path, 'r', encoding='utf-8') as f: self.therapy_db = json.load(f) logger.info(f"✅ Loaded therapy database from {therapy_db_path}") else: logger.warning(f"Therapy database not found at {therapy_db_path}, using defaults") self.therapy_db = self._get_default_therapy_db() except Exception as e: logger.error(f"Failed to load therapy database: {e}, using defaults") self.therapy_db = self._get_default_therapy_db() # Common substitution mappings (phoneme → likely wrong sound) self.substitution_map: Dict[str, List[str]] = { '/s/': ['/θ/', '/ʃ/', '/z/'], # lisp, sh-sound, voicing '/r/': ['/w/', '/l/', '/ɹ/'], # rhotacism variants '/l/': ['/w/', '/j/'], # liquid substitutions '/k/': ['/t/', '/p/'], # velar → alveolar/bilabial '/g/': ['/d/', '/b/'], # velar → alveolar/bilabial '/θ/': ['/f/', '/s/'], # th → f or s '/ð/': ['/v/', '/z/'], # voiced th → v or z '/ʃ/': ['/s/', '/tʃ/'], # sh → s or ch '/tʃ/': ['/ʃ/', '/ts/'], # ch → sh or ts } def map_classifier_output( self, class_id: int, confidence: float, phoneme: str, fluency_label: str = "normal" ) -> ErrorDetail: """ Map classifier output to ErrorDetail. Args: class_id: Classifier output class (0-7) confidence: Model confidence (0.0-1.0) phoneme: Expected phoneme symbol fluency_label: Fluency label ("normal" or "stutter") Returns: ErrorDetail object with error information """ # Determine error type from class_id if class_id == 0 or class_id == 4: error_type = ErrorType.NORMAL elif class_id == 1 or class_id == 5: error_type = ErrorType.SUBSTITUTION elif class_id == 2 or class_id == 6: error_type = ErrorType.OMISSION elif class_id == 3 or class_id == 7: error_type = ErrorType.DISTORTION else: logger.warning(f"Unknown class_id: {class_id}, defaulting to NORMAL") error_type = ErrorType.NORMAL # Calculate severity from confidence # Higher confidence in error = higher severity if error_type == ErrorType.NORMAL: severity = 0.0 else: severity = confidence # Use confidence as severity proxy # Get wrong sound for substitutions wrong_sound = None if error_type == ErrorType.SUBSTITUTION: wrong_sound = self._map_substitution(phoneme, confidence) # Get therapy recommendation therapy = self.get_therapy(error_type, phoneme, wrong_sound) return ErrorDetail( phoneme=phoneme, error_type=error_type, wrong_sound=wrong_sound, severity=severity, confidence=confidence, therapy=therapy ) def _map_substitution(self, phoneme: str, confidence: float) -> Optional[str]: """ Map substitution error to likely wrong sound. Args: phoneme: Expected phoneme confidence: Model confidence Returns: Most likely wrong phoneme, or None if unknown """ if phoneme in self.substitution_map: # Return first (most common) substitution return self.substitution_map[phoneme][0] return None def get_therapy( self, error_type: ErrorType, phoneme: str, wrong_sound: Optional[str] = None ) -> str: """ Get therapy recommendation for an error. Args: error_type: Type of error phoneme: Expected phoneme wrong_sound: For substitutions, the wrong sound produced Returns: Therapy recommendation text """ if error_type == ErrorType.NORMAL: return "No therapy needed - production is correct." # Build lookup key if error_type == ErrorType.SUBSTITUTION and wrong_sound: key = f"{phoneme}→{wrong_sound}" if "substitutions" in self.therapy_db and key in self.therapy_db["substitutions"]: return self.therapy_db["substitutions"][key] # Fallback to generic recommendations if error_type == ErrorType.SUBSTITUTION: if "substitutions" in self.therapy_db and "generic" in self.therapy_db["substitutions"]: return self.therapy_db["substitutions"]["generic"].replace("{phoneme}", phoneme) return f"Substitution error for {phoneme}. Practice correct articulator placement." elif error_type == ErrorType.OMISSION: if "omissions" in self.therapy_db and phoneme in self.therapy_db["omissions"]: return self.therapy_db["omissions"][phoneme] if "omissions" in self.therapy_db and "generic" in self.therapy_db["omissions"]: return self.therapy_db["omissions"]["generic"].replace("{phoneme}", phoneme) return f"Omission error for {phoneme}. Practice saying the sound separately first." elif error_type == ErrorType.DISTORTION: if "distortions" in self.therapy_db and phoneme in self.therapy_db["distortions"]: return self.therapy_db["distortions"][phoneme] if "distortions" in self.therapy_db and "generic" in self.therapy_db["distortions"]: return self.therapy_db["distortions"]["generic"].replace("{phoneme}", phoneme) return f"Distortion error for {phoneme}. Use mirror feedback and watch articulator position." return "Consult with speech-language pathologist for personalized therapy plan." def get_severity_level(self, severity: float) -> SeverityLevel: """ Convert severity score to severity level. Args: severity: Severity score (0.0-1.0) Returns: SeverityLevel enum """ if severity == 0.0: return SeverityLevel.NONE elif severity < 0.3: return SeverityLevel.LOW elif severity < 0.7: return SeverityLevel.MEDIUM else: return SeverityLevel.HIGH def _get_default_therapy_db(self) -> Dict: """Get default therapy database if file not found.""" return { "substitutions": { "/s/→/θ/": "Lisp - Use tongue tip placement behind upper teeth. Practice /s/ in isolation.", "/r/→/w/": "Rhotacism - Practice tongue position: curl tongue back, avoid lip rounding.", "/r/→/l/": "Rhotacism - Focus on tongue tip position vs. tongue body placement.", "generic": "Substitution error for {phoneme}. Practice correct articulator placement with mirror feedback." }, "omissions": { "/r/": "Practice /r/ in isolation, then in CV syllables (ra, re, ri, ro, ru).", "/l/": "Lateral tongue placement - practice with tongue tip up to alveolar ridge.", "/s/": "Practice /s/ with tongue tip placement, use mirror to check position.", "generic": "Omission error for {phoneme}. Say the sound separately first, then blend into words." }, "distortions": { "/s/": "Sibilant clarity - use mirror feedback, ensure tongue tip is up and air stream is central.", "/ʃ/": "Fricative voicing exercise - practice /sh/ vs /s/ distinction.", "/r/": "Rhotacism - practice tongue position and lip rounding control.", "generic": "Distortion error for {phoneme}. Use mirror feedback and watch articulator position carefully." } } # Unit test function def test_error_mapper(): """Test the ErrorMapper.""" print("Testing ErrorMapper...") mapper = ErrorMapper() # Test 1: Normal (class 0) error = mapper.map_classifier_output(0, 0.95, "/k/") assert error.error_type == ErrorType.NORMAL assert error.severity == 0.0 print(f"✅ Normal error: {error.error_type}, therapy: {error.therapy[:50]}...") # Test 2: Substitution (class 1) error = mapper.map_classifier_output(1, 0.78, "/s/") assert error.error_type == ErrorType.SUBSTITUTION assert error.wrong_sound is not None print(f"✅ Substitution error: {error.error_type}, wrong_sound: {error.wrong_sound}") print(f" Therapy: {error.therapy[:80]}...") # Test 3: Omission (class 2) error = mapper.map_classifier_output(2, 0.85, "/r/") assert error.error_type == ErrorType.OMISSION print(f"✅ Omission error: {error.error_type}") print(f" Therapy: {error.therapy[:80]}...") # Test 4: Distortion (class 3) error = mapper.map_classifier_output(3, 0.65, "/s/") assert error.error_type == ErrorType.DISTORTION print(f"✅ Distortion error: {error.error_type}") print(f" Therapy: {error.therapy[:80]}...") # Test 5: Severity levels assert mapper.get_severity_level(0.0) == SeverityLevel.NONE assert mapper.get_severity_level(0.2) == SeverityLevel.LOW assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH print("✅ Severity level mapping correct") print("\n✅ All tests passed!") if __name__ == "__main__": test_error_mapper()