File size: 13,773 Bytes
278e294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cd6149
 
 
 
 
 
 
 
 
 
 
 
278e294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
"""
Phoneme Mapper for Speech Pathology Analysis

This module provides grapheme-to-phoneme (G2P) conversion and alignment
of phonemes to audio frames for phone-level error detection.
"""

import logging
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass
import numpy as np

try:
    import g2p_en
    G2P_AVAILABLE = True
except ImportError:
    G2P_AVAILABLE = False
    logging.warning("g2p_en not available. Install with: pip install g2p-en")

logger = logging.getLogger(__name__)


@dataclass
class PhonemeSegment:
    """
    Represents a phoneme segment with timing information.
    
    Attributes:
        phoneme: Phoneme symbol (e.g., '/r/', '/k/')
        start_time: Start time in seconds
        end_time: End time in seconds
        duration: Duration in seconds
        frame_start: Starting frame index
        frame_end: Ending frame index (exclusive)
    """
    phoneme: str
    start_time: float
    end_time: float
    duration: float
    frame_start: int
    frame_end: int


class PhonemeMapper:
    """
    Maps text to phonemes and aligns them to audio frames.
    
    Uses g2p_en library for English grapheme-to-phoneme conversion.
    Aligns phonemes to 20ms frames for phone-level analysis.
    
    Example:
        >>> mapper = PhonemeMapper()
        >>> phonemes = mapper.text_to_phonemes("robot")
        >>> # Returns: [('/r/', 0.0), ('/o/', 0.1), ('/b/', 0.2), ('/o/', 0.3), ('/t/', 0.4)]
        >>> frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25, frame_duration_ms=20)
        >>> # Returns: ['/r/', '/r/', '/r/', '/o/', '/o/', '/b/', '/b/', ...]
    """
    
    def __init__(self, frame_duration_ms: int = 20, sample_rate: int = 16000):
        """
        Initialize the PhonemeMapper.
        
        Args:
            frame_duration_ms: Duration of each frame in milliseconds (default: 20ms)
            sample_rate: Audio sample rate in Hz (default: 16000)
        
        Raises:
            ImportError: If g2p_en is not available
        """
        if not G2P_AVAILABLE:
            raise ImportError(
                "g2p_en library is required. Install with: pip install g2p-en"
            )
        
        # Ensure NLTK data is available (required by g2p_en)
        try:
            import nltk
            try:
                nltk.data.find('taggers/averaged_perceptron_tagger_eng')
            except LookupError:
                logger.info("Downloading NLTK averaged_perceptron_tagger_eng...")
                nltk.download('averaged_perceptron_tagger_eng', quiet=True)
                logger.info("✅ NLTK data downloaded")
        except Exception as e:
            logger.warning(f"⚠️ Could not download NLTK data: {e}")
        
        try:
            self.g2p = g2p_en.G2p()
            logger.info("✅ G2P model loaded successfully")
        except Exception as e:
            logger.error(f"❌ Failed to load G2P model: {e}")
            raise
        
        self.frame_duration_ms = frame_duration_ms
        self.frame_duration_s = frame_duration_ms / 1000.0
        self.sample_rate = sample_rate
        
        # Average phoneme duration (typical English: 50-100ms)
        # We'll use 80ms as default, but adjust based on text length
        self.avg_phoneme_duration_ms = 80
        self.avg_phoneme_duration_s = self.avg_phoneme_duration_ms / 1000.0
        
        logger.info(f"PhonemeMapper initialized: frame_duration={frame_duration_ms}ms, "
                   f"avg_phoneme_duration={self.avg_phoneme_duration_ms}ms")
    
    def text_to_phonemes(
        self,
        text: str,
        duration: Optional[float] = None
    ) -> List[Tuple[str, float]]:
        """
        Convert text to phonemes with timing information.
        
        Args:
            text: Input text string (e.g., "robot", "cat")
            duration: Optional audio duration in seconds. If provided, phonemes
                     are distributed evenly across this duration. If None, uses
                     estimated duration based on phoneme count.
        
        Returns:
            List of tuples: [(phoneme, start_time), ...]
            - phoneme: Phoneme symbol with slashes (e.g., '/r/', '/k/')
            - start_time: Start time in seconds
        
        Example:
            >>> mapper = PhonemeMapper()
            >>> phonemes = mapper.text_to_phonemes("cat")
            >>> # Returns: [('/k/', 0.0), ('/æ/', 0.08), ('/t/', 0.16)]
        """
        if not text or not text.strip():
            logger.warning("Empty text provided, returning empty phoneme list")
            return []
        
        try:
            # Convert to phonemes using g2p_en
            phoneme_list = self.g2p(text.lower().strip())
            
            # Filter out punctuation and empty strings
            phoneme_list = [p for p in phoneme_list if p and p.strip() and not p.isspace()]
            
            if not phoneme_list:
                logger.warning(f"No phonemes extracted from text: '{text}'")
                return []
            
            # Add slashes if not present
            formatted_phonemes = []
            for p in phoneme_list:
                if not p.startswith('/'):
                    p = '/' + p
                if not p.endswith('/'):
                    p = p + '/'
                formatted_phonemes.append(p)
            
            logger.debug(f"Extracted {len(formatted_phonemes)} phonemes from '{text}': {formatted_phonemes}")
            
            # Calculate timing
            if duration is None:
                # Estimate duration: avg_phoneme_duration * num_phonemes
                total_duration = len(formatted_phonemes) * self.avg_phoneme_duration_s
            else:
                total_duration = duration
            
            # Distribute phonemes evenly across duration
            if len(formatted_phonemes) == 1:
                phoneme_duration = total_duration
            else:
                phoneme_duration = total_duration / len(formatted_phonemes)
            
            # Create phoneme-time pairs
            phoneme_times = []
            for i, phoneme in enumerate(formatted_phonemes):
                start_time = i * phoneme_duration
                phoneme_times.append((phoneme, start_time))
            
            logger.info(f"Converted '{text}' to {len(phoneme_times)} phonemes over {total_duration:.2f}s")
            
            return phoneme_times
            
        except Exception as e:
            logger.error(f"Error converting text to phonemes: {e}", exc_info=True)
            raise RuntimeError(f"Failed to convert text to phonemes: {e}") from e
    
    def align_phonemes_to_frames(
        self,
        phoneme_times: List[Tuple[str, float]],
        num_frames: int,
        frame_duration_ms: Optional[int] = None
    ) -> List[str]:
        """
        Align phonemes to audio frames.
        
        Each frame gets assigned the phoneme that overlaps with its time window.
        If multiple phonemes overlap, uses the one with the most overlap.
        
        Args:
            phoneme_times: List of (phoneme, start_time) tuples from text_to_phonemes()
            num_frames: Total number of frames in the audio
            frame_duration_ms: Optional frame duration override
        
        Returns:
            List of phonemes, one per frame: ['/r/', '/r/', '/o/', '/b/', ...]
        
        Example:
            >>> mapper = PhonemeMapper()
            >>> phonemes = [('/k/', 0.0), ('/æ/', 0.08), ('/t/', 0.16)]
            >>> frames = mapper.align_phonemes_to_frames(phonemes, num_frames=15, frame_duration_ms=20)
            >>> # Returns: ['/k/', '/k/', '/k/', '/k/', '/æ/', '/æ/', '/æ/', '/æ/', '/t/', ...]
        """
        if not phoneme_times:
            logger.warning("No phonemes provided, returning empty frame list")
            return [''] * num_frames
        
        frame_duration_s = (frame_duration_ms / 1000.0) if frame_duration_ms else self.frame_duration_s
        
        # Calculate phoneme end times (assume equal duration for simplicity)
        phoneme_segments = []
        for i, (phoneme, start_time) in enumerate(phoneme_times):
            if i < len(phoneme_times) - 1:
                end_time = phoneme_times[i + 1][1]
            else:
                # Last phoneme: estimate duration
                if len(phoneme_times) > 1:
                    avg_duration = phoneme_times[1][1] - phoneme_times[0][1]
                else:
                    avg_duration = self.avg_phoneme_duration_s
                end_time = start_time + avg_duration
            
            phoneme_segments.append(PhonemeSegment(
                phoneme=phoneme,
                start_time=start_time,
                end_time=end_time,
                duration=end_time - start_time,
                frame_start=-1,  # Will be calculated
                frame_end=-1
            ))
        
        # Map each frame to a phoneme
        frame_phonemes = []
        for frame_idx in range(num_frames):
            frame_start_time = frame_idx * frame_duration_s
            frame_end_time = (frame_idx + 1) * frame_duration_s
            frame_center_time = frame_start_time + (frame_duration_s / 2.0)
            
            # Find phoneme with most overlap
            best_phoneme = ''
            max_overlap = 0.0
            
            for seg in phoneme_segments:
                # Calculate overlap
                overlap_start = max(frame_start_time, seg.start_time)
                overlap_end = min(frame_end_time, seg.end_time)
                overlap = max(0.0, overlap_end - overlap_start)
                
                if overlap > max_overlap:
                    max_overlap = overlap
                    best_phoneme = seg.phoneme
            
            # If no overlap, use closest phoneme
            if not best_phoneme:
                closest_seg = min(
                    phoneme_segments,
                    key=lambda s: abs(frame_center_time - (s.start_time + s.duration / 2))
                )
                best_phoneme = closest_seg.phoneme
            
            frame_phonemes.append(best_phoneme)
        
        logger.debug(f"Aligned {len(phoneme_times)} phonemes to {num_frames} frames")
        
        return frame_phonemes
    
    def get_phoneme_boundaries(
        self,
        phoneme_times: List[Tuple[str, float]],
        duration: float
    ) -> List[PhonemeSegment]:
        """
        Get detailed phoneme boundary information.
        
        Args:
            phoneme_times: List of (phoneme, start_time) tuples
            duration: Total audio duration in seconds
        
        Returns:
            List of PhonemeSegment objects with timing and frame information
        """
        segments = []
        
        for i, (phoneme, start_time) in enumerate(phoneme_times):
            if i < len(phoneme_times) - 1:
                end_time = phoneme_times[i + 1][1]
            else:
                end_time = duration
            
            frame_start = int(start_time / self.frame_duration_s)
            frame_end = int(end_time / self.frame_duration_s)
            
            segments.append(PhonemeSegment(
                phoneme=phoneme,
                start_time=start_time,
                end_time=end_time,
                duration=end_time - start_time,
                frame_start=frame_start,
                frame_end=frame_end
            ))
        
        return segments
    
    def map_text_to_frames(
        self,
        text: str,
        num_frames: int,
        audio_duration: Optional[float] = None
    ) -> List[str]:
        """
        Complete pipeline: text → phonemes → frame alignment.
        
        Args:
            text: Input text string
            num_frames: Number of audio frames
            audio_duration: Optional audio duration in seconds
        
        Returns:
            List of phonemes, one per frame
        """
        # Convert text to phonemes
        phoneme_times = self.text_to_phonemes(text, duration=audio_duration)
        
        if not phoneme_times:
            return [''] * num_frames
        
        # Align to frames
        frame_phonemes = self.align_phonemes_to_frames(phoneme_times, num_frames)
        
        return frame_phonemes


