zlaqa-version-c-ai-enginee / models /error_taxonomy.py
anfastech's picture
New: implemented many, many changes. 10% Phone-level detection: WORKING
278e294
"""
Error Taxonomy for Speech Pathology Analysis
This module defines error types, severity levels, and therapy recommendations
for phoneme-level error detection.
"""
import logging
import json
from enum import Enum
from typing import Optional, Dict, List
from pathlib import Path
from dataclasses import dataclass, field
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
class ErrorType(str, Enum):
"""Types of articulation errors."""
NORMAL = "normal"
SUBSTITUTION = "substitution"
OMISSION = "omission"
DISTORTION = "distortion"
class SeverityLevel(str, Enum):
"""Severity levels for errors."""
NONE = "none"
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
@dataclass
class ErrorDetail:
"""
Detailed error information for a phoneme.
Attributes:
phoneme: Expected phoneme symbol (e.g., '/s/')
error_type: Type of error (NORMAL, SUBSTITUTION, OMISSION, DISTORTION)
wrong_sound: For substitutions, the incorrect phoneme produced (e.g., '/θ/')
severity: Severity score (0.0-1.0)
confidence: Model confidence in the error detection (0.0-1.0)
therapy: Therapy recommendation text
frame_indices: List of frame indices where this error occurs
"""
phoneme: str
error_type: ErrorType
wrong_sound: Optional[str] = None
severity: float = 0.0
confidence: float = 0.0
therapy: str = ""
frame_indices: List[int] = field(default_factory=list)
class ErrorDetailPydantic(BaseModel):
"""Pydantic model for API serialization."""
phoneme: str
error_type: str
wrong_sound: Optional[str] = None
severity: float = Field(ge=0.0, le=1.0)
confidence: float = Field(ge=0.0, le=1.0)
therapy: str
frame_indices: List[int] = Field(default_factory=list)
class ErrorMapper:
"""
Maps classifier outputs to error types and generates therapy recommendations.
Classifier output mapping (8 classes):
- Class 0: Normal articulation, normal fluency
- Class 1: Substitution, normal fluency
- Class 2: Omission, normal fluency
- Class 3: Distortion, normal fluency
- Class 4: Normal articulation, stutter
- Class 5: Substitution, stutter
- Class 6: Omission, stutter
- Class 7: Distortion, stutter
"""
def __init__(self, therapy_db_path: Optional[str] = None):
"""
Initialize the ErrorMapper.
Args:
therapy_db_path: Path to therapy recommendations JSON file.
If None, uses default location: data/therapy_recommendations.json
"""
self.therapy_db: Dict = {}
# Default path
if therapy_db_path is None:
therapy_db_path = Path(__file__).parent.parent / "data" / "therapy_recommendations.json"
else:
therapy_db_path = Path(therapy_db_path)
# Load therapy database
try:
if therapy_db_path.exists():
with open(therapy_db_path, 'r', encoding='utf-8') as f:
self.therapy_db = json.load(f)
logger.info(f"✅ Loaded therapy database from {therapy_db_path}")
else:
logger.warning(f"Therapy database not found at {therapy_db_path}, using defaults")
self.therapy_db = self._get_default_therapy_db()
except Exception as e:
logger.error(f"Failed to load therapy database: {e}, using defaults")
self.therapy_db = self._get_default_therapy_db()
# Common substitution mappings (phoneme → likely wrong sound)
self.substitution_map: Dict[str, List[str]] = {
'/s/': ['/θ/', '/ʃ/', '/z/'], # lisp, sh-sound, voicing
'/r/': ['/w/', '/l/', '/ɹ/'], # rhotacism variants
'/l/': ['/w/', '/j/'], # liquid substitutions
'/k/': ['/t/', '/p/'], # velar → alveolar/bilabial
'/g/': ['/d/', '/b/'], # velar → alveolar/bilabial
'/θ/': ['/f/', '/s/'], # th → f or s
'/ð/': ['/v/', '/z/'], # voiced th → v or z
'/ʃ/': ['/s/', '/tʃ/'], # sh → s or ch
'/tʃ/': ['/ʃ/', '/ts/'], # ch → sh or ts
}
def map_classifier_output(
self,
class_id: int,
confidence: float,
phoneme: str,
fluency_label: str = "normal"
) -> ErrorDetail:
"""
Map classifier output to ErrorDetail.
Args:
class_id: Classifier output class (0-7)
confidence: Model confidence (0.0-1.0)
phoneme: Expected phoneme symbol
fluency_label: Fluency label ("normal" or "stutter")
Returns:
ErrorDetail object with error information
"""
# Determine error type from class_id
if class_id == 0 or class_id == 4:
error_type = ErrorType.NORMAL
elif class_id == 1 or class_id == 5:
error_type = ErrorType.SUBSTITUTION
elif class_id == 2 or class_id == 6:
error_type = ErrorType.OMISSION
elif class_id == 3 or class_id == 7:
error_type = ErrorType.DISTORTION
else:
logger.warning(f"Unknown class_id: {class_id}, defaulting to NORMAL")
error_type = ErrorType.NORMAL
# Calculate severity from confidence
# Higher confidence in error = higher severity
if error_type == ErrorType.NORMAL:
severity = 0.0
else:
severity = confidence # Use confidence as severity proxy
# Get wrong sound for substitutions
wrong_sound = None
if error_type == ErrorType.SUBSTITUTION:
wrong_sound = self._map_substitution(phoneme, confidence)
# Get therapy recommendation
therapy = self.get_therapy(error_type, phoneme, wrong_sound)
return ErrorDetail(
phoneme=phoneme,
error_type=error_type,
wrong_sound=wrong_sound,
severity=severity,
confidence=confidence,
therapy=therapy
)
def _map_substitution(self, phoneme: str, confidence: float) -> Optional[str]:
"""
Map substitution error to likely wrong sound.
Args:
phoneme: Expected phoneme
confidence: Model confidence
Returns:
Most likely wrong phoneme, or None if unknown
"""
if phoneme in self.substitution_map:
# Return first (most common) substitution
return self.substitution_map[phoneme][0]
return None
def get_therapy(
self,
error_type: ErrorType,
phoneme: str,
wrong_sound: Optional[str] = None
) -> str:
"""
Get therapy recommendation for an error.
Args:
error_type: Type of error
phoneme: Expected phoneme
wrong_sound: For substitutions, the wrong sound produced
Returns:
Therapy recommendation text
"""
if error_type == ErrorType.NORMAL:
return "No therapy needed - production is correct."
# Build lookup key
if error_type == ErrorType.SUBSTITUTION and wrong_sound:
key = f"{phoneme}→{wrong_sound}"
if "substitutions" in self.therapy_db and key in self.therapy_db["substitutions"]:
return self.therapy_db["substitutions"][key]
# Fallback to generic recommendations
if error_type == ErrorType.SUBSTITUTION:
if "substitutions" in self.therapy_db and "generic" in self.therapy_db["substitutions"]:
return self.therapy_db["substitutions"]["generic"].replace("{phoneme}", phoneme)
return f"Substitution error for {phoneme}. Practice correct articulator placement."
elif error_type == ErrorType.OMISSION:
if "omissions" in self.therapy_db and phoneme in self.therapy_db["omissions"]:
return self.therapy_db["omissions"][phoneme]
if "omissions" in self.therapy_db and "generic" in self.therapy_db["omissions"]:
return self.therapy_db["omissions"]["generic"].replace("{phoneme}", phoneme)
return f"Omission error for {phoneme}. Practice saying the sound separately first."
elif error_type == ErrorType.DISTORTION:
if "distortions" in self.therapy_db and phoneme in self.therapy_db["distortions"]:
return self.therapy_db["distortions"][phoneme]
if "distortions" in self.therapy_db and "generic" in self.therapy_db["distortions"]:
return self.therapy_db["distortions"]["generic"].replace("{phoneme}", phoneme)
return f"Distortion error for {phoneme}. Use mirror feedback and watch articulator position."
return "Consult with speech-language pathologist for personalized therapy plan."
def get_severity_level(self, severity: float) -> SeverityLevel:
"""
Convert severity score to severity level.
Args:
severity: Severity score (0.0-1.0)
Returns:
SeverityLevel enum
"""
if severity == 0.0:
return SeverityLevel.NONE
elif severity < 0.3:
return SeverityLevel.LOW
elif severity < 0.7:
return SeverityLevel.MEDIUM
else:
return SeverityLevel.HIGH
def _get_default_therapy_db(self) -> Dict:
"""Get default therapy database if file not found."""
return {
"substitutions": {
"/s/→/θ/": "Lisp - Use tongue tip placement behind upper teeth. Practice /s/ in isolation.",
"/r/→/w/": "Rhotacism - Practice tongue position: curl tongue back, avoid lip rounding.",
"/r/→/l/": "Rhotacism - Focus on tongue tip position vs. tongue body placement.",
"generic": "Substitution error for {phoneme}. Practice correct articulator placement with mirror feedback."
},
"omissions": {
"/r/": "Practice /r/ in isolation, then in CV syllables (ra, re, ri, ro, ru).",
"/l/": "Lateral tongue placement - practice with tongue tip up to alveolar ridge.",
"/s/": "Practice /s/ with tongue tip placement, use mirror to check position.",
"generic": "Omission error for {phoneme}. Say the sound separately first, then blend into words."
},
"distortions": {
"/s/": "Sibilant clarity - use mirror feedback, ensure tongue tip is up and air stream is central.",
"/ʃ/": "Fricative voicing exercise - practice /sh/ vs /s/ distinction.",
"/r/": "Rhotacism - practice tongue position and lip rounding control.",
"generic": "Distortion error for {phoneme}. Use mirror feedback and watch articulator position carefully."
}
}
# Unit test function
def test_error_mapper():
"""Test the ErrorMapper."""
print("Testing ErrorMapper...")
mapper = ErrorMapper()
# Test 1: Normal (class 0)
error = mapper.map_classifier_output(0, 0.95, "/k/")
assert error.error_type == ErrorType.NORMAL
assert error.severity == 0.0
print(f"✅ Normal error: {error.error_type}, therapy: {error.therapy[:50]}...")
# Test 2: Substitution (class 1)
error = mapper.map_classifier_output(1, 0.78, "/s/")
assert error.error_type == ErrorType.SUBSTITUTION
assert error.wrong_sound is not None
print(f"✅ Substitution error: {error.error_type}, wrong_sound: {error.wrong_sound}")
print(f" Therapy: {error.therapy[:80]}...")
# Test 3: Omission (class 2)
error = mapper.map_classifier_output(2, 0.85, "/r/")
assert error.error_type == ErrorType.OMISSION
print(f"✅ Omission error: {error.error_type}")
print(f" Therapy: {error.therapy[:80]}...")
# Test 4: Distortion (class 3)
error = mapper.map_classifier_output(3, 0.65, "/s/")
assert error.error_type == ErrorType.DISTORTION
print(f"✅ Distortion error: {error.error_type}")
print(f" Therapy: {error.therapy[:80]}...")
# Test 5: Severity levels
assert mapper.get_severity_level(0.0) == SeverityLevel.NONE
assert mapper.get_severity_level(0.2) == SeverityLevel.LOW
assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM
assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH
print("✅ Severity level mapping correct")
print("\n✅ All tests passed!")
if __name__ == "__main__":
test_error_mapper()