File size: 12,838 Bytes
278e294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
"""
Error Taxonomy for Speech Pathology Analysis

This module defines error types, severity levels, and therapy recommendations
for phoneme-level error detection.
"""

import logging
import json
from enum import Enum
from typing import Optional, Dict, List
from pathlib import Path
from dataclasses import dataclass, field
from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)


class ErrorType(str, Enum):
    """Types of articulation errors."""
    NORMAL = "normal"
    SUBSTITUTION = "substitution"
    OMISSION = "omission"
    DISTORTION = "distortion"


class SeverityLevel(str, Enum):
    """Severity levels for errors."""
    NONE = "none"
    LOW = "low"
    MEDIUM = "medium"
    HIGH = "high"


@dataclass
class ErrorDetail:
    """
    Detailed error information for a phoneme.
    
    Attributes:
        phoneme: Expected phoneme symbol (e.g., '/s/')
        error_type: Type of error (NORMAL, SUBSTITUTION, OMISSION, DISTORTION)
        wrong_sound: For substitutions, the incorrect phoneme produced (e.g., '/θ/')
        severity: Severity score (0.0-1.0)
        confidence: Model confidence in the error detection (0.0-1.0)
        therapy: Therapy recommendation text
        frame_indices: List of frame indices where this error occurs
    """
    phoneme: str
    error_type: ErrorType
    wrong_sound: Optional[str] = None
    severity: float = 0.0
    confidence: float = 0.0
    therapy: str = ""
    frame_indices: List[int] = field(default_factory=list)


class ErrorDetailPydantic(BaseModel):
    """Pydantic model for API serialization."""
    phoneme: str
    error_type: str
    wrong_sound: Optional[str] = None
    severity: float = Field(ge=0.0, le=1.0)
    confidence: float = Field(ge=0.0, le=1.0)
    therapy: str
    frame_indices: List[int] = Field(default_factory=list)


