Pronunciation-Coach / mdd_engine.py
heldtomaturity's picture
initial app deploy
0515ef3
"""
MDD Engine β€” Mispronunciation Detection and Diagnosis
=====================================================
Architecture (Shahin et al. 2025)
----------------------------------
Your model runs 35 independent CTC decoders, one per phonological feature.
Each decoder outputs a sequence of +att(1) / -att(0) labels, with blanks
already removed and runs collapsed β€” so the output length reflects the number
of detected phoneme-level events, NOT audio frames.
The canonical target comes from the user's typed sentence:
sentence β†’ G2P (CMU ARPAbet) β†’ phoneme_sequence_to_feature_sequences()
β†’ 35 binary label sequences of length T (number of target phonemes)
The problem: the actual decoded sequence per feature may have a DIFFERENT
length than T, because the student may have:
- deleted phonemes (actual shorter than target)
- inserted extras (actual longer than target)
- substituted (same length, wrong labels)
Solution: Needleman-Wunsch (global sequence alignment) per feature
------------------------------------------------------------------
For each of the 35 features we run a global pairwise alignment between the
target binary sequence and the actual binary sequence. This gives us an
explicit alignment path with match / mismatch / insertion / deletion ops.
We then aggregate across all 35 features to get, per target phoneme position:
- which actual position it maps to (or DELETION if no match)
- which features are missing (+att in target, -att or gap in actual)
- which features are extra (-att in target, +att in actual)
- a weighted feature accuracy score
This is the standard approach in phonological MDD literature when no frame-
level forced alignment is available (see e.g. Lee & Glass 2015, Leung et al.
2019, and the feature-based MDD track of the AIP challenge).
Input/output contract
---------------------
actual_feature_seqs : list[list[int]] β€” 35 lists, each decoded CTC output
Values: 1 (+att) or 0 (-att)
Lengths may differ across features
and from the canonical length T
target_phonemes : list[str] β€” CMU ARPAbet phoneme sequence from
the user's typed sentence, length T
Output: MDDResult (see dataclass below)
"""
from __future__ import annotations
import numpy as np
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional
from phonological_features import (
PHONOLOGICAL_FEATURES,
phoneme_sequence_to_feature_sequences,
phoneme_to_feature_vector,
)
# ─────────────────────────────────────────────────────────────────────────────
# 1. Feature schema & weights
# ─────────────────────────────────────────────────────────────────────────────
FEATURE_NAMES: List[str] = PHONOLOGICAL_FEATURES # 35 features, canonical order
NUM_FEATURES = len(FEATURE_NAMES) # 35
assert NUM_FEATURES == 35
F2I: Dict[str, int] = {f: i for i, f in enumerate(FEATURE_NAMES)}
# Perceptual salience weights β€” higher = more important mismatch.
# Manner errors (wrong sound class) are most disruptive.
# Voicing errors are highly salient in English.
# Place errors matter but less so than manner.
# Length/type distinctions are least salient in L2 MDD.
FEATURE_WEIGHTS: np.ndarray = np.array([
# Manners (11): consonant sonorant fricative nasal stop
2.0, 1.5, 1.8, 2.0, 2.0,
# approximant affricate liquid vowel semivowel continuant
1.5, 1.8, 1.5, 2.0, 1.5, 1.2,
# Places (18): alveolar palatal dental glottal labial velar
1.5, 1.4, 1.3, 1.2, 1.5, 1.5,
# mid high low front back central
1.8, 1.8, 1.8, 1.6, 1.6, 1.2,
# anterior posterior retroflex bilabial coronal dorsal
1.3, 1.3, 1.3, 1.4, 1.3, 1.3,
# Others (6): long short monophthong diphthong round voiced
1.0, 1.0, 1.2, 1.2, 1.2, 2.5,
], dtype=np.float32)
assert len(FEATURE_WEIGHTS) == 35
# Alignment op codes
MATCH = 0 # same label, same position
MISMATCH = 1 # different label, same position
DELETE = 2 # target has event, actual has gap (deletion error)
INSERT = 3 # actual has event, target has gap (insertion error)
# NW scoring scheme
MATCH_SCORE = 2
MISMATCH_SCORE = -1
GAP_PENALTY = -2 # penalises deletions and insertions equally
# ─────────────────────────────────────────────────────────────────────────────
# 2. Data classes
# ─────────────────────────────────────────────────────────────────────────────
@dataclass
class AlignedPosition:
"""One position in the target sequence after multi-feature alignment."""
target_idx: int # index in target phoneme sequence
actual_idx: Optional[int] # index in actual sequence, None = deletion
op: int # MATCH / MISMATCH / DELETE / INSERT
target_bits: List[int] # canonical feature vector (35 bits)
actual_bits: List[int] # observed feature vector (35 bits, 0 if deleted)
missing_features: List[str] # +att in target, -att or gap in actual
extra_features: List[str] # -att in target, +att in actual
feature_accuracy: float # weighted accuracy 0-1
@dataclass
class PhonemeError:
"""One mispronounced phoneme with its full feature-level diagnosis."""
position: int # index in target sequence
target_phoneme: str # ARPAbet label from typed sentence
missing_features: List[str] # features the student failed to produce
extra_features: List[str] # features the student added erroneously
is_deletion: bool # student dropped this phoneme entirely
feature_accuracy: float # 0-1
severity: str # "mild" | "moderate" | "severe"
@dataclass
class MDDResult:
utterance_score: float # 0-100
phoneme_scores: List[float] # per target phoneme, 0-1
errors: List[PhonemeError]
aligned_positions: List[AlignedPosition]
feature_error_counts: Dict[str, int] # aggregated across all phonemes
deletion_count: int
insertion_count: int
# ─────────────────────────────────────────────────────────────────────────────
# 3. Needleman-Wunsch per-feature aligner
# ─────────────────────────────────────────────────────────────────────────────
def _nw_align(target_seq: List[int],
actual_seq: List[int]) -> List[Tuple[Optional[int], Optional[int]]]:
"""
Global sequence alignment (Needleman-Wunsch) for two binary label sequences.
Returns a list of (target_idx, actual_idx) pairs where:
(i, j) β†’ match or mismatch at target[i], actual[j]
(i, None) β†’ deletion: target[i] has no corresponding actual event
(None, j) β†’ insertion: actual[j] has no corresponding target event
Binary values: 1 = +att, 0 = -att
"""
T = len(target_seq)
A = len(actual_seq)
# Fill score matrix
score = np.zeros((T + 1, A + 1), dtype=np.float32)
score[0, :] = np.arange(A + 1) * GAP_PENALTY
score[:, 0] = np.arange(T + 1) * GAP_PENALTY
for i in range(1, T + 1):
for j in range(1, A + 1):
s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE
score[i, j] = max(
score[i-1, j-1] + s, # match/mismatch
score[i-1, j] + GAP_PENALTY, # deletion
score[i, j-1] + GAP_PENALTY, # insertion
)
# Traceback
path: List[Tuple[Optional[int], Optional[int]]] = []
i, j = T, A
while i > 0 or j > 0:
if i > 0 and j > 0:
s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE
if score[i, j] == score[i-1, j-1] + s:
path.append((i-1, j-1))
i -= 1; j -= 1
continue
if i > 0 and score[i, j] == score[i-1, j] + GAP_PENALTY:
path.append((i-1, None)) # deletion
i -= 1
else:
path.append((None, j-1)) # insertion
j -= 1
path.reverse()
return path
# ─────────────────────────────────────────────────────────────────────────────
# 4. Multi-feature alignment aggregator
# ─────────────────────────────────────────────────────────────────────────────
def _align_all_features(
target_feat_seqs: List[List[int]], # 35 lists, each length T
actual_feat_seqs: List[List[int]], # 35 lists, each possibly != T
T: int, # number of target phonemes
) -> List[AlignedPosition]:
"""
Run NW alignment independently on each of 35 feature sequences, then
aggregate the results per target phoneme position.
Strategy
--------
Each feature gives its own alignment path. We collect, for each target
position i, a vote over all 35 features about what actual position it
maps to. The plurality actual index wins. If the majority vote is "gap"
(deletion), the position is marked as a deletion.
Then per position we reconstruct the actual feature bits from the voted
actual index across all features.
"""
# votes[target_idx] β†’ list of actual_idx votes (None = deletion vote)
votes: List[List[Optional[int]]] = [[] for _ in range(T)]
# per_feature_actual_idx[feat][target_idx] β†’ actual_idx or None
per_feat_map: List[Dict[int, Optional[int]]] = [
{} for _ in range(NUM_FEATURES)
]
for feat_i in range(NUM_FEATURES):
t_seq = target_feat_seqs[feat_i] # length T
a_seq = actual_feat_seqs[feat_i] # length may differ
path = _nw_align(t_seq, a_seq)
for (ti, ai) in path:
if ti is None:
continue # insertion β€” no target position, skip
votes[ti].append(ai) # ai may be None (deletion)
per_feat_map[feat_i][ti] = ai
# Resolve votes per target position
aligned: List[AlignedPosition] = []
DELETION_VOTE_THRESHOLD = 0.5 # >50% gap votes β†’ mark as DELETE
for ti in range(T):
v = votes[ti]
non_null = [x for x in v if x is not None]
null_count = len(v) - len(non_null)
deletion_fraction = null_count / max(len(v), 1)
if not non_null or deletion_fraction > DELETION_VOTE_THRESHOLD:
chosen_ai = None
else:
# Plurality vote among non-null actual indices
counts: Dict[int, int] = {}
for idx in non_null:
counts[idx] = counts.get(idx, 0) + 1
chosen_ai = max(counts, key=counts.__getitem__)
# Build target and actual bit vectors for this position
target_bits = [target_feat_seqs[f][ti] for f in range(NUM_FEATURES)]
if chosen_ai is not None:
actual_bits = []
for f in range(NUM_FEATURES):
# Use per-feature actual value if this feature agrees on chosen_ai
feat_ai = per_feat_map[f].get(ti, None)
if feat_ai == chosen_ai:
actual_bits.append(actual_feat_seqs[f][feat_ai]
if feat_ai < len(actual_feat_seqs[f]) else 0)
else:
# Feature disagrees on the position β€” use its own aligned value
fa = per_feat_map[f].get(ti, None)
if fa is not None and fa < len(actual_feat_seqs[f]):
actual_bits.append(actual_feat_seqs[f][fa])
else:
actual_bits.append(0) # treat as absent
op = MATCH if target_bits == actual_bits else MISMATCH
else:
actual_bits = [0] * NUM_FEATURES
op = DELETE
# Compute feature-level errors
missing = [FEATURE_NAMES[f] for f in range(NUM_FEATURES)
if target_bits[f] == 1 and actual_bits[f] == 0]
extra = [FEATURE_NAMES[f] for f in range(NUM_FEATURES)
if target_bits[f] == 0 and actual_bits[f] == 1]
# Weighted accuracy: fraction of weighted features correctly produced
correct_weight = sum(
FEATURE_WEIGHTS[f]
for f in range(NUM_FEATURES)
if target_bits[f] == actual_bits[f]
)
total_weight = float(FEATURE_WEIGHTS.sum())
accuracy = float(correct_weight / total_weight)
aligned.append(AlignedPosition(
target_idx=ti,
actual_idx=chosen_ai,
op=op,
target_bits=target_bits,
actual_bits=actual_bits,
missing_features=missing,
extra_features=extra,
feature_accuracy=accuracy,
))
return aligned
# ─────────────────────────────────────────────────────────────────────────────
# 5. Insertion detector
# ─────────────────────────────────────────────────────────────────────────────
def _count_insertions(
actual_feat_seqs: List[List[int]],
actual_len: int,
aligned: List[AlignedPosition],
) -> int:
"""
Count actual positions that were voted as insertions (not mapped to any
target position) by the majority of features.
"""
used_actual = set(
ap.actual_idx for ap in aligned if ap.actual_idx is not None
)
inserted = set(range(actual_len)) - used_actual
return len(inserted)
# ─────────────────────────────────────────────────────────────────────────────
# 6. Severity classifier
# ─────────────────────────────────────────────────────────────────────────────
# Thresholds on weighted feature error rate
_SEV = {"mild": 0.85, "moderate": 0.65} # accuracy thresholds (higher = easier)
def _severity(accuracy: float, is_deletion: bool) -> str:
if is_deletion:
return "severe"
if accuracy >= _SEV["mild"]:
return "mild"
if accuracy >= _SEV["moderate"]:
return "moderate"
return "severe"
# ─────────────────────────────────────────────────────────────────────────────
# 7. Scorer
# ─────────────────────────────────────────────────────────────────────────────
def _score_utterance(aligned: List[AlignedPosition]) -> Tuple[float, List[float]]:
"""
Per-phoneme score: weighted feature accuracy (0-1).
Deletions score 0.
Utterance score: weighted mean, penalising deletions most.
"""
phoneme_scores = [ap.feature_accuracy for ap in aligned]
utterance_score = float(np.mean(phoneme_scores)) * 100.0
return utterance_score, phoneme_scores
# ─────────────────────────────────────────────────────────────────────────────
# 8. Error list builder
# ─────────────────────────────────────────────────────────────────────────────
def _build_errors(
aligned: List[AlignedPosition],
target_phonemes: List[str],
) -> List[PhonemeError]:
errors = []
for ap in aligned:
if ap.op == MATCH and not ap.missing_features and not ap.extra_features:
continue # perfectly correct, no error to report
errors.append(PhonemeError(
position=ap.target_idx,
target_phoneme=target_phonemes[ap.target_idx],
missing_features=ap.missing_features,
extra_features=ap.extra_features,
is_deletion=(ap.op == DELETE),
feature_accuracy=ap.feature_accuracy,
severity=_severity(ap.feature_accuracy, ap.op == DELETE),
))
return errors
# ─────────────────────────────────────────────────────────────────────────────
# 9. Aggregate feature error counts
# ─────────────────────────────────────────────────────────────────────────────
def _aggregate(errors: List[PhonemeError]) -> Dict[str, int]:
counts: Dict[str, int] = {}
for e in errors:
for f in e.missing_features + e.extra_features:
counts[f] = counts.get(f, 0) + 1
return dict(sorted(counts.items(), key=lambda x: -x[1]))
# ─────────────────────────────────────────────────────────────────────────────
# 10. Public entry point
# ─────────────────────────────────────────────────────────────────────────────
def run_mdd(
actual_feature_seqs: List[List[int]],
target_phonemes: List[str],
) -> MDDResult:
"""
Full MDD pipeline for a CTC phonological-feature model.
Parameters
----------
actual_feature_seqs : list of 35 lists of int (0 or 1)
CTC-decoded output of your model, AFTER blank removal and run-length
collapsing. Each list is the decoded +att/βˆ’att sequence for one feature.
Lengths may differ from each other and from len(target_phonemes).
Index order must match PHONOLOGICAL_FEATURES / FEATURE_NAMES.
Concretely, if your model outputs logits of shape (T_audio, 71):
nodes 0-34 = +att for features 0-34
nodes 35-69 = -att for features 0-34
node 70 = blank
Then for feature i, the CTC-decoded sequence is a list of 0s and 1s
(1 = +att node fired, 0 = -att node fired), blanks removed.
target_phonemes : list of str
CMU ARPAbet phoneme sequence from the user's typed sentence.
Obtain via any G2P tool, e.g. g2p_en:
from g2p_en import G2p
target_phonemes = G2p()(sentence)
Returns
-------
MDDResult
"""
assert len(actual_feature_seqs) == 35, \
f"Expected 35 feature sequences, got {len(actual_feature_seqs)}"
assert len(target_phonemes) > 0, "target_phonemes must not be empty"
T = len(target_phonemes)
# Build canonical target feature sequences from the phoneme labels
target_feat_seqs: List[List[int]] = phoneme_sequence_to_feature_sequences(
target_phonemes
) # 35 lists, each of length T
# Actual lengths (for insertion counting)
actual_len = max((len(s) for s in actual_feature_seqs), default=0)
# Step 1: per-feature NW alignment β†’ per target-position feature bits
aligned = _align_all_features(target_feat_seqs, actual_feature_seqs, T)
# Step 2: count structural errors
deletions = sum(1 for ap in aligned if ap.op == DELETE)
insertions = _count_insertions(actual_feature_seqs, actual_len, aligned)
# Step 3: score
utt_score, phoneme_scores = _score_utterance(aligned)
# Step 4: build error list
errors = _build_errors(aligned, target_phonemes)
# Step 5: aggregate feature error counts
feat_error_counts = _aggregate(errors)
return MDDResult(
utterance_score=utt_score,
phoneme_scores=phoneme_scores,
errors=errors,
aligned_positions=aligned,
feature_error_counts=feat_error_counts,
deletion_count=deletions,
insertion_count=insertions,
)
# ─────────────────────────────────────────────────────────────────────────────
# 11. CTC decode helper (use this on raw model logits)
# ─────────────────────────────────────────────────────────────────────────────
def ctc_decode_feature_seqs(
logits: np.ndarray, # (T_audio, 71) β€” raw model output per frame
blank_idx: int = 70,
) -> List[List[int]]:
"""
Greedy CTC decode for a phonological feature model with 71 output nodes.
For each of the 35 features independently:
1. At each frame, pick argmax between pos_node (feat_i) and neg_node (feat_i+35)
(ignoring blank).
2. Collapse runs and remove frames where blank wins overall.
3. Return the sequence of 1s (+att) and 0s (-att).
Parameters
----------
logits : np.ndarray (T_audio, 71)
Raw model output before softmax. If you've already applied softmax,
pass probabilities β€” the argmax logic is identical.
blank_idx : int
Index of the shared blank node (default 70).
Returns
-------
List of 35 lists of int (0 or 1), CTC-decoded.
"""
T_audio = logits.shape[0]
feature_seqs: List[List[int]] = [[] for _ in range(35)]
for feat_i in range(35):
pos_node = feat_i # +att node
neg_node = feat_i + 35 # -att node
prev_label = None
for t in range(T_audio):
frame = logits[t]
best_overall = int(np.argmax(frame))
if best_overall == blank_idx:
prev_label = None # blank resets run
continue
# Among pos/neg for this feature, pick the winner
label = 1 if frame[pos_node] >= frame[neg_node] else 0
# CTC run-length collapse
if label != prev_label:
feature_seqs[feat_i].append(label)
prev_label = label
return feature_seqs