zlaqa-version-c-ai-enginee / models /phoneme_mapper.py
anfastech's picture
New: Phoneme-level speech pathology diagnosis MVP with real-time streaming
1cd6149
"""
Phoneme Mapper for Speech Pathology Analysis
This module provides grapheme-to-phoneme (G2P) conversion and alignment
of phonemes to audio frames for phone-level error detection.
"""
import logging
from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass
import numpy as np
try:
import g2p_en
G2P_AVAILABLE = True
except ImportError:
G2P_AVAILABLE = False
logging.warning("g2p_en not available. Install with: pip install g2p-en")
logger = logging.getLogger(__name__)
@dataclass
class PhonemeSegment:
"""
Represents a phoneme segment with timing information.
Attributes:
phoneme: Phoneme symbol (e.g., '/r/', '/k/')
start_time: Start time in seconds
end_time: End time in seconds
duration: Duration in seconds
frame_start: Starting frame index
frame_end: Ending frame index (exclusive)
"""
phoneme: str
start_time: float
end_time: float
duration: float
frame_start: int
frame_end: int
class PhonemeMapper:
"""
Maps text to phonemes and aligns them to audio frames.
Uses g2p_en library for English grapheme-to-phoneme conversion.
Aligns phonemes to 20ms frames for phone-level analysis.
Example:
>>> mapper = PhonemeMapper()
>>> phonemes = mapper.text_to_phonemes("robot")
>>> # Returns: [('/r/', 0.0), ('/o/', 0.1), ('/b/', 0.2), ('/o/', 0.3), ('/t/', 0.4)]
>>> frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25, frame_duration_ms=20)
>>> # Returns: ['/r/', '/r/', '/r/', '/o/', '/o/', '/b/', '/b/', ...]
"""
def __init__(self, frame_duration_ms: int = 20, sample_rate: int = 16000):
"""
Initialize the PhonemeMapper.
Args:
frame_duration_ms: Duration of each frame in milliseconds (default: 20ms)
sample_rate: Audio sample rate in Hz (default: 16000)
Raises:
ImportError: If g2p_en is not available
"""
if not G2P_AVAILABLE:
raise ImportError(
"g2p_en library is required. Install with: pip install g2p-en"
)
# Ensure NLTK data is available (required by g2p_en)
try:
import nltk
try:
nltk.data.find('taggers/averaged_perceptron_tagger_eng')
except LookupError:
logger.info("Downloading NLTK averaged_perceptron_tagger_eng...")
nltk.download('averaged_perceptron_tagger_eng', quiet=True)
logger.info("✅ NLTK data downloaded")
except Exception as e:
logger.warning(f"⚠️ Could not download NLTK data: {e}")
try:
self.g2p = g2p_en.G2p()
logger.info("✅ G2P model loaded successfully")
except Exception as e:
logger.error(f"❌ Failed to load G2P model: {e}")
raise
self.frame_duration_ms = frame_duration_ms
self.frame_duration_s = frame_duration_ms / 1000.0
self.sample_rate = sample_rate
# Average phoneme duration (typical English: 50-100ms)
# We'll use 80ms as default, but adjust based on text length
self.avg_phoneme_duration_ms = 80
self.avg_phoneme_duration_s = self.avg_phoneme_duration_ms / 1000.0
logger.info(f"PhonemeMapper initialized: frame_duration={frame_duration_ms}ms, "
f"avg_phoneme_duration={self.avg_phoneme_duration_ms}ms")
def text_to_phonemes(
self,
text: str,
duration: Optional[float] = None
) -> List[Tuple[str, float]]:
"""
Convert text to phonemes with timing information.
Args:
text: Input text string (e.g., "robot", "cat")
duration: Optional audio duration in seconds. If provided, phonemes
are distributed evenly across this duration. If None, uses
estimated duration based on phoneme count.
Returns:
List of tuples: [(phoneme, start_time), ...]
- phoneme: Phoneme symbol with slashes (e.g., '/r/', '/k/')
- start_time: Start time in seconds
Example:
>>> mapper = PhonemeMapper()
>>> phonemes = mapper.text_to_phonemes("cat")
>>> # Returns: [('/k/', 0.0), ('/æ/', 0.08), ('/t/', 0.16)]
"""
if not text or not text.strip():
logger.warning("Empty text provided, returning empty phoneme list")
return []
try:
# Convert to phonemes using g2p_en
phoneme_list = self.g2p(text.lower().strip())
# Filter out punctuation and empty strings
phoneme_list = [p for p in phoneme_list if p and p.strip() and not p.isspace()]
if not phoneme_list:
logger.warning(f"No phonemes extracted from text: '{text}'")
return []
# Add slashes if not present
formatted_phonemes = []
for p in phoneme_list:
if not p.startswith('/'):
p = '/' + p
if not p.endswith('/'):
p = p + '/'
formatted_phonemes.append(p)
logger.debug(f"Extracted {len(formatted_phonemes)} phonemes from '{text}': {formatted_phonemes}")
# Calculate timing
if duration is None:
# Estimate duration: avg_phoneme_duration * num_phonemes
total_duration = len(formatted_phonemes) * self.avg_phoneme_duration_s
else:
total_duration = duration
# Distribute phonemes evenly across duration
if len(formatted_phonemes) == 1:
phoneme_duration = total_duration
else:
phoneme_duration = total_duration / len(formatted_phonemes)
# Create phoneme-time pairs
phoneme_times = []
for i, phoneme in enumerate(formatted_phonemes):
start_time = i * phoneme_duration
phoneme_times.append((phoneme, start_time))
logger.info(f"Converted '{text}' to {len(phoneme_times)} phonemes over {total_duration:.2f}s")
return phoneme_times
except Exception as e:
logger.error(f"Error converting text to phonemes: {e}", exc_info=True)
raise RuntimeError(f"Failed to convert text to phonemes: {e}") from e
def align_phonemes_to_frames(
self,
phoneme_times: List[Tuple[str, float]],
num_frames: int,
frame_duration_ms: Optional[int] = None
) -> List[str]:
"""
Align phonemes to audio frames.
Each frame gets assigned the phoneme that overlaps with its time window.
If multiple phonemes overlap, uses the one with the most overlap.
Args:
phoneme_times: List of (phoneme, start_time) tuples from text_to_phonemes()
num_frames: Total number of frames in the audio
frame_duration_ms: Optional frame duration override
Returns:
List of phonemes, one per frame: ['/r/', '/r/', '/o/', '/b/', ...]
Example:
>>> mapper = PhonemeMapper()
>>> phonemes = [('/k/', 0.0), ('/æ/', 0.08), ('/t/', 0.16)]
>>> frames = mapper.align_phonemes_to_frames(phonemes, num_frames=15, frame_duration_ms=20)
>>> # Returns: ['/k/', '/k/', '/k/', '/k/', '/æ/', '/æ/', '/æ/', '/æ/', '/t/', ...]
"""
if not phoneme_times:
logger.warning("No phonemes provided, returning empty frame list")
return [''] * num_frames
frame_duration_s = (frame_duration_ms / 1000.0) if frame_duration_ms else self.frame_duration_s
# Calculate phoneme end times (assume equal duration for simplicity)
phoneme_segments = []
for i, (phoneme, start_time) in enumerate(phoneme_times):
if i < len(phoneme_times) - 1:
end_time = phoneme_times[i + 1][1]
else:
# Last phoneme: estimate duration
if len(phoneme_times) > 1:
avg_duration = phoneme_times[1][1] - phoneme_times[0][1]
else:
avg_duration = self.avg_phoneme_duration_s
end_time = start_time + avg_duration
phoneme_segments.append(PhonemeSegment(
phoneme=phoneme,
start_time=start_time,
end_time=end_time,
duration=end_time - start_time,
frame_start=-1, # Will be calculated
frame_end=-1
))
# Map each frame to a phoneme
frame_phonemes = []
for frame_idx in range(num_frames):
frame_start_time = frame_idx * frame_duration_s
frame_end_time = (frame_idx + 1) * frame_duration_s
frame_center_time = frame_start_time + (frame_duration_s / 2.0)
# Find phoneme with most overlap
best_phoneme = ''
max_overlap = 0.0
for seg in phoneme_segments:
# Calculate overlap
overlap_start = max(frame_start_time, seg.start_time)
overlap_end = min(frame_end_time, seg.end_time)
overlap = max(0.0, overlap_end - overlap_start)
if overlap > max_overlap:
max_overlap = overlap
best_phoneme = seg.phoneme
# If no overlap, use closest phoneme
if not best_phoneme:
closest_seg = min(
phoneme_segments,
key=lambda s: abs(frame_center_time - (s.start_time + s.duration / 2))
)
best_phoneme = closest_seg.phoneme
frame_phonemes.append(best_phoneme)
logger.debug(f"Aligned {len(phoneme_times)} phonemes to {num_frames} frames")
return frame_phonemes
def get_phoneme_boundaries(
self,
phoneme_times: List[Tuple[str, float]],
duration: float
) -> List[PhonemeSegment]:
"""
Get detailed phoneme boundary information.
Args:
phoneme_times: List of (phoneme, start_time) tuples
duration: Total audio duration in seconds
Returns:
List of PhonemeSegment objects with timing and frame information
"""
segments = []
for i, (phoneme, start_time) in enumerate(phoneme_times):
if i < len(phoneme_times) - 1:
end_time = phoneme_times[i + 1][1]
else:
end_time = duration
frame_start = int(start_time / self.frame_duration_s)
frame_end = int(end_time / self.frame_duration_s)
segments.append(PhonemeSegment(
phoneme=phoneme,
start_time=start_time,
end_time=end_time,
duration=end_time - start_time,
frame_start=frame_start,
frame_end=frame_end
))
return segments
def map_text_to_frames(
self,
text: str,
num_frames: int,
audio_duration: Optional[float] = None
) -> List[str]:
"""
Complete pipeline: text → phonemes → frame alignment.
Args:
text: Input text string
num_frames: Number of audio frames
audio_duration: Optional audio duration in seconds
Returns:
List of phonemes, one per frame
"""
# Convert text to phonemes
phoneme_times = self.text_to_phonemes(text, duration=audio_duration)
if not phoneme_times:
return [''] * num_frames
# Align to frames
frame_phonemes = self.align_phonemes_to_frames(phoneme_times, num_frames)
return frame_phonemes
# Unit test function
def test_phoneme_mapper():
"""Test the PhonemeMapper with example text."""
print("Testing PhonemeMapper...")
try:
mapper = PhonemeMapper(frame_duration_ms=20)
# Test 1: Simple word
print("\n1. Testing 'robot':")
phonemes = mapper.text_to_phonemes("robot")
print(f" Phonemes: {phonemes}")
assert len(phonemes) > 0, "Should extract phonemes"
# Test 2: Frame alignment
print("\n2. Testing frame alignment:")
frame_phonemes = mapper.align_phonemes_to_frames(phonemes, num_frames=25)
print(f" Frame phonemes (first 10): {frame_phonemes[:10]}")
assert len(frame_phonemes) == 25, "Should have 25 frames"
# Test 3: Complete pipeline
print("\n3. Testing complete pipeline with 'cat':")
cat_frames = mapper.map_text_to_frames("cat", num_frames=15)
print(f" Frame phonemes: {cat_frames}")
assert len(cat_frames) == 15, "Should have 15 frames"
print("\n✅ All tests passed!")
except ImportError as e:
print(f"❌ G2P library not available: {e}")
print(" Install with: pip install g2p-en")
except Exception as e:
print(f"❌ Test failed: {e}")
raise
if __name__ == "__main__":
test_phoneme_mapper()