Spaces:
Sleeping
Sleeping
| """ | |
| MDD Engine β Mispronunciation Detection and Diagnosis | |
| ===================================================== | |
| Architecture (Shahin et al. 2025) | |
| ---------------------------------- | |
| Your model runs 35 independent CTC decoders, one per phonological feature. | |
| Each decoder outputs a sequence of +att(1) / -att(0) labels, with blanks | |
| already removed and runs collapsed β so the output length reflects the number | |
| of detected phoneme-level events, NOT audio frames. | |
| The canonical target comes from the user's typed sentence: | |
| sentence β G2P (CMU ARPAbet) β phoneme_sequence_to_feature_sequences() | |
| β 35 binary label sequences of length T (number of target phonemes) | |
| The problem: the actual decoded sequence per feature may have a DIFFERENT | |
| length than T, because the student may have: | |
| - deleted phonemes (actual shorter than target) | |
| - inserted extras (actual longer than target) | |
| - substituted (same length, wrong labels) | |
| Solution: Needleman-Wunsch (global sequence alignment) per feature | |
| ------------------------------------------------------------------ | |
| For each of the 35 features we run a global pairwise alignment between the | |
| target binary sequence and the actual binary sequence. This gives us an | |
| explicit alignment path with match / mismatch / insertion / deletion ops. | |
| We then aggregate across all 35 features to get, per target phoneme position: | |
| - which actual position it maps to (or DELETION if no match) | |
| - which features are missing (+att in target, -att or gap in actual) | |
| - which features are extra (-att in target, +att in actual) | |
| - a weighted feature accuracy score | |
| This is the standard approach in phonological MDD literature when no frame- | |
| level forced alignment is available (see e.g. Lee & Glass 2015, Leung et al. | |
| 2019, and the feature-based MDD track of the AIP challenge). | |
| Input/output contract | |
| --------------------- | |
| actual_feature_seqs : list[list[int]] β 35 lists, each decoded CTC output | |
| Values: 1 (+att) or 0 (-att) | |
| Lengths may differ across features | |
| and from the canonical length T | |
| target_phonemes : list[str] β CMU ARPAbet phoneme sequence from | |
| the user's typed sentence, length T | |
| Output: MDDResult (see dataclass below) | |
| """ | |
| from __future__ import annotations | |
| import numpy as np | |
| from dataclasses import dataclass, field | |
| from typing import List, Dict, Tuple, Optional | |
| from phonological_features import ( | |
| PHONOLOGICAL_FEATURES, | |
| phoneme_sequence_to_feature_sequences, | |
| phoneme_to_feature_vector, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Feature schema & weights | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| FEATURE_NAMES: List[str] = PHONOLOGICAL_FEATURES # 35 features, canonical order | |
| NUM_FEATURES = len(FEATURE_NAMES) # 35 | |
| assert NUM_FEATURES == 35 | |
| F2I: Dict[str, int] = {f: i for i, f in enumerate(FEATURE_NAMES)} | |
| # Perceptual salience weights β higher = more important mismatch. | |
| # Manner errors (wrong sound class) are most disruptive. | |
| # Voicing errors are highly salient in English. | |
| # Place errors matter but less so than manner. | |
| # Length/type distinctions are least salient in L2 MDD. | |
| FEATURE_WEIGHTS: np.ndarray = np.array([ | |
| # Manners (11): consonant sonorant fricative nasal stop | |
| 2.0, 1.5, 1.8, 2.0, 2.0, | |
| # approximant affricate liquid vowel semivowel continuant | |
| 1.5, 1.8, 1.5, 2.0, 1.5, 1.2, | |
| # Places (18): alveolar palatal dental glottal labial velar | |
| 1.5, 1.4, 1.3, 1.2, 1.5, 1.5, | |
| # mid high low front back central | |
| 1.8, 1.8, 1.8, 1.6, 1.6, 1.2, | |
| # anterior posterior retroflex bilabial coronal dorsal | |
| 1.3, 1.3, 1.3, 1.4, 1.3, 1.3, | |
| # Others (6): long short monophthong diphthong round voiced | |
| 1.0, 1.0, 1.2, 1.2, 1.2, 2.5, | |
| ], dtype=np.float32) | |
| assert len(FEATURE_WEIGHTS) == 35 | |
| # Alignment op codes | |
| MATCH = 0 # same label, same position | |
| MISMATCH = 1 # different label, same position | |
| DELETE = 2 # target has event, actual has gap (deletion error) | |
| INSERT = 3 # actual has event, target has gap (insertion error) | |
| # NW scoring scheme | |
| MATCH_SCORE = 2 | |
| MISMATCH_SCORE = -1 | |
| GAP_PENALTY = -2 # penalises deletions and insertions equally | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Data classes | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AlignedPosition: | |
| """One position in the target sequence after multi-feature alignment.""" | |
| target_idx: int # index in target phoneme sequence | |
| actual_idx: Optional[int] # index in actual sequence, None = deletion | |
| op: int # MATCH / MISMATCH / DELETE / INSERT | |
| target_bits: List[int] # canonical feature vector (35 bits) | |
| actual_bits: List[int] # observed feature vector (35 bits, 0 if deleted) | |
| missing_features: List[str] # +att in target, -att or gap in actual | |
| extra_features: List[str] # -att in target, +att in actual | |
| feature_accuracy: float # weighted accuracy 0-1 | |
| class PhonemeError: | |
| """One mispronounced phoneme with its full feature-level diagnosis.""" | |
| position: int # index in target sequence | |
| target_phoneme: str # ARPAbet label from typed sentence | |
| missing_features: List[str] # features the student failed to produce | |
| extra_features: List[str] # features the student added erroneously | |
| is_deletion: bool # student dropped this phoneme entirely | |
| feature_accuracy: float # 0-1 | |
| severity: str # "mild" | "moderate" | "severe" | |
| class MDDResult: | |
| utterance_score: float # 0-100 | |
| phoneme_scores: List[float] # per target phoneme, 0-1 | |
| errors: List[PhonemeError] | |
| aligned_positions: List[AlignedPosition] | |
| feature_error_counts: Dict[str, int] # aggregated across all phonemes | |
| deletion_count: int | |
| insertion_count: int | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. Needleman-Wunsch per-feature aligner | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _nw_align(target_seq: List[int], | |
| actual_seq: List[int]) -> List[Tuple[Optional[int], Optional[int]]]: | |
| """ | |
| Global sequence alignment (Needleman-Wunsch) for two binary label sequences. | |
| Returns a list of (target_idx, actual_idx) pairs where: | |
| (i, j) β match or mismatch at target[i], actual[j] | |
| (i, None) β deletion: target[i] has no corresponding actual event | |
| (None, j) β insertion: actual[j] has no corresponding target event | |
| Binary values: 1 = +att, 0 = -att | |
| """ | |
| T = len(target_seq) | |
| A = len(actual_seq) | |
| # Fill score matrix | |
| score = np.zeros((T + 1, A + 1), dtype=np.float32) | |
| score[0, :] = np.arange(A + 1) * GAP_PENALTY | |
| score[:, 0] = np.arange(T + 1) * GAP_PENALTY | |
| for i in range(1, T + 1): | |
| for j in range(1, A + 1): | |
| s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE | |
| score[i, j] = max( | |
| score[i-1, j-1] + s, # match/mismatch | |
| score[i-1, j] + GAP_PENALTY, # deletion | |
| score[i, j-1] + GAP_PENALTY, # insertion | |
| ) | |
| # Traceback | |
| path: List[Tuple[Optional[int], Optional[int]]] = [] | |
| i, j = T, A | |
| while i > 0 or j > 0: | |
| if i > 0 and j > 0: | |
| s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE | |
| if score[i, j] == score[i-1, j-1] + s: | |
| path.append((i-1, j-1)) | |
| i -= 1; j -= 1 | |
| continue | |
| if i > 0 and score[i, j] == score[i-1, j] + GAP_PENALTY: | |
| path.append((i-1, None)) # deletion | |
| i -= 1 | |
| else: | |
| path.append((None, j-1)) # insertion | |
| j -= 1 | |
| path.reverse() | |
| return path | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. Multi-feature alignment aggregator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _align_all_features( | |
| target_feat_seqs: List[List[int]], # 35 lists, each length T | |
| actual_feat_seqs: List[List[int]], # 35 lists, each possibly != T | |
| T: int, # number of target phonemes | |
| ) -> List[AlignedPosition]: | |
| """ | |
| Run NW alignment independently on each of 35 feature sequences, then | |
| aggregate the results per target phoneme position. | |
| Strategy | |
| -------- | |
| Each feature gives its own alignment path. We collect, for each target | |
| position i, a vote over all 35 features about what actual position it | |
| maps to. The plurality actual index wins. If the majority vote is "gap" | |
| (deletion), the position is marked as a deletion. | |
| Then per position we reconstruct the actual feature bits from the voted | |
| actual index across all features. | |
| """ | |
| # votes[target_idx] β list of actual_idx votes (None = deletion vote) | |
| votes: List[List[Optional[int]]] = [[] for _ in range(T)] | |
| # per_feature_actual_idx[feat][target_idx] β actual_idx or None | |
| per_feat_map: List[Dict[int, Optional[int]]] = [ | |
| {} for _ in range(NUM_FEATURES) | |
| ] | |
| for feat_i in range(NUM_FEATURES): | |
| t_seq = target_feat_seqs[feat_i] # length T | |
| a_seq = actual_feat_seqs[feat_i] # length may differ | |
| path = _nw_align(t_seq, a_seq) | |
| for (ti, ai) in path: | |
| if ti is None: | |
| continue # insertion β no target position, skip | |
| votes[ti].append(ai) # ai may be None (deletion) | |
| per_feat_map[feat_i][ti] = ai | |
| # Resolve votes per target position | |
| aligned: List[AlignedPosition] = [] | |
| DELETION_VOTE_THRESHOLD = 0.5 # >50% gap votes β mark as DELETE | |
| for ti in range(T): | |
| v = votes[ti] | |
| non_null = [x for x in v if x is not None] | |
| null_count = len(v) - len(non_null) | |
| deletion_fraction = null_count / max(len(v), 1) | |
| if not non_null or deletion_fraction > DELETION_VOTE_THRESHOLD: | |
| chosen_ai = None | |
| else: | |
| # Plurality vote among non-null actual indices | |
| counts: Dict[int, int] = {} | |
| for idx in non_null: | |
| counts[idx] = counts.get(idx, 0) + 1 | |
| chosen_ai = max(counts, key=counts.__getitem__) | |
| # Build target and actual bit vectors for this position | |
| target_bits = [target_feat_seqs[f][ti] for f in range(NUM_FEATURES)] | |
| if chosen_ai is not None: | |
| actual_bits = [] | |
| for f in range(NUM_FEATURES): | |
| # Use per-feature actual value if this feature agrees on chosen_ai | |
| feat_ai = per_feat_map[f].get(ti, None) | |
| if feat_ai == chosen_ai: | |
| actual_bits.append(actual_feat_seqs[f][feat_ai] | |
| if feat_ai < len(actual_feat_seqs[f]) else 0) | |
| else: | |
| # Feature disagrees on the position β use its own aligned value | |
| fa = per_feat_map[f].get(ti, None) | |
| if fa is not None and fa < len(actual_feat_seqs[f]): | |
| actual_bits.append(actual_feat_seqs[f][fa]) | |
| else: | |
| actual_bits.append(0) # treat as absent | |
| op = MATCH if target_bits == actual_bits else MISMATCH | |
| else: | |
| actual_bits = [0] * NUM_FEATURES | |
| op = DELETE | |
| # Compute feature-level errors | |
| missing = [FEATURE_NAMES[f] for f in range(NUM_FEATURES) | |
| if target_bits[f] == 1 and actual_bits[f] == 0] | |
| extra = [FEATURE_NAMES[f] for f in range(NUM_FEATURES) | |
| if target_bits[f] == 0 and actual_bits[f] == 1] | |
| # Weighted accuracy: fraction of weighted features correctly produced | |
| correct_weight = sum( | |
| FEATURE_WEIGHTS[f] | |
| for f in range(NUM_FEATURES) | |
| if target_bits[f] == actual_bits[f] | |
| ) | |
| total_weight = float(FEATURE_WEIGHTS.sum()) | |
| accuracy = float(correct_weight / total_weight) | |
| aligned.append(AlignedPosition( | |
| target_idx=ti, | |
| actual_idx=chosen_ai, | |
| op=op, | |
| target_bits=target_bits, | |
| actual_bits=actual_bits, | |
| missing_features=missing, | |
| extra_features=extra, | |
| feature_accuracy=accuracy, | |
| )) | |
| return aligned | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. Insertion detector | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _count_insertions( | |
| actual_feat_seqs: List[List[int]], | |
| actual_len: int, | |
| aligned: List[AlignedPosition], | |
| ) -> int: | |
| """ | |
| Count actual positions that were voted as insertions (not mapped to any | |
| target position) by the majority of features. | |
| """ | |
| used_actual = set( | |
| ap.actual_idx for ap in aligned if ap.actual_idx is not None | |
| ) | |
| inserted = set(range(actual_len)) - used_actual | |
| return len(inserted) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 6. Severity classifier | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Thresholds on weighted feature error rate | |
| _SEV = {"mild": 0.85, "moderate": 0.65} # accuracy thresholds (higher = easier) | |
| def _severity(accuracy: float, is_deletion: bool) -> str: | |
| if is_deletion: | |
| return "severe" | |
| if accuracy >= _SEV["mild"]: | |
| return "mild" | |
| if accuracy >= _SEV["moderate"]: | |
| return "moderate" | |
| return "severe" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 7. Scorer | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _score_utterance(aligned: List[AlignedPosition]) -> Tuple[float, List[float]]: | |
| """ | |
| Per-phoneme score: weighted feature accuracy (0-1). | |
| Deletions score 0. | |
| Utterance score: weighted mean, penalising deletions most. | |
| """ | |
| phoneme_scores = [ap.feature_accuracy for ap in aligned] | |
| utterance_score = float(np.mean(phoneme_scores)) * 100.0 | |
| return utterance_score, phoneme_scores | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 8. Error list builder | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_errors( | |
| aligned: List[AlignedPosition], | |
| target_phonemes: List[str], | |
| ) -> List[PhonemeError]: | |
| errors = [] | |
| for ap in aligned: | |
| if ap.op == MATCH and not ap.missing_features and not ap.extra_features: | |
| continue # perfectly correct, no error to report | |
| errors.append(PhonemeError( | |
| position=ap.target_idx, | |
| target_phoneme=target_phonemes[ap.target_idx], | |
| missing_features=ap.missing_features, | |
| extra_features=ap.extra_features, | |
| is_deletion=(ap.op == DELETE), | |
| feature_accuracy=ap.feature_accuracy, | |
| severity=_severity(ap.feature_accuracy, ap.op == DELETE), | |
| )) | |
| return errors | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 9. Aggregate feature error counts | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _aggregate(errors: List[PhonemeError]) -> Dict[str, int]: | |
| counts: Dict[str, int] = {} | |
| for e in errors: | |
| for f in e.missing_features + e.extra_features: | |
| counts[f] = counts.get(f, 0) + 1 | |
| return dict(sorted(counts.items(), key=lambda x: -x[1])) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 10. Public entry point | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_mdd( | |
| actual_feature_seqs: List[List[int]], | |
| target_phonemes: List[str], | |
| ) -> MDDResult: | |
| """ | |
| Full MDD pipeline for a CTC phonological-feature model. | |
| Parameters | |
| ---------- | |
| actual_feature_seqs : list of 35 lists of int (0 or 1) | |
| CTC-decoded output of your model, AFTER blank removal and run-length | |
| collapsing. Each list is the decoded +att/βatt sequence for one feature. | |
| Lengths may differ from each other and from len(target_phonemes). | |
| Index order must match PHONOLOGICAL_FEATURES / FEATURE_NAMES. | |
| Concretely, if your model outputs logits of shape (T_audio, 71): | |
| nodes 0-34 = +att for features 0-34 | |
| nodes 35-69 = -att for features 0-34 | |
| node 70 = blank | |
| Then for feature i, the CTC-decoded sequence is a list of 0s and 1s | |
| (1 = +att node fired, 0 = -att node fired), blanks removed. | |
| target_phonemes : list of str | |
| CMU ARPAbet phoneme sequence from the user's typed sentence. | |
| Obtain via any G2P tool, e.g. g2p_en: | |
| from g2p_en import G2p | |
| target_phonemes = G2p()(sentence) | |
| Returns | |
| ------- | |
| MDDResult | |
| """ | |
| assert len(actual_feature_seqs) == 35, \ | |
| f"Expected 35 feature sequences, got {len(actual_feature_seqs)}" | |
| assert len(target_phonemes) > 0, "target_phonemes must not be empty" | |
| T = len(target_phonemes) | |
| # Build canonical target feature sequences from the phoneme labels | |
| target_feat_seqs: List[List[int]] = phoneme_sequence_to_feature_sequences( | |
| target_phonemes | |
| ) # 35 lists, each of length T | |
| # Actual lengths (for insertion counting) | |
| actual_len = max((len(s) for s in actual_feature_seqs), default=0) | |
| # Step 1: per-feature NW alignment β per target-position feature bits | |
| aligned = _align_all_features(target_feat_seqs, actual_feature_seqs, T) | |
| # Step 2: count structural errors | |
| deletions = sum(1 for ap in aligned if ap.op == DELETE) | |
| insertions = _count_insertions(actual_feature_seqs, actual_len, aligned) | |
| # Step 3: score | |
| utt_score, phoneme_scores = _score_utterance(aligned) | |
| # Step 4: build error list | |
| errors = _build_errors(aligned, target_phonemes) | |
| # Step 5: aggregate feature error counts | |
| feat_error_counts = _aggregate(errors) | |
| return MDDResult( | |
| utterance_score=utt_score, | |
| phoneme_scores=phoneme_scores, | |
| errors=errors, | |
| aligned_positions=aligned, | |
| feature_error_counts=feat_error_counts, | |
| deletion_count=deletions, | |
| insertion_count=insertions, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 11. CTC decode helper (use this on raw model logits) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ctc_decode_feature_seqs( | |
| logits: np.ndarray, # (T_audio, 71) β raw model output per frame | |
| blank_idx: int = 70, | |
| ) -> List[List[int]]: | |
| """ | |
| Greedy CTC decode for a phonological feature model with 71 output nodes. | |
| For each of the 35 features independently: | |
| 1. At each frame, pick argmax between pos_node (feat_i) and neg_node (feat_i+35) | |
| (ignoring blank). | |
| 2. Collapse runs and remove frames where blank wins overall. | |
| 3. Return the sequence of 1s (+att) and 0s (-att). | |
| Parameters | |
| ---------- | |
| logits : np.ndarray (T_audio, 71) | |
| Raw model output before softmax. If you've already applied softmax, | |
| pass probabilities β the argmax logic is identical. | |
| blank_idx : int | |
| Index of the shared blank node (default 70). | |
| Returns | |
| ------- | |
| List of 35 lists of int (0 or 1), CTC-decoded. | |
| """ | |
| T_audio = logits.shape[0] | |
| feature_seqs: List[List[int]] = [[] for _ in range(35)] | |
| for feat_i in range(35): | |
| pos_node = feat_i # +att node | |
| neg_node = feat_i + 35 # -att node | |
| prev_label = None | |
| for t in range(T_audio): | |
| frame = logits[t] | |
| best_overall = int(np.argmax(frame)) | |
| if best_overall == blank_idx: | |
| prev_label = None # blank resets run | |
| continue | |
| # Among pos/neg for this feature, pick the winner | |
| label = 1 if frame[pos_node] >= frame[neg_node] else 0 | |
| # CTC run-length collapse | |
| if label != prev_label: | |
| feature_seqs[feat_i].append(label) | |
| prev_label = label | |
| return feature_seqs | |