Spaces:
Sleeping
Sleeping
| import Levenshtein | |
| from typing import List, Tuple, Dict | |
| import numpy as np | |
| from fastdtw import fastdtw | |
| import torch | |
| import torchaudio.functional as F | |
| class PronunciationScorer: | |
| def __init__(self): | |
| self.weights = { | |
| 'phoneme': 0.5, | |
| 'duration': 0.2, | |
| 'stress': 0.2, | |
| 'pitch': 0.1 | |
| } | |
| def _get_alignment_ops(self, pred: List[str], ref: List[str]) -> List[Tuple[str, str]]: | |
| """ | |
| Returns aligned phoneme pairs with gaps marked as '-' | |
| using Levenshtein edit operations | |
| """ | |
| aligned = [] | |
| i, j = 0, 0 | |
| ops = Levenshtein.editops(ref, pred) | |
| for op in ops: | |
| # Add matching phonemes before this edit | |
| while i < op[1] and j < op[2]: | |
| aligned.append((pred[j], ref[i])) | |
| i += 1 | |
| j += 1 | |
| # Handle the edit operation | |
| if op[0] == 'replace': | |
| aligned.append((pred[op[2]], ref[op[1]])) | |
| i += 1 | |
| j += 1 | |
| elif op[0] == 'delete': | |
| aligned.append(('-', ref[op[1]])) | |
| i += 1 | |
| elif op[0] == 'insert': | |
| aligned.append((pred[op[2]], '-')) | |
| j += 1 | |
| # Add remaining matching phonemes | |
| while i < len(ref) and j < len(pred): | |
| aligned.append((pred[j], ref[i])) | |
| i += 1 | |
| j += 1 | |
| return aligned | |
| def phoneme_accuracy(self, pred: List[str], ref: List[str]) -> Tuple[float, List[Tuple[str, str]]]: | |
| """ | |
| Returns: | |
| - accuracy score (0-1) | |
| - aligned phoneme pairs with gaps | |
| """ | |
| aligned = self._get_alignment_ops(pred, ref) | |
| correct = sum(1 for p, r in aligned if p == r) | |
| total_ref = len([r for _, r in aligned if r != '-']) | |
| return (correct / total_ref) if total_ref > 0 else 0.0, aligned | |
| def get_error_stats(self, aligned: List[Tuple[str, str]]) -> Dict[str, int]: | |
| """Returns counts of substitutions, insertions, deletions""" | |
| stats = {'sub': 0, 'ins': 0, 'del': 0} | |
| for p, r in aligned: | |
| if p == '-' and r != '-': | |
| stats['del'] += 1 | |
| elif p != '-' and r == '-': | |
| stats['ins'] += 1 | |
| elif p != r: | |
| stats['sub'] += 1 | |
| return stats | |
| def duration_score(self, | |
| pred_times: List[Tuple[float, float]], | |
| ref_times: List[Tuple[float, float]], | |
| aligned_pairs: List[Tuple[str, str]]) -> Dict[str, float]: | |
| """ | |
| Calculate duration metrics for aligned phonemes | |
| Returns: | |
| { | |
| 'accuracy': 0-1 score, | |
| 'avg_ratio': average duration ratio, | |
| 'error_ms': average error in milliseconds | |
| } | |
| """ | |
| if not pred_times or not ref_times: | |
| return {'accuracy': 0.0, 'avg_ratio': 1.0, 'error_ms': 0.0} | |
| scores = [] | |
| ratios = [] | |
| errors = [] | |
| pred_idx, ref_idx = 0, 0 | |
| for p_phn, r_phn in aligned_pairs: | |
| # Only compare when both phonemes exist | |
| if p_phn != '-' and r_phn != '-': | |
| p_start, p_end = pred_times[pred_idx] | |
| r_start, r_end = ref_times[ref_idx] | |
| p_dur = p_end - p_start | |
| r_dur = r_end - r_start | |
| if r_dur > 0: | |
| ratio = p_dur / r_dur | |
| ratios.append(ratio) | |
| errors.append(abs(p_dur - r_dur) * 1000) | |
| # Accuracy score (1 - normalized error) | |
| norm_error = min(1, abs(1 - ratio)) | |
| scores.append(1 - norm_error) | |
| pred_idx += 1 | |
| ref_idx += 1 | |
| else: | |
| if p_phn == '-': ref_idx += 1 | |
| if r_phn == '-': pred_idx += 1 | |
| if not scores: | |
| return {'accuracy': 0.0, 'avg_ratio': 1.0, 'error_ms': 0.0} | |
| return { | |
| 'accuracy': sum(scores) / len(scores), | |
| 'avg_ratio': sum(ratios) / len(ratios), | |
| 'error_ms': sum(errors) / len(errors) | |
| } | |
| def _extract_pitch_contour(self, waveform, sr, phoneme_times): | |
| """Extract pitch using a robust autocorrelation algorithm (avoids Numba/LLVM segfaults)""" | |
| pitch_contours = [] | |
| try: | |
| # Ensure waveform is 1D numpy array | |
| if isinstance(waveform, np.ndarray): | |
| if waveform.ndim > 1: | |
| waveform = waveform.squeeze() | |
| # Autocorrelation-based pitch tracking (pure NumPy) | |
| hop_length = 512 | |
| min_lag = int(sr / 500) | |
| max_lag = int(sr / 50) | |
| num_samples = len(waveform) | |
| f0 = [] | |
| for start_sample in range(0, num_samples - hop_length, hop_length): | |
| frame = waveform[start_sample:start_sample + hop_length] | |
| # Zero-mean the frame | |
| frame = frame - np.mean(frame) | |
| if np.std(frame) < 1e-4: | |
| f0.append(0.0) | |
| continue | |
| corr = np.correlate(frame, frame, mode='full') | |
| corr = corr[len(corr)//2:] | |
| if len(corr) > max_lag: | |
| search_region = corr[min_lag:max_lag] | |
| if len(search_region) > 0: | |
| peak_lag = np.argmax(search_region) + min_lag | |
| pitch = sr / peak_lag | |
| if corr[peak_lag] > 0.3 * corr[0]: | |
| f0.append(pitch) | |
| else: | |
| f0.append(0.0) | |
| else: | |
| f0.append(0.0) | |
| else: | |
| f0.append(0.0) | |
| f0 = np.array(f0) | |
| # Extract per-phoneme segments | |
| for start, end in phoneme_times: | |
| start_idx = int(start * sr / hop_length) | |
| end_idx = int(end * sr / hop_length) | |
| segment = f0[start_idx:end_idx] | |
| # Filter out unvoiced frames (0 values) | |
| segment_voiced = segment[segment > 0] | |
| pitch_contours.append(segment_voiced) | |
| except Exception as e: | |
| print(f"Pitch extraction error: {e}") | |
| return [] | |
| return pitch_contours | |
| def _extract_continuous_pitch(self, waveform, sr) -> List[float]: | |
| """ | |
| Extracts continuous pitch contour (list of Hz values) at 20ms frames (320 samples). | |
| """ | |
| try: | |
| if isinstance(waveform, torch.Tensor): | |
| waveform = waveform.cpu().numpy() | |
| if isinstance(waveform, np.ndarray): | |
| waveform = waveform.squeeze() | |
| hop_length = 320 | |
| frame_size = 512 | |
| min_lag = int(sr / 500) | |
| max_lag = int(sr / 50) | |
| num_samples = len(waveform) | |
| f0 = [] | |
| for start_sample in range(0, num_samples - frame_size, hop_length): | |
| frame = waveform[start_sample:start_sample + frame_size] | |
| frame = frame - np.mean(frame) | |
| if np.std(frame) < 1e-4: | |
| f0.append(0.0) | |
| continue | |
| corr = np.correlate(frame, frame, mode='full') | |
| corr = corr[len(corr)//2:] | |
| if len(corr) > max_lag: | |
| search_region = corr[min_lag:max_lag] | |
| if len(search_region) > 0: | |
| peak_lag = np.argmax(search_region) + min_lag | |
| pitch = sr / peak_lag | |
| if corr[peak_lag] > 0.25 * corr[0]: | |
| f0.append(float(pitch)) | |
| else: | |
| f0.append(0.0) | |
| else: | |
| f0.append(0.0) | |
| else: | |
| f0.append(0.0) | |
| # Map 0.0 values to None for clean JSON serialization | |
| f0_clean = [val if val > 0.0 else None for val in f0] | |
| return f0_clean | |
| except Exception as e: | |
| print(f"Error in continuous pitch extraction: {e}") | |
| return [] | |
| def pitch_score(self, pred_waveform, ref_waveform, sr, aligned_pairs, pred_times, ref_times): | |
| """Compare pitch contours using DTW and return trajectories""" | |
| # Ensure waveforms are numpy arrays | |
| if not isinstance(pred_waveform, np.ndarray): | |
| pred_waveform = np.array(pred_waveform) | |
| if not isinstance(ref_waveform, np.ndarray): | |
| ref_waveform = np.array(ref_waveform) | |
| # Squeeze to 1D if needed | |
| pred_waveform = pred_waveform.squeeze() | |
| ref_waveform = ref_waveform.squeeze() | |
| # Calculate continuous trajectories for visualization | |
| trajectory = self._extract_continuous_pitch(pred_waveform, sr) | |
| reference_trajectory = self._extract_continuous_pitch(ref_waveform, sr) | |
| # Filter out None values to get clean voiced trajectories | |
| pred_voiced = np.array([p for p in trajectory if p is not None], dtype=np.float32) | |
| ref_voiced = np.array([r for r in reference_trajectory if r is not None], dtype=np.float32) | |
| similarity = 0.8 # default baseline | |
| correlation = 0.8 | |
| error_hz = 0.0 | |
| if len(pred_voiced) > 3 and len(ref_voiced) > 3: | |
| try: | |
| # Normalize pitch to z-scores to compare relative shape rather than absolute register (male vs female) | |
| pred_std = np.std(pred_voiced) | |
| ref_std = np.std(ref_voiced) | |
| pred_norm = (pred_voiced - np.mean(pred_voiced)) / (pred_std if pred_std > 1e-4 else 1e-4) | |
| ref_norm = (ref_voiced - np.mean(ref_voiced)) / (ref_std if ref_std > 1e-4 else 1e-4) | |
| dtw_dist, _ = fastdtw(pred_norm, ref_norm) | |
| # Normalize DTW distance to 0-1 similarity based on size | |
| norm_factor = max(len(pred_norm), len(ref_norm)) | |
| similarity = 1 / (1 + (dtw_dist / (norm_factor if norm_factor > 0 else 1.0))) | |
| # Absolute difference in mean pitch | |
| error_hz = abs(np.mean(pred_voiced) - np.mean(ref_voiced)) | |
| # Correlation of truncated/aligned sequences | |
| min_len = min(len(pred_voiced), len(ref_voiced)) | |
| if min_len > 1: | |
| corr = np.corrcoef(pred_voiced[:min_len], ref_voiced[:min_len])[0, 1] | |
| correlation = 0.0 if np.isnan(corr) else corr | |
| except Exception as e: | |
| print(f"Warning: pitch score computation failed: {e}") | |
| return { | |
| 'similarity': float(similarity), | |
| 'error_hz': float(error_hz), | |
| 'correlation': float(correlation), | |
| 'trajectory': trajectory, | |
| 'reference_trajectory': reference_trajectory | |
| } | |
| def _extract_energy_envelope(self, waveform, sr) -> List[float]: | |
| """ | |
| Extracts the RMS energy envelope of the waveform at 20ms frames. | |
| """ | |
| try: | |
| if isinstance(waveform, torch.Tensor): | |
| waveform = waveform.cpu().numpy() | |
| if isinstance(waveform, np.ndarray): | |
| waveform = waveform.squeeze() | |
| hop_length = 320 # 20ms frames | |
| frame_size = 512 | |
| num_samples = len(waveform) | |
| energy = [] | |
| for start_sample in range(0, num_samples - frame_size, hop_length): | |
| frame = waveform[start_sample:start_sample + frame_size] | |
| rms = np.sqrt(np.mean(frame**2)) | |
| energy.append(float(rms)) | |
| # Normalize to 0-1 range to align scale | |
| energy = np.array(energy, dtype=np.float32) | |
| if len(energy) > 0: | |
| max_val = np.max(energy) | |
| if max_val > 1e-6: | |
| energy = energy / max_val | |
| return energy.tolist() | |
| except Exception as e: | |
| print(f"Error extracting energy envelope: {e}") | |
| return [] | |
| def stress_score(self, pred_waveform, ref_waveform, sr) -> Dict[str, float]: | |
| """ | |
| Compare dynamic stress (energy/loudness intensity envelopes) between prediction and reference. | |
| """ | |
| try: | |
| if not isinstance(pred_waveform, np.ndarray): | |
| pred_waveform = np.array(pred_waveform) | |
| if not isinstance(ref_waveform, np.ndarray): | |
| ref_waveform = np.array(ref_waveform) | |
| pred_waveform = pred_waveform.squeeze() | |
| ref_waveform = ref_waveform.squeeze() | |
| pred_energy = self._extract_energy_envelope(pred_waveform, sr) | |
| ref_energy = self._extract_energy_envelope(ref_waveform, sr) | |
| similarity = 0.8 # default baseline fallback | |
| if len(pred_energy) > 5 and len(ref_energy) > 5: | |
| dtw_dist, _ = fastdtw(np.array(pred_energy), np.array(ref_energy)) | |
| norm_factor = max(len(pred_energy), len(ref_energy)) | |
| similarity = 1 / (1 + (dtw_dist / (norm_factor if norm_factor > 0 else 1.0))) | |
| return { | |
| 'accuracy': float(similarity), | |
| 'error_stats': { | |
| 'missing_stress': 0, | |
| 'extra_stress': 0, | |
| 'wrong_stress': 0 | |
| } | |
| } | |
| except Exception as e: | |
| print(f"Error computing stress score: {e}") | |
| return {'accuracy': 0.8, 'error_stats': {}} | |
| def compute_scores(self, | |
| pred_phonemes: List[str], | |
| ref_phonemes: List[str], | |
| pred_times: List[Tuple[float, float]] = None, | |
| ref_times: List[Tuple[float, float]] = None, | |
| pred_waveform = None, | |
| ref_waveform = None, | |
| sr: int = None) -> Dict: | |
| """Enhanced scoring interface with pitch and stress analysis""" | |
| accuracy, aligned = self.phoneme_accuracy(pred_phonemes, ref_phonemes) | |
| results = { | |
| 'phoneme': accuracy, | |
| 'error_stats': self.get_error_stats(aligned), | |
| 'aligned_pairs': aligned | |
| } | |
| if pred_times and ref_times: | |
| results['duration'] = self.duration_score(pred_times, ref_times, aligned) | |
| if all(x is not None for x in [pred_waveform, ref_waveform, sr]): | |
| results['pitch'] = self.pitch_score( | |
| pred_waveform, ref_waveform, sr, aligned, pred_times, ref_times) | |
| results['stress'] = self.stress_score( | |
| pred_waveform, ref_waveform, sr) | |
| return results | |
| def ctc_forced_align(self, log_probs: torch.Tensor, targets: torch.Tensor, blank_id: int = 0) -> List[Tuple[int, int]]: | |
| """ | |
| Computes CTC forced alignment for batch_size=1. | |
| Args: | |
| log_probs: Tensor of shape (1, Time, Vocab) | |
| targets: Tensor of shape (1, Target_Len) | |
| blank_id: Index of blank token | |
| Returns: | |
| List of (start_frame, end_frame) matching each token in targets. | |
| """ | |
| B, T, C = log_probs.shape | |
| L = targets.shape[1] | |
| # Move inputs to CPU to avoid CUDA kernel/driver binary compatibility segfaults | |
| # and multi-GPU device mapping issues in torchaudio's C++ extension. | |
| log_probs_cpu = log_probs.cpu() | |
| targets_cpu = targets.cpu() | |
| targets_list = targets_cpu[0].numpy().tolist() | |
| # Validate constraints to prevent C++ out-of-bounds/assertion crashes | |
| # 1. Target sequence cannot be empty | |
| # 2. Input frames must be >= target length | |
| # 3. Target sequence must not contain the blank/pad token | |
| if L == 0 or T < L or blank_id in targets_list: | |
| print(f"Warning: CTC alignment constraints violated (T={T}, L={L}, blank_in_target={blank_id in targets_list}). Falling back to linear alignment.") | |
| intervals = [] | |
| step = T / max(L, 1) | |
| for idx in range(L): | |
| s = int(idx * step) | |
| e = int((idx + 1) * step) - 1 | |
| intervals.append((s, max(s, e))) | |
| return intervals | |
| input_lengths = torch.tensor([T], dtype=torch.long, device="cpu") | |
| target_lengths = torch.tensor([L], dtype=torch.long, device="cpu") | |
| # Log softmax along vocab dimension | |
| log_probs_norm = torch.log_softmax(log_probs_cpu, dim=-1) | |
| try: | |
| # torchaudio forced_align on CPU | |
| alignments, scores = F.forced_align( | |
| log_probs_norm, | |
| targets_cpu, | |
| input_lengths=input_lengths, | |
| target_lengths=target_lengths, | |
| blank=blank_id | |
| ) | |
| path = alignments[0].numpy().tolist() | |
| # Extract intervals using state machine | |
| intervals = [] | |
| target_idx = 0 | |
| start_frame = None | |
| end_frame = None | |
| saw_blank = False | |
| for t in range(T): | |
| token = path[t] | |
| if token == blank_id: | |
| saw_blank = True | |
| continue | |
| if (target_idx + 1 < L and token == targets_list[target_idx + 1] and | |
| start_frame is not None and | |
| (targets_list[target_idx + 1] != targets_list[target_idx] or saw_blank)): | |
| intervals.append((start_frame, end_frame)) | |
| target_idx += 1 | |
| start_frame = t | |
| end_frame = t | |
| saw_blank = False | |
| elif target_idx < L and token == targets_list[target_idx]: | |
| if start_frame is None: | |
| start_frame = t | |
| end_frame = t | |
| saw_blank = False | |
| if start_frame is not None: | |
| intervals.append((start_frame, end_frame)) | |
| except Exception as e: | |
| print(f"Warning: torchaudio forced_align failed: {e}. Falling back to linear alignment.") | |
| intervals = [] | |
| step = T / max(L, 1) | |
| for idx in range(L): | |
| s = int(idx * step) | |
| e = int((idx + 1) * step) - 1 | |
| intervals.append((s, max(s, e))) | |
| return intervals | |
| # Fallback padding | |
| while len(intervals) < L: | |
| if intervals: | |
| intervals.append(intervals[-1]) | |
| else: | |
| intervals.append((0, T - 1)) | |
| return intervals[:L] | |
| def compute_gop(self, | |
| log_probs: torch.Tensor, | |
| targets: torch.Tensor, | |
| intervals: List[Tuple[int, int]], | |
| vocab_tokens: List[str], | |
| blank_id: int = 0) -> List[Dict]: | |
| """ | |
| Computes Goodness of Pronunciation (GoP) using max-pooling and blank-exclusion. | |
| """ | |
| # Argmax predictions across all frames to identify blank frames | |
| pred_ids = torch.argmax(log_probs[0], dim=-1).cpu().numpy() | |
| probs = torch.softmax(log_probs[0], dim=-1) | |
| L = targets.shape[1] | |
| targets_list = targets[0].cpu().numpy().tolist() | |
| results = [] | |
| frame_stride_ms = 20.0 | |
| for idx in range(L): | |
| token_id = targets_list[idx] | |
| phoneme = vocab_tokens[idx] if idx < len(vocab_tokens) else str(token_id) | |
| s_frame, e_frame = intervals[idx] | |
| # Blank-Exclusion: Filter out frames where argmax prediction is <pad> (blank_id) | |
| valid_frames = [] | |
| for f in range(s_frame, e_frame + 1): | |
| if pred_ids[f] != blank_id: | |
| valid_frames.append(f) | |
| # If all frames in segment are blank, fall back to evaluating all frames in the segment | |
| if not valid_frames: | |
| valid_frames = list(range(s_frame, e_frame + 1)) | |
| token_probs = probs[valid_frames, token_id] | |
| # Max-Pooling: Take the maximum probability inside the valid segment frames | |
| if len(token_probs) > 0: | |
| gop_prob = float(torch.max(token_probs).item()) | |
| else: | |
| gop_prob = 1e-8 | |
| is_correct = bool(gop_prob >= 0.40) | |
| results.append({ | |
| "phoneme": phoneme, | |
| "start_ms": float(s_frame * frame_stride_ms), | |
| "end_ms": float((e_frame + 1) * frame_stride_ms), | |
| "gop_prob": gop_prob, | |
| "is_correct": is_correct | |
| }) | |
| return results |