import Levenshtein from typing import List, Tuple, Dict import numpy as np from fastdtw import fastdtw import torch import torchaudio.functional as F class PronunciationScorer: def __init__(self): self.weights = { 'phoneme': 0.5, 'duration': 0.2, 'stress': 0.2, 'pitch': 0.1 } def _get_alignment_ops(self, pred: List[str], ref: List[str]) -> List[Tuple[str, str]]: """ Returns aligned phoneme pairs with gaps marked as '-' using Levenshtein edit operations """ aligned = [] i, j = 0, 0 ops = Levenshtein.editops(ref, pred) for op in ops: # Add matching phonemes before this edit while i < op[1] and j < op[2]: aligned.append((pred[j], ref[i])) i += 1 j += 1 # Handle the edit operation if op[0] == 'replace': aligned.append((pred[op[2]], ref[op[1]])) i += 1 j += 1 elif op[0] == 'delete': aligned.append(('-', ref[op[1]])) i += 1 elif op[0] == 'insert': aligned.append((pred[op[2]], '-')) j += 1 # Add remaining matching phonemes while i < len(ref) and j < len(pred): aligned.append((pred[j], ref[i])) i += 1 j += 1 return aligned def phoneme_accuracy(self, pred: List[str], ref: List[str]) -> Tuple[float, List[Tuple[str, str]]]: """ Returns: - accuracy score (0-1) - aligned phoneme pairs with gaps """ aligned = self._get_alignment_ops(pred, ref) correct = sum(1 for p, r in aligned if p == r) total_ref = len([r for _, r in aligned if r != '-']) return (correct / total_ref) if total_ref > 0 else 0.0, aligned def get_error_stats(self, aligned: List[Tuple[str, str]]) -> Dict[str, int]: """Returns counts of substitutions, insertions, deletions""" stats = {'sub': 0, 'ins': 0, 'del': 0} for p, r in aligned: if p == '-' and r != '-': stats['del'] += 1 elif p != '-' and r == '-': stats['ins'] += 1 elif p != r: stats['sub'] += 1 return stats def duration_score(self, pred_times: List[Tuple[float, float]], ref_times: List[Tuple[float, float]], aligned_pairs: List[Tuple[str, str]]) -> Dict[str, float]: """ Calculate duration metrics for aligned phonemes Returns: { 'accuracy': 0-1 score, 'avg_ratio': average duration ratio, 'error_ms': average error in milliseconds } """ if not pred_times or not ref_times: return {'accuracy': 0.0, 'avg_ratio': 1.0, 'error_ms': 0.0} scores = [] ratios = [] errors = [] pred_idx, ref_idx = 0, 0 for p_phn, r_phn in aligned_pairs: # Only compare when both phonemes exist if p_phn != '-' and r_phn != '-': p_start, p_end = pred_times[pred_idx] r_start, r_end = ref_times[ref_idx] p_dur = p_end - p_start r_dur = r_end - r_start if r_dur > 0: ratio = p_dur / r_dur ratios.append(ratio) errors.append(abs(p_dur - r_dur) * 1000) # Accuracy score (1 - normalized error) norm_error = min(1, abs(1 - ratio)) scores.append(1 - norm_error) pred_idx += 1 ref_idx += 1 else: if p_phn == '-': ref_idx += 1 if r_phn == '-': pred_idx += 1 if not scores: return {'accuracy': 0.0, 'avg_ratio': 1.0, 'error_ms': 0.0} return { 'accuracy': sum(scores) / len(scores), 'avg_ratio': sum(ratios) / len(ratios), 'error_ms': sum(errors) / len(errors) } def _extract_pitch_contour(self, waveform, sr, phoneme_times): """Extract pitch using a robust autocorrelation algorithm (avoids Numba/LLVM segfaults)""" pitch_contours = [] try: # Ensure waveform is 1D numpy array if isinstance(waveform, np.ndarray): if waveform.ndim > 1: waveform = waveform.squeeze() # Autocorrelation-based pitch tracking (pure NumPy) hop_length = 512 min_lag = int(sr / 500) max_lag = int(sr / 50) num_samples = len(waveform) f0 = [] for start_sample in range(0, num_samples - hop_length, hop_length): frame = waveform[start_sample:start_sample + hop_length] # Zero-mean the frame frame = frame - np.mean(frame) if np.std(frame) < 1e-4: f0.append(0.0) continue corr = np.correlate(frame, frame, mode='full') corr = corr[len(corr)//2:] if len(corr) > max_lag: search_region = corr[min_lag:max_lag] if len(search_region) > 0: peak_lag = np.argmax(search_region) + min_lag pitch = sr / peak_lag if corr[peak_lag] > 0.3 * corr[0]: f0.append(pitch) else: f0.append(0.0) else: f0.append(0.0) else: f0.append(0.0) f0 = np.array(f0) # Extract per-phoneme segments for start, end in phoneme_times: start_idx = int(start * sr / hop_length) end_idx = int(end * sr / hop_length) segment = f0[start_idx:end_idx] # Filter out unvoiced frames (0 values) segment_voiced = segment[segment > 0] pitch_contours.append(segment_voiced) except Exception as e: print(f"Pitch extraction error: {e}") return [] return pitch_contours def _extract_continuous_pitch(self, waveform, sr) -> List[float]: """ Extracts continuous pitch contour (list of Hz values) at 20ms frames (320 samples). """ try: if isinstance(waveform, torch.Tensor): waveform = waveform.cpu().numpy() if isinstance(waveform, np.ndarray): waveform = waveform.squeeze() hop_length = 320 frame_size = 512 min_lag = int(sr / 500) max_lag = int(sr / 50) num_samples = len(waveform) f0 = [] for start_sample in range(0, num_samples - frame_size, hop_length): frame = waveform[start_sample:start_sample + frame_size] frame = frame - np.mean(frame) if np.std(frame) < 1e-4: f0.append(0.0) continue corr = np.correlate(frame, frame, mode='full') corr = corr[len(corr)//2:] if len(corr) > max_lag: search_region = corr[min_lag:max_lag] if len(search_region) > 0: peak_lag = np.argmax(search_region) + min_lag pitch = sr / peak_lag if corr[peak_lag] > 0.25 * corr[0]: f0.append(float(pitch)) else: f0.append(0.0) else: f0.append(0.0) else: f0.append(0.0) # Map 0.0 values to None for clean JSON serialization f0_clean = [val if val > 0.0 else None for val in f0] return f0_clean except Exception as e: print(f"Error in continuous pitch extraction: {e}") return [] def pitch_score(self, pred_waveform, ref_waveform, sr, aligned_pairs, pred_times, ref_times): """Compare pitch contours using DTW and return trajectories""" # Ensure waveforms are numpy arrays if not isinstance(pred_waveform, np.ndarray): pred_waveform = np.array(pred_waveform) if not isinstance(ref_waveform, np.ndarray): ref_waveform = np.array(ref_waveform) # Squeeze to 1D if needed pred_waveform = pred_waveform.squeeze() ref_waveform = ref_waveform.squeeze() # Calculate continuous trajectories for visualization trajectory = self._extract_continuous_pitch(pred_waveform, sr) reference_trajectory = self._extract_continuous_pitch(ref_waveform, sr) # Filter out None values to get clean voiced trajectories pred_voiced = np.array([p for p in trajectory if p is not None], dtype=np.float32) ref_voiced = np.array([r for r in reference_trajectory if r is not None], dtype=np.float32) similarity = 0.8 # default baseline correlation = 0.8 error_hz = 0.0 if len(pred_voiced) > 3 and len(ref_voiced) > 3: try: # Normalize pitch to z-scores to compare relative shape rather than absolute register (male vs female) pred_std = np.std(pred_voiced) ref_std = np.std(ref_voiced) pred_norm = (pred_voiced - np.mean(pred_voiced)) / (pred_std if pred_std > 1e-4 else 1e-4) ref_norm = (ref_voiced - np.mean(ref_voiced)) / (ref_std if ref_std > 1e-4 else 1e-4) dtw_dist, _ = fastdtw(pred_norm, ref_norm) # Normalize DTW distance to 0-1 similarity based on size norm_factor = max(len(pred_norm), len(ref_norm)) similarity = 1 / (1 + (dtw_dist / (norm_factor if norm_factor > 0 else 1.0))) # Absolute difference in mean pitch error_hz = abs(np.mean(pred_voiced) - np.mean(ref_voiced)) # Correlation of truncated/aligned sequences min_len = min(len(pred_voiced), len(ref_voiced)) if min_len > 1: corr = np.corrcoef(pred_voiced[:min_len], ref_voiced[:min_len])[0, 1] correlation = 0.0 if np.isnan(corr) else corr except Exception as e: print(f"Warning: pitch score computation failed: {e}") return { 'similarity': float(similarity), 'error_hz': float(error_hz), 'correlation': float(correlation), 'trajectory': trajectory, 'reference_trajectory': reference_trajectory } def _extract_energy_envelope(self, waveform, sr) -> List[float]: """ Extracts the RMS energy envelope of the waveform at 20ms frames. """ try: if isinstance(waveform, torch.Tensor): waveform = waveform.cpu().numpy() if isinstance(waveform, np.ndarray): waveform = waveform.squeeze() hop_length = 320 # 20ms frames frame_size = 512 num_samples = len(waveform) energy = [] for start_sample in range(0, num_samples - frame_size, hop_length): frame = waveform[start_sample:start_sample + frame_size] rms = np.sqrt(np.mean(frame**2)) energy.append(float(rms)) # Normalize to 0-1 range to align scale energy = np.array(energy, dtype=np.float32) if len(energy) > 0: max_val = np.max(energy) if max_val > 1e-6: energy = energy / max_val return energy.tolist() except Exception as e: print(f"Error extracting energy envelope: {e}") return [] def stress_score(self, pred_waveform, ref_waveform, sr) -> Dict[str, float]: """ Compare dynamic stress (energy/loudness intensity envelopes) between prediction and reference. """ try: if not isinstance(pred_waveform, np.ndarray): pred_waveform = np.array(pred_waveform) if not isinstance(ref_waveform, np.ndarray): ref_waveform = np.array(ref_waveform) pred_waveform = pred_waveform.squeeze() ref_waveform = ref_waveform.squeeze() pred_energy = self._extract_energy_envelope(pred_waveform, sr) ref_energy = self._extract_energy_envelope(ref_waveform, sr) similarity = 0.8 # default baseline fallback if len(pred_energy) > 5 and len(ref_energy) > 5: dtw_dist, _ = fastdtw(np.array(pred_energy), np.array(ref_energy)) norm_factor = max(len(pred_energy), len(ref_energy)) similarity = 1 / (1 + (dtw_dist / (norm_factor if norm_factor > 0 else 1.0))) return { 'accuracy': float(similarity), 'error_stats': { 'missing_stress': 0, 'extra_stress': 0, 'wrong_stress': 0 } } except Exception as e: print(f"Error computing stress score: {e}") return {'accuracy': 0.8, 'error_stats': {}} def compute_scores(self, pred_phonemes: List[str], ref_phonemes: List[str], pred_times: List[Tuple[float, float]] = None, ref_times: List[Tuple[float, float]] = None, pred_waveform = None, ref_waveform = None, sr: int = None) -> Dict: """Enhanced scoring interface with pitch and stress analysis""" accuracy, aligned = self.phoneme_accuracy(pred_phonemes, ref_phonemes) results = { 'phoneme': accuracy, 'error_stats': self.get_error_stats(aligned), 'aligned_pairs': aligned } if pred_times and ref_times: results['duration'] = self.duration_score(pred_times, ref_times, aligned) if all(x is not None for x in [pred_waveform, ref_waveform, sr]): results['pitch'] = self.pitch_score( pred_waveform, ref_waveform, sr, aligned, pred_times, ref_times) results['stress'] = self.stress_score( pred_waveform, ref_waveform, sr) return results def ctc_forced_align(self, log_probs: torch.Tensor, targets: torch.Tensor, blank_id: int = 0) -> List[Tuple[int, int]]: """ Computes CTC forced alignment for batch_size=1. Args: log_probs: Tensor of shape (1, Time, Vocab) targets: Tensor of shape (1, Target_Len) blank_id: Index of blank token Returns: List of (start_frame, end_frame) matching each token in targets. """ B, T, C = log_probs.shape L = targets.shape[1] # Move inputs to CPU to avoid CUDA kernel/driver binary compatibility segfaults # and multi-GPU device mapping issues in torchaudio's C++ extension. log_probs_cpu = log_probs.cpu() targets_cpu = targets.cpu() targets_list = targets_cpu[0].numpy().tolist() # Validate constraints to prevent C++ out-of-bounds/assertion crashes # 1. Target sequence cannot be empty # 2. Input frames must be >= target length # 3. Target sequence must not contain the blank/pad token if L == 0 or T < L or blank_id in targets_list: print(f"Warning: CTC alignment constraints violated (T={T}, L={L}, blank_in_target={blank_id in targets_list}). Falling back to linear alignment.") intervals = [] step = T / max(L, 1) for idx in range(L): s = int(idx * step) e = int((idx + 1) * step) - 1 intervals.append((s, max(s, e))) return intervals input_lengths = torch.tensor([T], dtype=torch.long, device="cpu") target_lengths = torch.tensor([L], dtype=torch.long, device="cpu") # Log softmax along vocab dimension log_probs_norm = torch.log_softmax(log_probs_cpu, dim=-1) try: # torchaudio forced_align on CPU alignments, scores = F.forced_align( log_probs_norm, targets_cpu, input_lengths=input_lengths, target_lengths=target_lengths, blank=blank_id ) path = alignments[0].numpy().tolist() # Extract intervals using state machine intervals = [] target_idx = 0 start_frame = None end_frame = None saw_blank = False for t in range(T): token = path[t] if token == blank_id: saw_blank = True continue if (target_idx + 1 < L and token == targets_list[target_idx + 1] and start_frame is not None and (targets_list[target_idx + 1] != targets_list[target_idx] or saw_blank)): intervals.append((start_frame, end_frame)) target_idx += 1 start_frame = t end_frame = t saw_blank = False elif target_idx < L and token == targets_list[target_idx]: if start_frame is None: start_frame = t end_frame = t saw_blank = False if start_frame is not None: intervals.append((start_frame, end_frame)) except Exception as e: print(f"Warning: torchaudio forced_align failed: {e}. Falling back to linear alignment.") intervals = [] step = T / max(L, 1) for idx in range(L): s = int(idx * step) e = int((idx + 1) * step) - 1 intervals.append((s, max(s, e))) return intervals # Fallback padding while len(intervals) < L: if intervals: intervals.append(intervals[-1]) else: intervals.append((0, T - 1)) return intervals[:L] def compute_gop(self, log_probs: torch.Tensor, targets: torch.Tensor, intervals: List[Tuple[int, int]], vocab_tokens: List[str], blank_id: int = 0) -> List[Dict]: """ Computes Goodness of Pronunciation (GoP) using max-pooling and blank-exclusion. """ # Argmax predictions across all frames to identify blank frames pred_ids = torch.argmax(log_probs[0], dim=-1).cpu().numpy() probs = torch.softmax(log_probs[0], dim=-1) L = targets.shape[1] targets_list = targets[0].cpu().numpy().tolist() results = [] frame_stride_ms = 20.0 for idx in range(L): token_id = targets_list[idx] phoneme = vocab_tokens[idx] if idx < len(vocab_tokens) else str(token_id) s_frame, e_frame = intervals[idx] # Blank-Exclusion: Filter out frames where argmax prediction is (blank_id) valid_frames = [] for f in range(s_frame, e_frame + 1): if pred_ids[f] != blank_id: valid_frames.append(f) # If all frames in segment are blank, fall back to evaluating all frames in the segment if not valid_frames: valid_frames = list(range(s_frame, e_frame + 1)) token_probs = probs[valid_frames, token_id] # Max-Pooling: Take the maximum probability inside the valid segment frames if len(token_probs) > 0: gop_prob = float(torch.max(token_probs).item()) else: gop_prob = 1e-8 is_correct = bool(gop_prob >= 0.40) results.append({ "phoneme": phoneme, "start_ms": float(s_frame * frame_stride_ms), "end_ms": float((e_frame + 1) * frame_stride_ms), "gop_prob": gop_prob, "is_correct": is_correct }) return results