class ErrorMapper:
    """
    Maps classifier outputs to error types and generates therapy recommendations.
    
    Classifier output mapping (8 classes):
    - Class 0: Normal articulation, normal fluency
    - Class 1: Substitution, normal fluency
    - Class 2: Omission, normal fluency
    - Class 3: Distortion, normal fluency
    - Class 4: Normal articulation, stutter
    - Class 5: Substitution, stutter
    - Class 6: Omission, stutter
    - Class 7: Distortion, stutter
    """
    
    def __init__(self, therapy_db_path: Optional[str] = None):
        """
        Initialize the ErrorMapper.
        
        Args:
            therapy_db_path: Path to therapy recommendations JSON file.
                           If None, uses default location: data/therapy_recommendations.json
        """
        self.therapy_db: Dict = {}
        
        # Default path
        if therapy_db_path is None:
            therapy_db_path = Path(__file__).parent.parent / "data" / "therapy_recommendations.json"
        else:
            therapy_db_path = Path(therapy_db_path)
        
        # Load therapy database
        try:
            if therapy_db_path.exists():
                with open(therapy_db_path, 'r', encoding='utf-8') as f:
                    self.therapy_db = json.load(f)
                logger.info(f"✅ Loaded therapy database from {therapy_db_path}")
            else:
                logger.warning(f"Therapy database not found at {therapy_db_path}, using defaults")
                self.therapy_db = self._get_default_therapy_db()
        except Exception as e:
            logger.error(f"Failed to load therapy database: {e}, using defaults")
            self.therapy_db = self._get_default_therapy_db()
        
        # Common substitution mappings (phoneme → likely wrong sound)
        self.substitution_map: Dict[str, List[str]] = {
            '/s/': ['/θ/', '/ʃ/', '/z/'],  # lisp, sh-sound, voicing
            '/r/': ['/w/', '/l/', '/ɹ/'],  # rhotacism variants
            '/l/': ['/w/', '/j/'],  # liquid substitutions
            '/k/': ['/t/', '/p/'],  # velar → alveolar/bilabial
            '/g/': ['/d/', '/b/'],  # velar → alveolar/bilabial
            '/θ/': ['/f/', '/s/'],  # th → f or s
            '/ð/': ['/v/', '/z/'],  # voiced th → v or z
            '/ʃ/': ['/s/', '/tʃ/'],  # sh → s or ch
            '/tʃ/': ['/ʃ/', '/ts/'],  # ch → sh or ts
        }
    
    def map_classifier_output(
        self,
        class_id: int,
        confidence: float,
        phoneme: str,
        fluency_label: str = "normal"
    ) -> ErrorDetail:
        """
        Map classifier output to ErrorDetail.
        
        Args:
            class_id: Classifier output class (0-7)
            confidence: Model confidence (0.0-1.0)
            phoneme: Expected phoneme symbol
            fluency_label: Fluency label ("normal" or "stutter")
        
        Returns:
            ErrorDetail object with error information
        """
        # Determine error type from class_id
        if class_id == 0 or class_id == 4:
            error_type = ErrorType.NORMAL
        elif class_id == 1 or class_id == 5:
            error_type = ErrorType.SUBSTITUTION
        elif class_id == 2 or class_id == 6:
            error_type = ErrorType.OMISSION
        elif class_id == 3 or class_id == 7:
            error_type = ErrorType.DISTORTION
        else:
            logger.warning(f"Unknown class_id: {class_id}, defaulting to NORMAL")
            error_type = ErrorType.NORMAL
        
        # Calculate severity from confidence
        # Higher confidence in error = higher severity
        if error_type == ErrorType.NORMAL:
            severity = 0.0
        else:
            severity = confidence  # Use confidence as severity proxy
        
        # Get wrong sound for substitutions
        wrong_sound = None
        if error_type == ErrorType.SUBSTITUTION:
            wrong_sound = self._map_substitution(phoneme, confidence)
        
        # Get therapy recommendation
        therapy = self.get_therapy(error_type, phoneme, wrong_sound)
        
        return ErrorDetail(
            phoneme=phoneme,
            error_type=error_type,
            wrong_sound=wrong_sound,
            severity=severity,
            confidence=confidence,
            therapy=therapy
        )
    
    def _map_substitution(self, phoneme: str, confidence: float) -> Optional[str]:
        """
        Map substitution error to likely wrong sound.
        
        Args:
            phoneme: Expected phoneme
            confidence: Model confidence
        
        Returns:
            Most likely wrong phoneme, or None if unknown
        """
        if phoneme in self.substitution_map:
            # Return first (most common) substitution
            return self.substitution_map[phoneme][0]
        return None
    
    def get_therapy(
        self,
        error_type: ErrorType,
        phoneme: str,
        wrong_sound: Optional[str] = None
    ) -> str:
        """
        Get therapy recommendation for an error.
        
        Args:
            error_type: Type of error
            phoneme: Expected phoneme
            wrong_sound: For substitutions, the wrong sound produced
        
        Returns:
            Therapy recommendation text
        """
        if error_type == ErrorType.NORMAL:
            return "No therapy needed - production is correct."
        
        # Build lookup key
        if error_type == ErrorType.SUBSTITUTION and wrong_sound:
            key = f"{phoneme}→{wrong_sound}"
            if "substitutions" in self.therapy_db and key in self.therapy_db["substitutions"]:
                return self.therapy_db["substitutions"][key]
        
        # Fallback to generic recommendations
        if error_type == ErrorType.SUBSTITUTION:
            if "substitutions" in self.therapy_db and "generic" in self.therapy_db["substitutions"]:
                return self.therapy_db["substitutions"]["generic"].replace("{phoneme}", phoneme)
            return f"Substitution error for {phoneme}. Practice correct articulator placement."
        
        elif error_type == ErrorType.OMISSION:
            if "omissions" in self.therapy_db and phoneme in self.therapy_db["omissions"]:
                return self.therapy_db["omissions"][phoneme]
            if "omissions" in self.therapy_db and "generic" in self.therapy_db["omissions"]:
                return self.therapy_db["omissions"]["generic"].replace("{phoneme}", phoneme)
            return f"Omission error for {phoneme}. Practice saying the sound separately first."
        
        elif error_type == ErrorType.DISTORTION:
            if "distortions" in self.therapy_db and phoneme in self.therapy_db["distortions"]:
                return self.therapy_db["distortions"][phoneme]
            if "distortions" in self.therapy_db and "generic" in self.therapy_db["distortions"]:
                return self.therapy_db["distortions"]["generic"].replace("{phoneme}", phoneme)
            return f"Distortion error for {phoneme}. Use mirror feedback and watch articulator position."
        
        return "Consult with speech-language pathologist for personalized therapy plan."
    
    def get_severity_level(self, severity: float) -> SeverityLevel:
        """
        Convert severity score to severity level.
        
        Args:
            severity: Severity score (0.0-1.0)
        
        Returns:
            SeverityLevel enum
        """
        if severity == 0.0:
            return SeverityLevel.NONE
        elif severity < 0.3:
            return SeverityLevel.LOW
        elif severity < 0.7:
            return SeverityLevel.MEDIUM
        else:
            return SeverityLevel.HIGH
    
    def _get_default_therapy_db(self) -> Dict:
        """Get default therapy database if file not found."""
        return {
            "substitutions": {
                "/s/→/θ/": "Lisp - Use tongue tip placement behind upper teeth. Practice /s/ in isolation.",
                "/r/→/w/": "Rhotacism - Practice tongue position: curl tongue back, avoid lip rounding.",
                "/r/→/l/": "Rhotacism - Focus on tongue tip position vs. tongue body placement.",
                "generic": "Substitution error for {phoneme}. Practice correct articulator placement with mirror feedback."
            },
            "omissions": {
                "/r/": "Practice /r/ in isolation, then in CV syllables (ra, re, ri, ro, ru).",
                "/l/": "Lateral tongue placement - practice with tongue tip up to alveolar ridge.",
                "/s/": "Practice /s/ with tongue tip placement, use mirror to check position.",
                "generic": "Omission error for {phoneme}. Say the sound separately first, then blend into words."
            },
            "distortions": {
                "/s/": "Sibilant clarity - use mirror feedback, ensure tongue tip is up and air stream is central.",
                "/ʃ/": "Fricative voicing exercise - practice /sh/ vs /s/ distinction.",
                "/r/": "Rhotacism - practice tongue position and lip rounding control.",
                "generic": "Distortion error for {phoneme}. Use mirror feedback and watch articulator position carefully."
            }
        }


