""" MDD Engine — Mispronunciation Detection and Diagnosis ===================================================== Architecture (Shahin et al. 2025) ---------------------------------- Your model runs 35 independent CTC decoders, one per phonological feature. Each decoder outputs a sequence of +att(1) / -att(0) labels, with blanks already removed and runs collapsed — so the output length reflects the number of detected phoneme-level events, NOT audio frames. The canonical target comes from the user's typed sentence: sentence → G2P (CMU ARPAbet) → phoneme_sequence_to_feature_sequences() → 35 binary label sequences of length T (number of target phonemes) The problem: the actual decoded sequence per feature may have a DIFFERENT length than T, because the student may have: - deleted phonemes (actual shorter than target) - inserted extras (actual longer than target) - substituted (same length, wrong labels) Solution: Needleman-Wunsch (global sequence alignment) per feature ------------------------------------------------------------------ For each of the 35 features we run a global pairwise alignment between the target binary sequence and the actual binary sequence. This gives us an explicit alignment path with match / mismatch / insertion / deletion ops. We then aggregate across all 35 features to get, per target phoneme position: - which actual position it maps to (or DELETION if no match) - which features are missing (+att in target, -att or gap in actual) - which features are extra (-att in target, +att in actual) - a weighted feature accuracy score This is the standard approach in phonological MDD literature when no frame- level forced alignment is available (see e.g. Lee & Glass 2015, Leung et al. 2019, and the feature-based MDD track of the AIP challenge). Input/output contract --------------------- actual_feature_seqs : list[list[int]] — 35 lists, each decoded CTC output Values: 1 (+att) or 0 (-att) Lengths may differ across features and from the canonical length T target_phonemes : list[str] — CMU ARPAbet phoneme sequence from the user's typed sentence, length T Output: MDDResult (see dataclass below) """ from __future__ import annotations import numpy as np from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional from phonological_features import ( PHONOLOGICAL_FEATURES, phoneme_sequence_to_feature_sequences, phoneme_to_feature_vector, ) # ───────────────────────────────────────────────────────────────────────────── # 1. Feature schema & weights # ───────────────────────────────────────────────────────────────────────────── FEATURE_NAMES: List[str] = PHONOLOGICAL_FEATURES # 35 features, canonical order NUM_FEATURES = len(FEATURE_NAMES) # 35 assert NUM_FEATURES == 35 F2I: Dict[str, int] = {f: i for i, f in enumerate(FEATURE_NAMES)} # Perceptual salience weights — higher = more important mismatch. # Manner errors (wrong sound class) are most disruptive. # Voicing errors are highly salient in English. # Place errors matter but less so than manner. # Length/type distinctions are least salient in L2 MDD. FEATURE_WEIGHTS: np.ndarray = np.array([ # Manners (11): consonant sonorant fricative nasal stop 2.0, 1.5, 1.8, 2.0, 2.0, # approximant affricate liquid vowel semivowel continuant 1.5, 1.8, 1.5, 2.0, 1.5, 1.2, # Places (18): alveolar palatal dental glottal labial velar 1.5, 1.4, 1.3, 1.2, 1.5, 1.5, # mid high low front back central 1.8, 1.8, 1.8, 1.6, 1.6, 1.2, # anterior posterior retroflex bilabial coronal dorsal 1.3, 1.3, 1.3, 1.4, 1.3, 1.3, # Others (6): long short monophthong diphthong round voiced 1.0, 1.0, 1.2, 1.2, 1.2, 2.5, ], dtype=np.float32) assert len(FEATURE_WEIGHTS) == 35 # Alignment op codes MATCH = 0 # same label, same position MISMATCH = 1 # different label, same position DELETE = 2 # target has event, actual has gap (deletion error) INSERT = 3 # actual has event, target has gap (insertion error) # NW scoring scheme MATCH_SCORE = 2 MISMATCH_SCORE = -1 GAP_PENALTY = -2 # penalises deletions and insertions equally # ───────────────────────────────────────────────────────────────────────────── # 2. Data classes # ───────────────────────────────────────────────────────────────────────────── @dataclass class AlignedPosition: """One position in the target sequence after multi-feature alignment.""" target_idx: int # index in target phoneme sequence actual_idx: Optional[int] # index in actual sequence, None = deletion op: int # MATCH / MISMATCH / DELETE / INSERT target_bits: List[int] # canonical feature vector (35 bits) actual_bits: List[int] # observed feature vector (35 bits, 0 if deleted) missing_features: List[str] # +att in target, -att or gap in actual extra_features: List[str] # -att in target, +att in actual feature_accuracy: float # weighted accuracy 0-1 @dataclass class PhonemeError: """One mispronounced phoneme with its full feature-level diagnosis.""" position: int # index in target sequence target_phoneme: str # ARPAbet label from typed sentence missing_features: List[str] # features the student failed to produce extra_features: List[str] # features the student added erroneously is_deletion: bool # student dropped this phoneme entirely feature_accuracy: float # 0-1 severity: str # "mild" | "moderate" | "severe" @dataclass class MDDResult: utterance_score: float # 0-100 phoneme_scores: List[float] # per target phoneme, 0-1 errors: List[PhonemeError] aligned_positions: List[AlignedPosition] feature_error_counts: Dict[str, int] # aggregated across all phonemes deletion_count: int insertion_count: int # ───────────────────────────────────────────────────────────────────────────── # 3. Needleman-Wunsch per-feature aligner # ───────────────────────────────────────────────────────────────────────────── def _nw_align(target_seq: List[int], actual_seq: List[int]) -> List[Tuple[Optional[int], Optional[int]]]: """ Global sequence alignment (Needleman-Wunsch) for two binary label sequences. Returns a list of (target_idx, actual_idx) pairs where: (i, j) → match or mismatch at target[i], actual[j] (i, None) → deletion: target[i] has no corresponding actual event (None, j) → insertion: actual[j] has no corresponding target event Binary values: 1 = +att, 0 = -att """ T = len(target_seq) A = len(actual_seq) # Fill score matrix score = np.zeros((T + 1, A + 1), dtype=np.float32) score[0, :] = np.arange(A + 1) * GAP_PENALTY score[:, 0] = np.arange(T + 1) * GAP_PENALTY for i in range(1, T + 1): for j in range(1, A + 1): s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE score[i, j] = max( score[i-1, j-1] + s, # match/mismatch score[i-1, j] + GAP_PENALTY, # deletion score[i, j-1] + GAP_PENALTY, # insertion ) # Traceback path: List[Tuple[Optional[int], Optional[int]]] = [] i, j = T, A while i > 0 or j > 0: if i > 0 and j > 0: s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE if score[i, j] == score[i-1, j-1] + s: path.append((i-1, j-1)) i -= 1; j -= 1 continue if i > 0 and score[i, j] == score[i-1, j] + GAP_PENALTY: path.append((i-1, None)) # deletion i -= 1 else: path.append((None, j-1)) # insertion j -= 1 path.reverse() return path # ───────────────────────────────────────────────────────────────────────────── # 4. Multi-feature alignment aggregator # ───────────────────────────────────────────────────────────────────────────── def _align_all_features( target_feat_seqs: List[List[int]], # 35 lists, each length T actual_feat_seqs: List[List[int]], # 35 lists, each possibly != T T: int, # number of target phonemes ) -> List[AlignedPosition]: """ Run NW alignment independently on each of 35 feature sequences, then aggregate the results per target phoneme position. Strategy -------- Each feature gives its own alignment path. We collect, for each target position i, a vote over all 35 features about what actual position it maps to. The plurality actual index wins. If the majority vote is "gap" (deletion), the position is marked as a deletion. Then per position we reconstruct the actual feature bits from the voted actual index across all features. """ # votes[target_idx] → list of actual_idx votes (None = deletion vote) votes: List[List[Optional[int]]] = [[] for _ in range(T)] # per_feature_actual_idx[feat][target_idx] → actual_idx or None per_feat_map: List[Dict[int, Optional[int]]] = [ {} for _ in range(NUM_FEATURES) ] for feat_i in range(NUM_FEATURES): t_seq = target_feat_seqs[feat_i] # length T a_seq = actual_feat_seqs[feat_i] # length may differ path = _nw_align(t_seq, a_seq) for (ti, ai) in path: if ti is None: continue # insertion — no target position, skip votes[ti].append(ai) # ai may be None (deletion) per_feat_map[feat_i][ti] = ai # Resolve votes per target position aligned: List[AlignedPosition] = [] DELETION_VOTE_THRESHOLD = 0.5 # >50% gap votes → mark as DELETE for ti in range(T): v = votes[ti] non_null = [x for x in v if x is not None] null_count = len(v) - len(non_null) deletion_fraction = null_count / max(len(v), 1) if not non_null or deletion_fraction > DELETION_VOTE_THRESHOLD: chosen_ai = None else: # Plurality vote among non-null actual indices counts: Dict[int, int] = {} for idx in non_null: counts[idx] = counts.get(idx, 0) + 1 chosen_ai = max(counts, key=counts.__getitem__) # Build target and actual bit vectors for this position target_bits = [target_feat_seqs[f][ti] for f in range(NUM_FEATURES)] if chosen_ai is not None: actual_bits = [] for f in range(NUM_FEATURES): # Use per-feature actual value if this feature agrees on chosen_ai feat_ai = per_feat_map[f].get(ti, None) if feat_ai == chosen_ai: actual_bits.append(actual_feat_seqs[f][feat_ai] if feat_ai < len(actual_feat_seqs[f]) else 0) else: # Feature disagrees on the position — use its own aligned value fa = per_feat_map[f].get(ti, None) if fa is not None and fa < len(actual_feat_seqs[f]): actual_bits.append(actual_feat_seqs[f][fa]) else: actual_bits.append(0) # treat as absent op = MATCH if target_bits == actual_bits else MISMATCH else: actual_bits = [0] * NUM_FEATURES op = DELETE # Compute feature-level errors missing = [FEATURE_NAMES[f] for f in range(NUM_FEATURES) if target_bits[f] == 1 and actual_bits[f] == 0] extra = [FEATURE_NAMES[f] for f in range(NUM_FEATURES) if target_bits[f] == 0 and actual_bits[f] == 1] # Weighted accuracy: fraction of weighted features correctly produced correct_weight = sum( FEATURE_WEIGHTS[f] for f in range(NUM_FEATURES) if target_bits[f] == actual_bits[f] ) total_weight = float(FEATURE_WEIGHTS.sum()) accuracy = float(correct_weight / total_weight) aligned.append(AlignedPosition( target_idx=ti, actual_idx=chosen_ai, op=op, target_bits=target_bits, actual_bits=actual_bits, missing_features=missing, extra_features=extra, feature_accuracy=accuracy, )) return aligned # ───────────────────────────────────────────────────────────────────────────── # 5. Insertion detector # ───────────────────────────────────────────────────────────────────────────── def _count_insertions( actual_feat_seqs: List[List[int]], actual_len: int, aligned: List[AlignedPosition], ) -> int: """ Count actual positions that were voted as insertions (not mapped to any target position) by the majority of features. """ used_actual = set( ap.actual_idx for ap in aligned if ap.actual_idx is not None ) inserted = set(range(actual_len)) - used_actual return len(inserted) # ───────────────────────────────────────────────────────────────────────────── # 6. Severity classifier # ───────────────────────────────────────────────────────────────────────────── # Thresholds on weighted feature error rate _SEV = {"mild": 0.85, "moderate": 0.65} # accuracy thresholds (higher = easier) def _severity(accuracy: float, is_deletion: bool) -> str: if is_deletion: return "severe" if accuracy >= _SEV["mild"]: return "mild" if accuracy >= _SEV["moderate"]: return "moderate" return "severe" # ───────────────────────────────────────────────────────────────────────────── # 7. Scorer # ───────────────────────────────────────────────────────────────────────────── def _score_utterance(aligned: List[AlignedPosition]) -> Tuple[float, List[float]]: """ Per-phoneme score: weighted feature accuracy (0-1). Deletions score 0. Utterance score: weighted mean, penalising deletions most. """ phoneme_scores = [ap.feature_accuracy for ap in aligned] utterance_score = float(np.mean(phoneme_scores)) * 100.0 return utterance_score, phoneme_scores # ───────────────────────────────────────────────────────────────────────────── # 8. Error list builder # ───────────────────────────────────────────────────────────────────────────── def _build_errors( aligned: List[AlignedPosition], target_phonemes: List[str], ) -> List[PhonemeError]: errors = [] for ap in aligned: if ap.op == MATCH and not ap.missing_features and not ap.extra_features: continue # perfectly correct, no error to report errors.append(PhonemeError( position=ap.target_idx, target_phoneme=target_phonemes[ap.target_idx], missing_features=ap.missing_features, extra_features=ap.extra_features, is_deletion=(ap.op == DELETE), feature_accuracy=ap.feature_accuracy, severity=_severity(ap.feature_accuracy, ap.op == DELETE), )) return errors # ───────────────────────────────────────────────────────────────────────────── # 9. Aggregate feature error counts # ───────────────────────────────────────────────────────────────────────────── def _aggregate(errors: List[PhonemeError]) -> Dict[str, int]: counts: Dict[str, int] = {} for e in errors: for f in e.missing_features + e.extra_features: counts[f] = counts.get(f, 0) + 1 return dict(sorted(counts.items(), key=lambda x: -x[1])) # ───────────────────────────────────────────────────────────────────────────── # 10. Public entry point # ───────────────────────────────────────────────────────────────────────────── def run_mdd( actual_feature_seqs: List[List[int]], target_phonemes: List[str], ) -> MDDResult: """ Full MDD pipeline for a CTC phonological-feature model. Parameters ---------- actual_feature_seqs : list of 35 lists of int (0 or 1) CTC-decoded output of your model, AFTER blank removal and run-length collapsing. Each list is the decoded +att/−att sequence for one feature. Lengths may differ from each other and from len(target_phonemes). Index order must match PHONOLOGICAL_FEATURES / FEATURE_NAMES. Concretely, if your model outputs logits of shape (T_audio, 71): nodes 0-34 = +att for features 0-34 nodes 35-69 = -att for features 0-34 node 70 = blank Then for feature i, the CTC-decoded sequence is a list of 0s and 1s (1 = +att node fired, 0 = -att node fired), blanks removed. target_phonemes : list of str CMU ARPAbet phoneme sequence from the user's typed sentence. Obtain via any G2P tool, e.g. g2p_en: from g2p_en import G2p target_phonemes = G2p()(sentence) Returns ------- MDDResult """ assert len(actual_feature_seqs) == 35, \ f"Expected 35 feature sequences, got {len(actual_feature_seqs)}" assert len(target_phonemes) > 0, "target_phonemes must not be empty" T = len(target_phonemes) # Build canonical target feature sequences from the phoneme labels target_feat_seqs: List[List[int]] = phoneme_sequence_to_feature_sequences( target_phonemes ) # 35 lists, each of length T # Actual lengths (for insertion counting) actual_len = max((len(s) for s in actual_feature_seqs), default=0) # Step 1: per-feature NW alignment → per target-position feature bits aligned = _align_all_features(target_feat_seqs, actual_feature_seqs, T) # Step 2: count structural errors deletions = sum(1 for ap in aligned if ap.op == DELETE) insertions = _count_insertions(actual_feature_seqs, actual_len, aligned) # Step 3: score utt_score, phoneme_scores = _score_utterance(aligned) # Step 4: build error list errors = _build_errors(aligned, target_phonemes) # Step 5: aggregate feature error counts feat_error_counts = _aggregate(errors) return MDDResult( utterance_score=utt_score, phoneme_scores=phoneme_scores, errors=errors, aligned_positions=aligned, feature_error_counts=feat_error_counts, deletion_count=deletions, insertion_count=insertions, ) # ───────────────────────────────────────────────────────────────────────────── # 11. CTC decode helper (use this on raw model logits) # ───────────────────────────────────────────────────────────────────────────── def ctc_decode_feature_seqs( logits: np.ndarray, # (T_audio, 71) — raw model output per frame blank_idx: int = 70, ) -> List[List[int]]: """ Greedy CTC decode for a phonological feature model with 71 output nodes. For each of the 35 features independently: 1. At each frame, pick argmax between pos_node (feat_i) and neg_node (feat_i+35) (ignoring blank). 2. Collapse runs and remove frames where blank wins overall. 3. Return the sequence of 1s (+att) and 0s (-att). Parameters ---------- logits : np.ndarray (T_audio, 71) Raw model output before softmax. If you've already applied softmax, pass probabilities — the argmax logic is identical. blank_idx : int Index of the shared blank node (default 70). Returns ------- List of 35 lists of int (0 or 1), CTC-decoded. """ T_audio = logits.shape[0] feature_seqs: List[List[int]] = [[] for _ in range(35)] for feat_i in range(35): pos_node = feat_i # +att node neg_node = feat_i + 35 # -att node prev_label = None for t in range(T_audio): frame = logits[t] best_overall = int(np.argmax(frame)) if best_overall == blank_idx: prev_label = None # blank resets run continue # Among pos/neg for this feature, pick the winner label = 1 if frame[pos_node] >= frame[neg_node] else 0 # CTC run-length collapse if label != prev_label: feature_seqs[feat_i].append(label) prev_label = label return feature_seqs