# Unit test function
def test_phoneme_mapper():
    """Test the PhonemeMapper with example text."""
    print("Testing PhonemeMapper...")
    
    try:
        mapper = PhonemeMapper(frame_duration_ms=20)
        
        # Test 1: Simple word
        print("\n1. Testing 'robot':")
        phonemes = mapper.text_to_phonemes("robot")
        print(f"   Phonemes: {phonemes}")
        assert len(phonemes) > 0, "Should extract phonemes"
        
        # Test 2: Frame alignment
        print("\n2. Testing frame alignment:")
        frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25)
        print(f"   Frame phonemes (first 10): {frame_phonemes[:10]}")
        assert len(frame_phonemes) == 25, "Should have 25 frames"
        
        # Test 3: Complete pipeline
        print("\n3. Testing complete pipeline with 'cat':")
        cat_frames = mapper.map_text_to_frames("cat", num_frames=15)
        print(f"   Frame phonemes: {cat_frames}")
        assert len(cat_frames) == 15, "Should have 15 frames"
        
        print("\n✅ All tests passed!")
        
    except ImportError as e:
        print(f"❌ G2P library not available: {e}")
        print("   Install with: pip install g2p-en")
    except Exception as e:
        print(f"❌ Test failed: {e}")
        raise


if __name__ == "__main__":
    test_phoneme_mapper()