# Unit test function
def test_error_mapper():
    """Test the ErrorMapper."""
    print("Testing ErrorMapper...")
    
    mapper = ErrorMapper()
    
    # Test 1: Normal (class 0)
    error = mapper.map_classifier_output(0, 0.95, "/k/")
    assert error.error_type == ErrorType.NORMAL
    assert error.severity == 0.0
    print(f"✅ Normal error: {error.error_type}, therapy: {error.therapy[:50]}...")
    
    # Test 2: Substitution (class 1)
    error = mapper.map_classifier_output(1, 0.78, "/s/")
    assert error.error_type == ErrorType.SUBSTITUTION
    assert error.wrong_sound is not None
    print(f"✅ Substitution error: {error.error_type}, wrong_sound: {error.wrong_sound}")
    print(f"   Therapy: {error.therapy[:80]}...")
    
    # Test 3: Omission (class 2)
    error = mapper.map_classifier_output(2, 0.85, "/r/")
    assert error.error_type == ErrorType.OMISSION
    print(f"✅ Omission error: {error.error_type}")
    print(f"   Therapy: {error.therapy[:80]}...")
    
    # Test 4: Distortion (class 3)
    error = mapper.map_classifier_output(3, 0.65, "/s/")
    assert error.error_type == ErrorType.DISTORTION
    print(f"✅ Distortion error: {error.error_type}")
    print(f"   Therapy: {error.therapy[:80]}...")
    
    # Test 5: Severity levels
    assert mapper.get_severity_level(0.0) == SeverityLevel.NONE
    assert mapper.get_severity_level(0.2) == SeverityLevel.LOW
    assert mapper.get_severity_level(0.5) == SeverityLevel.MEDIUM
    assert mapper.get_severity_level(0.8) == SeverityLevel.HIGH
    print("✅ Severity level mapping correct")
    
    print("\n✅ All tests passed!")


if __name__ == "__main__":
    test_error_mapper()