Spaces:
Sleeping
Sleeping
Commit Β·
0515ef3
1
Parent(s): c35de4f
initial app deploy
Browse files- app.py +287 -0
- feedback_generator.py +690 -0
- mdd_engine.py +534 -0
- phonological_features.py +253 -0
- requirements.txt +17 -0
- wav2vec2_phonological.py +267 -0
app.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Mispronunciation Detection & Diagnosis β HuggingFace Space
|
| 3 |
+
===========================================================
|
| 4 |
+
Wires together:
|
| 5 |
+
1. PhonologicalWav2Vec2 (your best_model.pt, loaded once at cold start)
|
| 6 |
+
2. MDD engine (per-feature NW alignment β errors + score)
|
| 7 |
+
3. Feedback generator (rule engine + optional LLM rewriter)
|
| 8 |
+
|
| 9 |
+
Environment variables to set in Space β Settings β Variables and secrets:
|
| 10 |
+
HF_TOKEN (secret) β read token for your private model repo
|
| 11 |
+
HF_MODEL_REPO (variable) β e.g. "Backlighteu/phonological-mdd"
|
| 12 |
+
HF_MODEL_FILENAME (variable) β e.g. "best_model.pt" (default)
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import json
|
| 17 |
+
import torch
|
| 18 |
+
import numpy as np
|
| 19 |
+
import gradio as gr
|
| 20 |
+
import librosa
|
| 21 |
+
|
| 22 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 23 |
+
from transformers import Wav2Vec2FeatureExtractor
|
| 24 |
+
|
| 25 |
+
from wav2vec2_phonological import PhonologicalWav2Vec2
|
| 26 |
+
from mdd_engine import run_mdd
|
| 27 |
+
from feedback_generator import generate_feedback
|
| 28 |
+
from phonological_features import (
|
| 29 |
+
CMU_39_PHONEMES,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
# 1. Model β loaded once at cold start, reused for every request
|
| 34 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 35 |
+
|
| 36 |
+
_model = None
|
| 37 |
+
_feature_extractor = None
|
| 38 |
+
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 39 |
+
|
| 40 |
+
PRETRAINED_BASE = "facebook/wav2vec2-large-robust"
|
| 41 |
+
MODEL_REPO = os.environ.get("HF_MODEL_REPO", "Backlighteu/phonological-mdd")
|
| 42 |
+
MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "best_model.pt")
|
| 43 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def load_model():
|
| 47 |
+
global _model, _feature_extractor
|
| 48 |
+
|
| 49 |
+
if _model is not None:
|
| 50 |
+
return
|
| 51 |
+
|
| 52 |
+
# Download entire repo into ./model_cache once, then load from disk.
|
| 53 |
+
# hf_hub_download checks cache first β no re-download if already present.
|
| 54 |
+
print(f"[startup] Caching {MODEL_REPO} to ./model_cache ...")
|
| 55 |
+
snapshot_download(
|
| 56 |
+
repo_id=MODEL_REPO,
|
| 57 |
+
token=HF_TOKEN,
|
| 58 |
+
local_dir="./model_cache",
|
| 59 |
+
)
|
| 60 |
+
weights_path = "./model_cache/best_model.pt"
|
| 61 |
+
print(f"[startup] Loading weights from {weights_path}")
|
| 62 |
+
|
| 63 |
+
model = PhonologicalWav2Vec2(
|
| 64 |
+
pretrained_model_name=PRETRAINED_BASE,
|
| 65 |
+
num_output_nodes=71,
|
| 66 |
+
freeze_cnn_encoder=True,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
state_dict = torch.load(weights_path, map_location=_device)
|
| 70 |
+
model.load_state_dict(state_dict)
|
| 71 |
+
model.to(_device)
|
| 72 |
+
model.eval()
|
| 73 |
+
_model = model
|
| 74 |
+
print(f"[startup] Model ready on {_device}.")
|
| 75 |
+
|
| 76 |
+
print(f"[startup] Loading feature extractor from '{PRETRAINED_BASE}' ...")
|
| 77 |
+
_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(PRETRAINED_BASE)
|
| 78 |
+
print("[startup] Feature extractor ready.")
|
| 79 |
+
|
| 80 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 81 |
+
# 2. Audio β decoded feature sequences
|
| 82 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
|
| 84 |
+
TARGET_SR = 16_000
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def decode_audio(audio_path: str) -> list:
|
| 88 |
+
"""
|
| 89 |
+
Load audio, run the phonological model, return CTC-decoded feature seqs.
|
| 90 |
+
|
| 91 |
+
Returns
|
| 92 |
+
-------
|
| 93 |
+
actual_feature_seqs : list of 35 lists of int (0 or 1)
|
| 94 |
+
CTC-decoded +att / -att sequence for each of the 35 features.
|
| 95 |
+
"""
|
| 96 |
+
load_model()
|
| 97 |
+
|
| 98 |
+
waveform, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
|
| 99 |
+
waveform = waveform.astype(np.float32)
|
| 100 |
+
|
| 101 |
+
inputs = _feature_extractor(
|
| 102 |
+
waveform,
|
| 103 |
+
sampling_rate=TARGET_SR,
|
| 104 |
+
return_tensors="pt",
|
| 105 |
+
padding=True,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
input_values = inputs.input_values.to(_device)
|
| 109 |
+
attention_mask = inputs.get("attention_mask")
|
| 110 |
+
if attention_mask is not None:
|
| 111 |
+
attention_mask = attention_mask.to(_device)
|
| 112 |
+
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
logits, output_lengths = _model(
|
| 115 |
+
input_values,
|
| 116 |
+
attention_mask,
|
| 117 |
+
apply_spec_augment=False,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# model.decode() returns list[B][35][list[bool]] β True=+att, False=-att
|
| 121 |
+
decoded_batch = _model.decode(logits, output_lengths)
|
| 122 |
+
decoded_35 = decoded_batch[0] # [35][list[bool]]
|
| 123 |
+
|
| 124 |
+
# Convert bool β int (1/0)
|
| 125 |
+
actual_feature_seqs = [
|
| 126 |
+
[1 if v else 0 for v in feat_seq]
|
| 127 |
+
for feat_seq in decoded_35
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
return actual_feature_seqs
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ββββββββββββββββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 134 |
+
# 3. Text β canonical phoneme sequence
|
| 135 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 136 |
+
|
| 137 |
+
_VALID_PHONEMES = set(CMU_39_PHONEMES) | {"sil"}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def parse_phoneme_input(text: str) -> list:
|
| 141 |
+
"""
|
| 142 |
+
Accept space-separated CMU ARPAbet tokens typed by the user.
|
| 143 |
+
Unknown tokens are skipped with a warning.
|
| 144 |
+
"""
|
| 145 |
+
tokens = text.lower().split()
|
| 146 |
+
valid, skipped = [], []
|
| 147 |
+
for t in tokens:
|
| 148 |
+
if t in _VALID_PHONEMES:
|
| 149 |
+
valid.append(t)
|
| 150 |
+
else:
|
| 151 |
+
skipped.append(t)
|
| 152 |
+
if skipped:
|
| 153 |
+
print(f"[warning] Unrecognised tokens skipped: {skipped}")
|
| 154 |
+
return valid if valid else ["sil"]
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 158 |
+
# 4. Gradio processing function
|
| 159 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 160 |
+
|
| 161 |
+
def process(audio_input, script_text, use_llm, max_issues):
|
| 162 |
+
if audio_input is None:
|
| 163 |
+
return "Please record or upload audio first.", "", "{}"
|
| 164 |
+
|
| 165 |
+
script_text = script_text.strip()
|
| 166 |
+
if not script_text:
|
| 167 |
+
return (
|
| 168 |
+
"Please type the target sentence as ARPAbet phoneme tokens.\n"
|
| 169 |
+
"Example: `dh ae k ae t` for 'the cat'",
|
| 170 |
+
"", "{}",
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
try:
|
| 174 |
+
actual_feature_seqs = decode_audio(audio_input)
|
| 175 |
+
except Exception as e:
|
| 176 |
+
return f"Audio processing error: {e}", "", "{}"
|
| 177 |
+
|
| 178 |
+
target_phonemes = parse_phoneme_input(script_text)
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
result = run_mdd(
|
| 182 |
+
actual_feature_seqs=actual_feature_seqs,
|
| 183 |
+
target_phonemes=target_phonemes,
|
| 184 |
+
)
|
| 185 |
+
except Exception as e:
|
| 186 |
+
return f"MDD engine error: {e}", "", "{}"
|
| 187 |
+
|
| 188 |
+
feedback_dict = generate_feedback(
|
| 189 |
+
result,
|
| 190 |
+
use_llm=use_llm,
|
| 191 |
+
max_issues=int(max_issues),
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
score = feedback_dict["score"]
|
| 195 |
+
main_feedback = (
|
| 196 |
+
f"**Pronunciation Score: {score}/100**\n\n"
|
| 197 |
+
+ feedback_dict["final_feedback"]
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
detail_lines = ["### Per-phoneme detail\n"]
|
| 201 |
+
for e in feedback_dict["error_summary"]:
|
| 202 |
+
deletion_tag = " *(deleted)*" if e.get("is_deletion") else ""
|
| 203 |
+
detail_lines.append(
|
| 204 |
+
f"- **/{e['target']}/** (pos {e['position']}){deletion_tag}: "
|
| 205 |
+
f"severity=`{e['severity']}`, accuracy={e['accuracy']:.0%}\n"
|
| 206 |
+
f" - Missing: {', '.join(e['missing_features']) or 'β'}\n"
|
| 207 |
+
f" - Extra: {', '.join(e['extra_features']) or 'β'}"
|
| 208 |
+
)
|
| 209 |
+
if not feedback_dict["error_summary"]:
|
| 210 |
+
detail_lines.append("No feature-level errors detected β great pronunciation!")
|
| 211 |
+
|
| 212 |
+
detail_text = "\n".join(detail_lines)
|
| 213 |
+
|
| 214 |
+
json_output = json.dumps({
|
| 215 |
+
"score": feedback_dict["score"],
|
| 216 |
+
"deletion_count": result.deletion_count,
|
| 217 |
+
"insertion_count": result.insertion_count,
|
| 218 |
+
"feature_error_counts": feedback_dict["feature_error_counts"],
|
| 219 |
+
"rules_triggered": feedback_dict["rules_triggered"],
|
| 220 |
+
"target_phonemes": target_phonemes,
|
| 221 |
+
"actual_seq_lengths": [len(s) for s in actual_feature_seqs],
|
| 222 |
+
}, indent=2)
|
| 223 |
+
|
| 224 |
+
return main_feedback, detail_text, json_output
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 228 |
+
# 5. Gradio UI
|
| 229 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 230 |
+
|
| 231 |
+
VALID_PHONEME_LIST = ", ".join(sorted(CMU_39_PHONEMES))
|
| 232 |
+
|
| 233 |
+
with gr.Blocks(title="Pronunciation Coach", theme=gr.themes.Soft()) as demo:
|
| 234 |
+
gr.Markdown(
|
| 235 |
+
"""
|
| 236 |
+
# Pronunciation Coach
|
| 237 |
+
Speak a sentence, type what you meant to say as **ARPAbet phoneme tokens**,
|
| 238 |
+
and get phonological-feature-level feedback with articulation tips.
|
| 239 |
+
"""
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
with gr.Row():
|
| 243 |
+
with gr.Column(scale=1):
|
| 244 |
+
audio_input = gr.Audio(
|
| 245 |
+
sources=["microphone", "upload"],
|
| 246 |
+
type="filepath",
|
| 247 |
+
label="Your speech",
|
| 248 |
+
)
|
| 249 |
+
script_input = gr.Textbox(
|
| 250 |
+
label="Target sentence β space-separated ARPAbet tokens",
|
| 251 |
+
placeholder="e.g. dh ae k ae t (= 'the cat')",
|
| 252 |
+
lines=2,
|
| 253 |
+
)
|
| 254 |
+
with gr.Accordion("Valid phoneme tokens", open=False):
|
| 255 |
+
gr.Markdown(f"`{VALID_PHONEME_LIST}`")
|
| 256 |
+
with gr.Row():
|
| 257 |
+
use_llm = gr.Checkbox(value=False, label="LLM feedback rewriter")
|
| 258 |
+
max_issues = gr.Slider(1, 5, value=3, step=1, label="Max issues shown")
|
| 259 |
+
submit_btn = gr.Button("Analyse", variant="primary")
|
| 260 |
+
|
| 261 |
+
with gr.Column(scale=2):
|
| 262 |
+
feedback_out = gr.Markdown(label="Coaching feedback")
|
| 263 |
+
with gr.Accordion("Per-phoneme detail", open=False):
|
| 264 |
+
detail_out = gr.Markdown()
|
| 265 |
+
with gr.Accordion("Raw JSON (developers)", open=False):
|
| 266 |
+
json_out = gr.Code(language="json")
|
| 267 |
+
|
| 268 |
+
submit_btn.click(
|
| 269 |
+
fn=process,
|
| 270 |
+
inputs=[audio_input, script_input, use_llm, max_issues],
|
| 271 |
+
outputs=[feedback_out, detail_out, json_out],
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
gr.Markdown(
|
| 275 |
+
"""
|
| 276 |
+
---
|
| 277 |
+
**How to enter the target sentence:**
|
| 278 |
+
Convert your sentence to ARPAbet using the
|
| 279 |
+
[CMU Pronouncing Dictionary](http://www.speech.cs.cmu.edu/cgi-bin/cmudict)
|
| 280 |
+
then paste the space-separated tokens here.
|
| 281 |
+
Example: *"the cat sat"* β `dh ax k ae t s ae t`
|
| 282 |
+
"""
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
if __name__ == "__main__":
|
| 287 |
+
demo.launch()
|
feedback_generator.py
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Feedback Generator
|
| 3 |
+
==================
|
| 4 |
+
Two-layer system:
|
| 5 |
+
Layer 1 β Rule engine: maps specific feature errors to expert articulatory cues
|
| 6 |
+
Layer 2 β LLM rewriter: takes rule outputs and rewrites them into natural,
|
| 7 |
+
encouraging coach-like language via a lightweight local model
|
| 8 |
+
(or cloud fallback).
|
| 9 |
+
|
| 10 |
+
The rule templates are the ground truth; the LLM only adds warmth and fluency.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
import os
|
| 15 |
+
import json
|
| 16 |
+
import textwrap
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import List, Dict, Optional, Tuple
|
| 19 |
+
from mdd_engine import PhonemeError, MDDResult, FEATURE_NAMES
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 23 |
+
# 1. Articulatory feedback rule bank
|
| 24 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
# Each rule = {trigger_features, direction, tip, drill, self_check}
|
| 26 |
+
# direction: "missing" | "extra" | "both"
|
| 27 |
+
|
| 28 |
+
FEATURE_RULES: List[Dict] = [
|
| 29 |
+
# ββ VOICING (Others group) ββββββββββββββββββββββββββββββββββββββββββββ
|
| 30 |
+
{
|
| 31 |
+
"features": ["voiced"],
|
| 32 |
+
"direction": "missing",
|
| 33 |
+
"tip": (
|
| 34 |
+
"Your vocal cords are not vibrating when they should be. "
|
| 35 |
+
"Place two fingers lightly on your throat (the Adam's apple area). "
|
| 36 |
+
"Now say the sound β if you feel vibration, you've got it. "
|
| 37 |
+
"Try humming first ('mmm'), then slide into the target sound."
|
| 38 |
+
),
|
| 39 |
+
"drill": "Practice pairs: /f/ β /v/, /s/ β /z/, /p/ β /b/. "
|
| 40 |
+
"Feel the buzz turn on for the second sound each time.",
|
| 41 |
+
"self_check": "Put your hand on your throat. You should feel a gentle buzz.",
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"features": ["voiced"],
|
| 45 |
+
"direction": "extra",
|
| 46 |
+
"tip": (
|
| 47 |
+
"You are voicing a sound that should be voiceless β your vocal cords "
|
| 48 |
+
"are buzzing when they should be still. "
|
| 49 |
+
"Whisper the sound first to train your cords to stay quiet, "
|
| 50 |
+
"then gradually add breath pressure without the buzz."
|
| 51 |
+
),
|
| 52 |
+
"drill": "Whisper-shout drill: whisper /p/, /t/, /k/, /f/, /s/ ten times.",
|
| 53 |
+
"self_check": "Put your hand on your throat. It should feel still, no vibration.",
|
| 54 |
+
},
|
| 55 |
+
|
| 56 |
+
# ββ MANNER: STOP βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 57 |
+
{
|
| 58 |
+
"features": ["stop"],
|
| 59 |
+
"direction": "missing",
|
| 60 |
+
"tip": (
|
| 61 |
+
"This sound needs a full closure in your mouth β air must be completely "
|
| 62 |
+
"blocked and then released in a burst. "
|
| 63 |
+
"Your tongue or lips are not making a tight enough seal, letting air trickle "
|
| 64 |
+
"through instead of building up pressure."
|
| 65 |
+
),
|
| 66 |
+
"drill": "Tap your fingers on the desk for each stop: /p/ β /t/ β /k/. "
|
| 67 |
+
"Feel the 'pop' as pressure releases each time.",
|
| 68 |
+
"self_check": "Before the release, you should feel air pressure building behind the closure.",
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"features": ["stop"],
|
| 72 |
+
"direction": "extra",
|
| 73 |
+
"tip": (
|
| 74 |
+
"You are closing your airway completely when the sound should be continuous. "
|
| 75 |
+
"Relax the articulators and keep a small opening so air can flow through "
|
| 76 |
+
"without a burst."
|
| 77 |
+
),
|
| 78 |
+
"drill": "Say /s/ and /f/ β feel the continuous uninterrupted airflow, no pop.",
|
| 79 |
+
"self_check": "You should hear no 'pop' or sudden release β just steady air.",
|
| 80 |
+
},
|
| 81 |
+
|
| 82 |
+
# ββ MANNER: FRICATIVE ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 83 |
+
{
|
| 84 |
+
"features": ["fricative"],
|
| 85 |
+
"direction": "missing",
|
| 86 |
+
"tip": (
|
| 87 |
+
"This sound requires turbulent airflow β a hissing or buzzing quality. "
|
| 88 |
+
"Narrow the passage between your tongue (or lips) and the articulators just enough "
|
| 89 |
+
"that the air becomes turbulent. Too wide gives a vowel; full closure gives a stop."
|
| 90 |
+
),
|
| 91 |
+
"drill": "Hold /s/, /f/, /sh/ for three full seconds each. Feel the continuous friction.",
|
| 92 |
+
"self_check": "You should hear a clear hissing or buzzing sound throughout, not silence or a pop.",
|
| 93 |
+
},
|
| 94 |
+
|
| 95 |
+
# ββ MANNER: NASAL βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
{
|
| 97 |
+
"features": ["nasal"],
|
| 98 |
+
"direction": "missing",
|
| 99 |
+
"tip": (
|
| 100 |
+
"This sound requires airflow through your nose. "
|
| 101 |
+
"Pinch your nostrils closed β if the sound changes dramatically, "
|
| 102 |
+
"you were accidentally blocking nasal airflow. "
|
| 103 |
+
"Let air flow freely through your nose as you make the sound."
|
| 104 |
+
),
|
| 105 |
+
"drill": "Alternate: hum 'mmm' (nasal), then 'bbb' (not nasal). Feel the difference.",
|
| 106 |
+
"self_check": "Pinch your nose lightly β a nasal sound will feel 'stuffed up' when blocked.",
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"features": ["nasal"],
|
| 110 |
+
"direction": "extra",
|
| 111 |
+
"tip": (
|
| 112 |
+
"Your sound has unwanted nasality β air is leaking through your nose. "
|
| 113 |
+
"Practice lifting the soft palate by saying 'uh-oh' firmly, then keep that "
|
| 114 |
+
"lifted feeling while producing the target sound."
|
| 115 |
+
),
|
| 116 |
+
"drill": "Say 'back β bank', 'bad β band'. The first word of each pair is not nasal.",
|
| 117 |
+
"self_check": "Hold a mirror under your nose β it should not fog up.",
|
| 118 |
+
},
|
| 119 |
+
|
| 120 |
+
# ββ MANNER: AFFRICATE ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 121 |
+
{
|
| 122 |
+
"features": ["affricate"],
|
| 123 |
+
"direction": "missing",
|
| 124 |
+
"tip": (
|
| 125 |
+
"An affricate starts with a complete closure (like a stop) then releases "
|
| 126 |
+
"into a fricative β think of /ch/ in 'church' or /jh/ in 'judge'. "
|
| 127 |
+
"You are either skipping the closure or the friction release. "
|
| 128 |
+
"Make sure you feel both: a tight seal followed by a hissing release."
|
| 129 |
+
),
|
| 130 |
+
"drill": "Say 'ch-ch-ch' rapidly, feeling the tap-and-hiss for each one.",
|
| 131 |
+
"self_check": "You should feel a brief closure then turbulent airflow β two phases in one sound.",
|
| 132 |
+
},
|
| 133 |
+
|
| 134 |
+
# ββ MANNER: APPROXIMANT / LIQUID βββββββββββββββββββββββββββββββββββββ
|
| 135 |
+
{
|
| 136 |
+
"features": ["approximant", "liquid"],
|
| 137 |
+
"direction": "missing",
|
| 138 |
+
"tip": (
|
| 139 |
+
"This sound (/l/, /r/, /w/, /y/) needs your articulators to approach each other "
|
| 140 |
+
"closely without fully touching or creating friction. "
|
| 141 |
+
"Relax the contact β you may be pressing too hard and creating a stop, "
|
| 142 |
+
"or not shaping your mouth precisely enough."
|
| 143 |
+
),
|
| 144 |
+
"drill": "Say 'la-la-la' for /l/ and 'ra-ra-ra' for /r/ slowly, keeping the tongue light.",
|
| 145 |
+
"self_check": "There should be no pop and no hiss β just a smooth, resonant glide.",
|
| 146 |
+
},
|
| 147 |
+
|
| 148 |
+
# ββ MANNER: CONTINUANT βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 149 |
+
{
|
| 150 |
+
"features": ["continuant"],
|
| 151 |
+
"direction": "missing",
|
| 152 |
+
"tip": (
|
| 153 |
+
"This sound should have continuous, uninterrupted airflow β it is not a stop. "
|
| 154 |
+
"Keep your airway open and let air flow through for the full duration of the sound."
|
| 155 |
+
),
|
| 156 |
+
"drill": "Sustain /s/, /m/, /l/ or /v/ for three seconds without any interruption.",
|
| 157 |
+
"self_check": "You should be able to hold the sound indefinitely without cutting off air.",
|
| 158 |
+
},
|
| 159 |
+
|
| 160 |
+
# ββ PLACE: BILABIAL ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 161 |
+
{
|
| 162 |
+
"features": ["bilabial"],
|
| 163 |
+
"direction": "missing",
|
| 164 |
+
"tip": (
|
| 165 |
+
"This sound needs both lips pressed firmly together (/p/, /b/, /m/). "
|
| 166 |
+
"You may be making it with only one lip or further back in the mouth. "
|
| 167 |
+
"Press your lips together completely before releasing."
|
| 168 |
+
),
|
| 169 |
+
"drill": "Say 'pa-ba-ma' ten times, exaggerating full lip closure each time.",
|
| 170 |
+
"self_check": "Watch yourself in a mirror β both lips should close completely.",
|
| 171 |
+
},
|
| 172 |
+
|
| 173 |
+
# ββ PLACE: LABIAL (labiodental /f/, /v/) ββββββββββββββββββββββββββββ
|
| 174 |
+
{
|
| 175 |
+
"features": ["labial"],
|
| 176 |
+
"direction": "missing",
|
| 177 |
+
"tip": (
|
| 178 |
+
"This sound needs your lips to be active β either both lips together (bilabial: /p/, /b/, /m/) "
|
| 179 |
+
"or upper teeth touching the lower lip (labiodental: /f/, /v/). "
|
| 180 |
+
"You may be making the sound too far back with the tongue."
|
| 181 |
+
),
|
| 182 |
+
"drill": "Exaggerate lip contact. Say 'pop', 'bob', 'mom', 'five', 'very' in front of a mirror.",
|
| 183 |
+
"self_check": "Watch yourself in a mirror β you should see clear lip movement.",
|
| 184 |
+
},
|
| 185 |
+
|
| 186 |
+
# ββ PLACE: DENTAL ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 187 |
+
{
|
| 188 |
+
"features": ["dental"],
|
| 189 |
+
"direction": "missing",
|
| 190 |
+
"tip": (
|
| 191 |
+
"This sound (/th/, /dh/) requires your tongue tip to be right at or between your teeth. "
|
| 192 |
+
"Stick your tongue tip just between your upper and lower front teeth "
|
| 193 |
+
"and let air flow over it."
|
| 194 |
+
),
|
| 195 |
+
"drill": "Say 'think' and 'this' slowly, deliberately placing your tongue between your teeth each time.",
|
| 196 |
+
"self_check": "You should feel your tongue tip touching the edges of your front teeth.",
|
| 197 |
+
},
|
| 198 |
+
|
| 199 |
+
# ββ PLACE: ALVEOLAR ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 200 |
+
{
|
| 201 |
+
"features": ["alveolar"],
|
| 202 |
+
"direction": "missing",
|
| 203 |
+
"tip": (
|
| 204 |
+
"Your tongue tip needs to touch the alveolar ridge β the hard bump just behind "
|
| 205 |
+
"your upper front teeth. "
|
| 206 |
+
"This is the target for /t/, /d/, /n/, /s/, /z/, /l/. "
|
| 207 |
+
"You may be placing your tongue too far back or too far forward."
|
| 208 |
+
),
|
| 209 |
+
"drill": "Touch the ridge behind your upper teeth with your tongue tip and feel it. "
|
| 210 |
+
"Now tap /t/ ten times, always returning to that exact spot.",
|
| 211 |
+
"self_check": "Is your tongue tip touching the hard ridge β not the teeth and not the palate?",
|
| 212 |
+
},
|
| 213 |
+
|
| 214 |
+
# ββ PLACE: PALATAL ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 215 |
+
{
|
| 216 |
+
"features": ["palatal"],
|
| 217 |
+
"direction": "missing",
|
| 218 |
+
"tip": (
|
| 219 |
+
"This sound (/sh/, /zh/, /ch/, /jh/, /y/) is made with the tongue body raised "
|
| 220 |
+
"toward the hard palate β the hard, bony roof just behind the alveolar ridge. "
|
| 221 |
+
"Move your tongue further back from the teeth and arch it upward."
|
| 222 |
+
),
|
| 223 |
+
"drill": "Say 'she', 'measure', 'church' β feel your tongue body rise toward the hard palate.",
|
| 224 |
+
"self_check": "You should feel your tongue broadly touching or approaching the middle of the roof.",
|
| 225 |
+
},
|
| 226 |
+
|
| 227 |
+
# ββ PLACE: VELAR ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 228 |
+
{
|
| 229 |
+
"features": ["velar"],
|
| 230 |
+
"direction": "missing",
|
| 231 |
+
"tip": (
|
| 232 |
+
"This sound (/k/, /g/, /ng/) is made at the back of your mouth, with the back of your tongue "
|
| 233 |
+
"touching the soft palate (velum). "
|
| 234 |
+
"Try gargling β that back-of-tongue raised position is exactly what you need."
|
| 235 |
+
),
|
| 236 |
+
"drill": "Say 'king', 'ring', 'sing' β focus on the back-of-tongue closure each time.",
|
| 237 |
+
"self_check": "You should feel the back of your tongue lift and meet the soft palate.",
|
| 238 |
+
},
|
| 239 |
+
|
| 240 |
+
# ββ PLACE: GLOTTAL ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 241 |
+
{
|
| 242 |
+
"features": ["glottal"],
|
| 243 |
+
"direction": "missing",
|
| 244 |
+
"tip": (
|
| 245 |
+
"This sound (/hh/) is made deep in the throat at the vocal folds. "
|
| 246 |
+
"Think of fogging up a mirror β breathe out gently with a completely open throat. "
|
| 247 |
+
"No tongue or lip constriction should be involved."
|
| 248 |
+
),
|
| 249 |
+
"drill": "Say 'hi', 'hat', 'hot' β the /h/ should feel like a breath, not a friction sound.",
|
| 250 |
+
"self_check": "Place a hand on your throat β you should feel warmth from breath, not a hiss.",
|
| 251 |
+
},
|
| 252 |
+
|
| 253 |
+
# ββ PLACE: RETROFLEX βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 254 |
+
{
|
| 255 |
+
"features": ["retroflex"],
|
| 256 |
+
"direction": "missing",
|
| 257 |
+
"tip": (
|
| 258 |
+
"This sound (/r/ in English, /er/) requires your tongue tip to curl back toward "
|
| 259 |
+
"the back of the alveolar ridge without touching anything, or to bunch up in the "
|
| 260 |
+
"center of your mouth. "
|
| 261 |
+
"Say 'uh' then slowly curl your tongue tip upward and backward."
|
| 262 |
+
),
|
| 263 |
+
"drill": "Practice: 'uh' β curl tongue β 'er'. Hold 'er' for three seconds.",
|
| 264 |
+
"self_check": "Your tongue tip should point upward or backward but NOT touch the roof.",
|
| 265 |
+
},
|
| 266 |
+
|
| 267 |
+
# ββ PLACE: CORONAL βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 268 |
+
{
|
| 269 |
+
"features": ["coronal"],
|
| 270 |
+
"direction": "missing",
|
| 271 |
+
"tip": (
|
| 272 |
+
"Coronal sounds are made with the front part (blade or tip) of the tongue β "
|
| 273 |
+
"this covers /t/, /d/, /s/, /z/, /n/, /l/, /sh/, /th/, and /r/. "
|
| 274 |
+
"Make sure your tongue front is active and positioned correctly for this sound."
|
| 275 |
+
),
|
| 276 |
+
"drill": "Say 'tip', 'dip', 'sip', 'nip' β feel the tongue tip or blade doing the work.",
|
| 277 |
+
"self_check": "Is your tongue front β tip or blade β the part making contact?",
|
| 278 |
+
},
|
| 279 |
+
|
| 280 |
+
# ββ PLACE: DORSAL ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 281 |
+
{
|
| 282 |
+
"features": ["dorsal"],
|
| 283 |
+
"direction": "missing",
|
| 284 |
+
"tip": (
|
| 285 |
+
"Dorsal sounds (/k/, /g/, /ng/, /w/, /y/) involve the back (body or root) of the tongue. "
|
| 286 |
+
"Your tongue body needs to arch toward the velum or palate. "
|
| 287 |
+
"You may be using your tongue tip when the back of the tongue should lead."
|
| 288 |
+
),
|
| 289 |
+
"drill": "Say 'key', 'go', 'sing' β feel the back hump of your tongue rise each time.",
|
| 290 |
+
"self_check": "The front of your tongue should be relaxed; the back should be doing the work.",
|
| 291 |
+
},
|
| 292 |
+
|
| 293 |
+
# ββ VOWEL HEIGHT ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 294 |
+
{
|
| 295 |
+
"features": ["high"],
|
| 296 |
+
"direction": "missing",
|
| 297 |
+
"tip": (
|
| 298 |
+
"This vowel needs your tongue to be high in your mouth. "
|
| 299 |
+
"Think of 'ee' in 'feet' or 'oo' in 'food' β the tongue is raised close to the palate. "
|
| 300 |
+
"Raise your tongue toward the roof of your mouth as you say the vowel."
|
| 301 |
+
),
|
| 302 |
+
"drill": "Slide from 'ah' (low, jaw open) β 'ee' (high, jaw nearly closed) and feel the tongue rise.",
|
| 303 |
+
"self_check": "Your jaw should be mostly closed; the tongue should be near the roof.",
|
| 304 |
+
},
|
| 305 |
+
{
|
| 306 |
+
"features": ["mid"],
|
| 307 |
+
"direction": "missing",
|
| 308 |
+
"tip": (
|
| 309 |
+
"This vowel needs a mid-height tongue position β halfway between fully raised and fully lowered. "
|
| 310 |
+
"Think of 'eh' in 'bed' or 'oh' in 'boat'. "
|
| 311 |
+
"Relax your jaw to a half-open position."
|
| 312 |
+
),
|
| 313 |
+
"drill": "Slide 'ee' (high) β 'eh' (mid) β 'ah' (low) and stop at the middle position.",
|
| 314 |
+
"self_check": "Your jaw should be half open β neither clenched nor dropped wide.",
|
| 315 |
+
},
|
| 316 |
+
{
|
| 317 |
+
"features": ["low"],
|
| 318 |
+
"direction": "missing",
|
| 319 |
+
"tip": (
|
| 320 |
+
"This vowel needs your tongue to drop down and your jaw to open wide. "
|
| 321 |
+
"Think of 'ah' in 'father' or 'ae' in 'cat' β the tongue is flat and low. "
|
| 322 |
+
"Let your jaw drop and your tongue rest at the bottom of your mouth."
|
| 323 |
+
),
|
| 324 |
+
"drill": "Say 'ah' like a doctor's exam β exaggerate the open jaw and flat tongue.",
|
| 325 |
+
"self_check": "Your jaw should be open wide; your tongue should feel flat at the bottom.",
|
| 326 |
+
},
|
| 327 |
+
|
| 328 |
+
# ββ VOWEL BACKNESS βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 329 |
+
{
|
| 330 |
+
"features": ["front"],
|
| 331 |
+
"direction": "missing",
|
| 332 |
+
"tip": (
|
| 333 |
+
"This vowel should be made with your tongue pushed toward the front of your mouth. "
|
| 334 |
+
"Smile slightly β this naturally pulls the tongue body forward."
|
| 335 |
+
),
|
| 336 |
+
"drill": "Say 'ee β ay β eh' and feel your tongue staying at the front for all three.",
|
| 337 |
+
"self_check": "You should feel tension or contact toward the front of your mouth.",
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"features": ["back"],
|
| 341 |
+
"direction": "missing",
|
| 342 |
+
"tip": (
|
| 343 |
+
"This vowel should be made with your tongue retracted toward the back of your mouth. "
|
| 344 |
+
"Round your lips slightly and pull your tongue body backward as you say the vowel."
|
| 345 |
+
),
|
| 346 |
+
"drill": "Say 'oo β oh β aw' β feel your tongue pulling back and the lips rounding each time.",
|
| 347 |
+
"self_check": "You should feel the back of your tongue arch upward and backward.",
|
| 348 |
+
},
|
| 349 |
+
{
|
| 350 |
+
"features": ["central"],
|
| 351 |
+
"direction": "missing",
|
| 352 |
+
"tip": (
|
| 353 |
+
"This vowel (like the schwa /Ι/ in 'about') should be made with a completely neutral, "
|
| 354 |
+
"centered tongue β not pushed forward or pulled back. "
|
| 355 |
+
"Relax all tension in your jaw, lips, and tongue."
|
| 356 |
+
),
|
| 357 |
+
"drill": "Say 'uh' with a completely relaxed, drooping jaw and limp tongue.",
|
| 358 |
+
"self_check": "Your mouth should feel effortless, tongue neither front nor back.",
|
| 359 |
+
},
|
| 360 |
+
|
| 361 |
+
# ββ LIP ROUNDING (Others group: 'round') βββββββββββββββββββββββββββββ
|
| 362 |
+
{
|
| 363 |
+
"features": ["round"],
|
| 364 |
+
"direction": "missing",
|
| 365 |
+
"tip": (
|
| 366 |
+
"This sound requires rounded, protruded lips β like you are blowing out a candle. "
|
| 367 |
+
"Form an 'oo' shape with your lips before and during the sound."
|
| 368 |
+
),
|
| 369 |
+
"drill": "Exaggerate lip rounding: say 'oo β oh β aw' with very pursed lips.",
|
| 370 |
+
"self_check": "Look in a mirror β your lips should form a clear circle or oval.",
|
| 371 |
+
},
|
| 372 |
+
{
|
| 373 |
+
"features": ["round"],
|
| 374 |
+
"direction": "extra",
|
| 375 |
+
"tip": (
|
| 376 |
+
"You are rounding your lips when they should be spread or neutral. "
|
| 377 |
+
"Spread your lips into a slight smile and keep them flat as you say the sound."
|
| 378 |
+
),
|
| 379 |
+
"drill": "Say 'ee β ih β eh' with a relaxed smile β no lip rounding at all.",
|
| 380 |
+
"self_check": "Your lips should be flat or slightly spread, not puckered.",
|
| 381 |
+
},
|
| 382 |
+
|
| 383 |
+
# ββ VOWEL LENGTH (Others group: 'long' / 'short') βββββββββββββββββββββ
|
| 384 |
+
{
|
| 385 |
+
"features": ["long"],
|
| 386 |
+
"direction": "missing",
|
| 387 |
+
"tip": (
|
| 388 |
+
"This vowel should be noticeably longer in duration. "
|
| 389 |
+
"English long vowels (/iy/, /uw/, /aa/, /ao/, /ae/, /er/) are roughly twice "
|
| 390 |
+
"as long as their short counterparts. Stretch it out."
|
| 391 |
+
),
|
| 392 |
+
"drill": "Say 'beat' and hold the vowel: 'beeeeat'. Then compare with the short 'bit'.",
|
| 393 |
+
"self_check": "Record yourself β the vowel should sound stretched, not clipped.",
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"features": ["short"],
|
| 397 |
+
"direction": "missing",
|
| 398 |
+
"tip": (
|
| 399 |
+
"This vowel should be brief and clipped. "
|
| 400 |
+
"Short vowels (/ih/, /eh/, /ah/, /uh/) are reduced in duration. "
|
| 401 |
+
"Don't let the vowel linger β move quickly to the next sound."
|
| 402 |
+
),
|
| 403 |
+
"drill": "Say 'bit', 'bet', 'but', 'book' β snap off each vowel quickly.",
|
| 404 |
+
"self_check": "The vowel should feel brief. If you can hold it comfortably, it's too long.",
|
| 405 |
+
},
|
| 406 |
+
|
| 407 |
+
# ββ VOWEL TYPE (Others group: 'monophthong' / 'diphthong') ββββββββββ
|
| 408 |
+
{
|
| 409 |
+
"features": ["monophthong"],
|
| 410 |
+
"direction": "missing",
|
| 411 |
+
"tip": (
|
| 412 |
+
"This vowel should be pure and steady β your tongue and lips should hold the same "
|
| 413 |
+
"position throughout. You may be letting the vowel glide (diphthongize). "
|
| 414 |
+
"Keep your tongue and jaw completely still from start to finish."
|
| 415 |
+
),
|
| 416 |
+
"drill": "Hold /aa/, /iy/, or /uw/ for three seconds without any movement.",
|
| 417 |
+
"self_check": "The vowel quality should be identical at the beginning and end β no glide.",
|
| 418 |
+
},
|
| 419 |
+
{
|
| 420 |
+
"features": ["diphthong"],
|
| 421 |
+
"direction": "missing",
|
| 422 |
+
"tip": (
|
| 423 |
+
"This vowel should glide from one position to another β it is a diphthong. "
|
| 424 |
+
"English diphthongs like /ay/ ('bite'), /aw/ ('bout'), /oy/ ('boy'), "
|
| 425 |
+
"/ey/ ('bait'), /ow/ ('boat') have a clear movement. "
|
| 426 |
+
"Let your tongue and jaw glide smoothly to the second target."
|
| 427 |
+
),
|
| 428 |
+
"drill": "Say 'buy β bow β boy β bay β boat' slowly and feel the glide in each vowel.",
|
| 429 |
+
"self_check": "The vowel should sound like it is moving, not fixed in one place.",
|
| 430 |
+
},
|
| 431 |
+
]
|
| 432 |
+
|
| 433 |
+
# Build a fast lookup: feature β list of applicable rules
|
| 434 |
+
_RULE_INDEX: Dict[str, List[Dict]] = {}
|
| 435 |
+
for rule in FEATURE_RULES:
|
| 436 |
+
for feat in rule["features"]:
|
| 437 |
+
_RULE_INDEX.setdefault(feat, []).append(rule)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 441 |
+
# 2. Rule matcher
|
| 442 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 443 |
+
|
| 444 |
+
@dataclass
|
| 445 |
+
class RuleFeedback:
|
| 446 |
+
feature: str
|
| 447 |
+
direction: str # "missing" | "extra"
|
| 448 |
+
tip: str
|
| 449 |
+
drill: str
|
| 450 |
+
self_check: str
|
| 451 |
+
count: int = 1 # how many phonemes triggered this rule
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
def match_rules(errors: List[PhonemeError]) -> List[RuleFeedback]:
|
| 455 |
+
"""
|
| 456 |
+
Given a list of phoneme errors, find the most relevant feedback rules.
|
| 457 |
+
Rules are deduplicated and sorted by frequency of occurrence.
|
| 458 |
+
"""
|
| 459 |
+
triggered: Dict[Tuple[str, str], RuleFeedback] = {}
|
| 460 |
+
|
| 461 |
+
for error in errors:
|
| 462 |
+
for feat in error.missing_features:
|
| 463 |
+
for rule in _RULE_INDEX.get(feat, []):
|
| 464 |
+
if rule["direction"] in ("missing", "both"):
|
| 465 |
+
key = (feat, "missing")
|
| 466 |
+
if key in triggered:
|
| 467 |
+
triggered[key].count += 1
|
| 468 |
+
else:
|
| 469 |
+
triggered[key] = RuleFeedback(
|
| 470 |
+
feature=feat,
|
| 471 |
+
direction="missing",
|
| 472 |
+
tip=rule["tip"],
|
| 473 |
+
drill=rule["drill"],
|
| 474 |
+
self_check=rule["self_check"],
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
for feat in error.extra_features:
|
| 478 |
+
for rule in _RULE_INDEX.get(feat, []):
|
| 479 |
+
if rule["direction"] in ("extra", "both"):
|
| 480 |
+
key = (feat, "extra")
|
| 481 |
+
if key in triggered:
|
| 482 |
+
triggered[key].count += 1
|
| 483 |
+
else:
|
| 484 |
+
triggered[key] = RuleFeedback(
|
| 485 |
+
feature=feat,
|
| 486 |
+
direction="extra",
|
| 487 |
+
tip=rule["tip"],
|
| 488 |
+
drill=rule["drill"],
|
| 489 |
+
self_check=rule["self_check"],
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# Sort by occurrence count descending
|
| 493 |
+
return sorted(triggered.values(), key=lambda r: -r.count)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 497 |
+
# 3. Template-based fallback feedback (no LLM needed)
|
| 498 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 499 |
+
|
| 500 |
+
def format_feedback_template(
|
| 501 |
+
result: MDDResult,
|
| 502 |
+
rules: List[RuleFeedback],
|
| 503 |
+
max_issues: int = 3,
|
| 504 |
+
) -> str:
|
| 505 |
+
"""Structured text feedback without LLM β always available."""
|
| 506 |
+
lines = []
|
| 507 |
+
score = result.utterance_score
|
| 508 |
+
|
| 509 |
+
# Score header
|
| 510 |
+
if score >= 85:
|
| 511 |
+
lines.append(f"π Great pronunciation! Score: {score:.0f}/100")
|
| 512 |
+
elif score >= 65:
|
| 513 |
+
lines.append(f"π Good effort! Score: {score:.0f}/100 β a few things to polish.")
|
| 514 |
+
elif score >= 45:
|
| 515 |
+
lines.append(f"π Score: {score:.0f}/100 β let's work on some key areas.")
|
| 516 |
+
else:
|
| 517 |
+
lines.append(f"πͺ Score: {score:.0f}/100 β keep practicing, you'll get there!")
|
| 518 |
+
|
| 519 |
+
if not rules:
|
| 520 |
+
lines.append("\nNo significant feature errors detected. Well done!")
|
| 521 |
+
return "\n".join(lines)
|
| 522 |
+
|
| 523 |
+
lines.append(f"\nI found {len(result.errors)} phoneme(s) that need attention.\n")
|
| 524 |
+
|
| 525 |
+
for i, rule in enumerate(rules[:max_issues]):
|
| 526 |
+
direction_word = "missing" if rule.direction == "missing" else "extra"
|
| 527 |
+
lines.append(f"β Issue {i+1}: [{rule.feature}] feature {direction_word}")
|
| 528 |
+
lines.append(f" π‘ {rule.tip}")
|
| 529 |
+
lines.append(f" ποΈ Drill: {rule.drill}")
|
| 530 |
+
lines.append(f" β
Self-check: {rule.self_check}\n")
|
| 531 |
+
|
| 532 |
+
return "\n".join(lines)
|
| 533 |
+
|
| 534 |
+
|
| 535 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 536 |
+
# 4. LLM-enhanced feedback
|
| 537 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 538 |
+
|
| 539 |
+
LLM_SYSTEM_PROMPT = """You are a warm, encouraging English pronunciation coach.
|
| 540 |
+
Your student just attempted to say a sentence and you've identified specific
|
| 541 |
+
phonological feature errors. Your task is to rewrite the structured feedback
|
| 542 |
+
into a single natural, conversational coaching response.
|
| 543 |
+
|
| 544 |
+
Rules:
|
| 545 |
+
- Keep ALL the articulatory tips and self-checks intact β do not omit or soften them.
|
| 546 |
+
- Write as if speaking to the student directly.
|
| 547 |
+
- Be encouraging but honest.
|
| 548 |
+
- Limit response to 200 words maximum.
|
| 549 |
+
- Do not add new advice not present in the structured feedback.
|
| 550 |
+
- Start with a brief overall assessment, then naturally weave in the tips.
|
| 551 |
+
- End with one motivating sentence.
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
def generate_llm_feedback(
|
| 555 |
+
structured_feedback: str,
|
| 556 |
+
score: float,
|
| 557 |
+
model_name: str = "Qwen/Qwen2.5-0.5B-Instruct", # lightweight default
|
| 558 |
+
use_cloud_fallback: bool = True,
|
| 559 |
+
) -> str:
|
| 560 |
+
"""
|
| 561 |
+
Rewrites structured feedback into natural coaching language.
|
| 562 |
+
|
| 563 |
+
Tries (in order):
|
| 564 |
+
1. Local transformers model (if available)
|
| 565 |
+
2. Cloud LLM API (if use_cloud_fallback=True and API key set)
|
| 566 |
+
3. Returns structured_feedback unchanged as graceful degradation
|
| 567 |
+
"""
|
| 568 |
+
prompt = f"""Here is structured pronunciation feedback for a student who scored {score:.0f}/100:
|
| 569 |
+
|
| 570 |
+
{structured_feedback}
|
| 571 |
+
|
| 572 |
+
Please rewrite this as a warm, natural coaching response."""
|
| 573 |
+
|
| 574 |
+
# --- Try local model first ---
|
| 575 |
+
try:
|
| 576 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 577 |
+
import torch
|
| 578 |
+
|
| 579 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 580 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 581 |
+
model_name,
|
| 582 |
+
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
|
| 583 |
+
device_map="auto",
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
messages = [
|
| 587 |
+
{"role": "system", "content": LLM_SYSTEM_PROMPT},
|
| 588 |
+
{"role": "user", "content": prompt},
|
| 589 |
+
]
|
| 590 |
+
text = tokenizer.apply_chat_template(
|
| 591 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 592 |
+
)
|
| 593 |
+
inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 594 |
+
with torch.no_grad():
|
| 595 |
+
output = model.generate(
|
| 596 |
+
**inputs,
|
| 597 |
+
max_new_tokens=256,
|
| 598 |
+
temperature=0.7,
|
| 599 |
+
do_sample=True,
|
| 600 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 601 |
+
)
|
| 602 |
+
response = tokenizer.decode(
|
| 603 |
+
output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True
|
| 604 |
+
)
|
| 605 |
+
return response.strip()
|
| 606 |
+
|
| 607 |
+
except Exception as local_err:
|
| 608 |
+
print(f"[Local LLM] Not available: {local_err}")
|
| 609 |
+
|
| 610 |
+
# --- Cloud fallback (OpenAI-compatible API) ---
|
| 611 |
+
if use_cloud_fallback:
|
| 612 |
+
api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY")
|
| 613 |
+
api_base = os.environ.get("LLM_API_BASE", "https://api.openai.com/v1")
|
| 614 |
+
cloud_model = os.environ.get("LLM_MODEL", "gpt-4o-mini")
|
| 615 |
+
|
| 616 |
+
if api_key:
|
| 617 |
+
try:
|
| 618 |
+
import httpx
|
| 619 |
+
headers = {
|
| 620 |
+
"Authorization": f"Bearer {api_key}",
|
| 621 |
+
"Content-Type": "application/json",
|
| 622 |
+
}
|
| 623 |
+
body = {
|
| 624 |
+
"model": cloud_model,
|
| 625 |
+
"messages": [
|
| 626 |
+
{"role": "system", "content": LLM_SYSTEM_PROMPT},
|
| 627 |
+
{"role": "user", "content": prompt},
|
| 628 |
+
],
|
| 629 |
+
"max_tokens": 300,
|
| 630 |
+
"temperature": 0.7,
|
| 631 |
+
}
|
| 632 |
+
r = httpx.post(f"{api_base}/chat/completions", json=body, headers=headers, timeout=15)
|
| 633 |
+
r.raise_for_status()
|
| 634 |
+
return r.json()["choices"][0]["message"]["content"].strip()
|
| 635 |
+
except Exception as cloud_err:
|
| 636 |
+
print(f"[Cloud LLM] Failed: {cloud_err}")
|
| 637 |
+
|
| 638 |
+
# --- Graceful degradation ---
|
| 639 |
+
return structured_feedback
|
| 640 |
+
|
| 641 |
+
|
| 642 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 643 |
+
# 5. Main feedback pipeline
|
| 644 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββ
|
| 645 |
+
|
| 646 |
+
def generate_feedback(
|
| 647 |
+
result: MDDResult,
|
| 648 |
+
use_llm: bool = True,
|
| 649 |
+
max_issues: int = 3,
|
| 650 |
+
) -> Dict:
|
| 651 |
+
"""
|
| 652 |
+
Full feedback pipeline. Returns a dict with keys:
|
| 653 |
+
score, template_feedback, final_feedback, error_summary, rules_triggered
|
| 654 |
+
"""
|
| 655 |
+
rules = match_rules(result.errors)
|
| 656 |
+
template_fb = format_feedback_template(result, rules, max_issues)
|
| 657 |
+
|
| 658 |
+
if use_llm and rules:
|
| 659 |
+
final_fb = generate_llm_feedback(template_fb, result.utterance_score)
|
| 660 |
+
else:
|
| 661 |
+
final_fb = template_fb
|
| 662 |
+
|
| 663 |
+
error_summary = [
|
| 664 |
+
{
|
| 665 |
+
"position": e.position,
|
| 666 |
+
"target": e.target_phoneme,
|
| 667 |
+
"produced": e.produced_phoneme,
|
| 668 |
+
"missing_features": e.missing_features,
|
| 669 |
+
"extra_features": e.extra_features,
|
| 670 |
+
"accuracy": round(e.feature_accuracy, 3),
|
| 671 |
+
"severity": e.severity,
|
| 672 |
+
}
|
| 673 |
+
for e in result.errors
|
| 674 |
+
]
|
| 675 |
+
|
| 676 |
+
return {
|
| 677 |
+
"score": round(result.utterance_score, 1),
|
| 678 |
+
"template_feedback": template_fb,
|
| 679 |
+
"final_feedback": final_fb,
|
| 680 |
+
"error_summary": error_summary,
|
| 681 |
+
"feature_error_counts": result.feature_error_counts,
|
| 682 |
+
"rules_triggered": [
|
| 683 |
+
{
|
| 684 |
+
"feature": r.feature,
|
| 685 |
+
"direction": r.direction,
|
| 686 |
+
"occurrences": r.count,
|
| 687 |
+
}
|
| 688 |
+
for r in rules
|
| 689 |
+
],
|
| 690 |
+
}
|
mdd_engine.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MDD Engine β Mispronunciation Detection and Diagnosis
|
| 3 |
+
=====================================================
|
| 4 |
+
|
| 5 |
+
Architecture (Shahin et al. 2025)
|
| 6 |
+
----------------------------------
|
| 7 |
+
Your model runs 35 independent CTC decoders, one per phonological feature.
|
| 8 |
+
Each decoder outputs a sequence of +att(1) / -att(0) labels, with blanks
|
| 9 |
+
already removed and runs collapsed β so the output length reflects the number
|
| 10 |
+
of detected phoneme-level events, NOT audio frames.
|
| 11 |
+
|
| 12 |
+
The canonical target comes from the user's typed sentence:
|
| 13 |
+
sentence β G2P (CMU ARPAbet) β phoneme_sequence_to_feature_sequences()
|
| 14 |
+
β 35 binary label sequences of length T (number of target phonemes)
|
| 15 |
+
|
| 16 |
+
The problem: the actual decoded sequence per feature may have a DIFFERENT
|
| 17 |
+
length than T, because the student may have:
|
| 18 |
+
- deleted phonemes (actual shorter than target)
|
| 19 |
+
- inserted extras (actual longer than target)
|
| 20 |
+
- substituted (same length, wrong labels)
|
| 21 |
+
|
| 22 |
+
Solution: Needleman-Wunsch (global sequence alignment) per feature
|
| 23 |
+
------------------------------------------------------------------
|
| 24 |
+
For each of the 35 features we run a global pairwise alignment between the
|
| 25 |
+
target binary sequence and the actual binary sequence. This gives us an
|
| 26 |
+
explicit alignment path with match / mismatch / insertion / deletion ops.
|
| 27 |
+
|
| 28 |
+
We then aggregate across all 35 features to get, per target phoneme position:
|
| 29 |
+
- which actual position it maps to (or DELETION if no match)
|
| 30 |
+
- which features are missing (+att in target, -att or gap in actual)
|
| 31 |
+
- which features are extra (-att in target, +att in actual)
|
| 32 |
+
- a weighted feature accuracy score
|
| 33 |
+
|
| 34 |
+
This is the standard approach in phonological MDD literature when no frame-
|
| 35 |
+
level forced alignment is available (see e.g. Lee & Glass 2015, Leung et al.
|
| 36 |
+
2019, and the feature-based MDD track of the AIP challenge).
|
| 37 |
+
|
| 38 |
+
Input/output contract
|
| 39 |
+
---------------------
|
| 40 |
+
actual_feature_seqs : list[list[int]] β 35 lists, each decoded CTC output
|
| 41 |
+
Values: 1 (+att) or 0 (-att)
|
| 42 |
+
Lengths may differ across features
|
| 43 |
+
and from the canonical length T
|
| 44 |
+
|
| 45 |
+
target_phonemes : list[str] β CMU ARPAbet phoneme sequence from
|
| 46 |
+
the user's typed sentence, length T
|
| 47 |
+
|
| 48 |
+
Output: MDDResult (see dataclass below)
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
from __future__ import annotations
|
| 52 |
+
|
| 53 |
+
import numpy as np
|
| 54 |
+
from dataclasses import dataclass, field
|
| 55 |
+
from typing import List, Dict, Tuple, Optional
|
| 56 |
+
|
| 57 |
+
from phonological_features import (
|
| 58 |
+
PHONOLOGICAL_FEATURES,
|
| 59 |
+
phoneme_sequence_to_feature_sequences,
|
| 60 |
+
phoneme_to_feature_vector,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 64 |
+
# 1. Feature schema & weights
|
| 65 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 66 |
+
|
| 67 |
+
FEATURE_NAMES: List[str] = PHONOLOGICAL_FEATURES # 35 features, canonical order
|
| 68 |
+
NUM_FEATURES = len(FEATURE_NAMES) # 35
|
| 69 |
+
assert NUM_FEATURES == 35
|
| 70 |
+
|
| 71 |
+
F2I: Dict[str, int] = {f: i for i, f in enumerate(FEATURE_NAMES)}
|
| 72 |
+
|
| 73 |
+
# Perceptual salience weights β higher = more important mismatch.
|
| 74 |
+
# Manner errors (wrong sound class) are most disruptive.
|
| 75 |
+
# Voicing errors are highly salient in English.
|
| 76 |
+
# Place errors matter but less so than manner.
|
| 77 |
+
# Length/type distinctions are least salient in L2 MDD.
|
| 78 |
+
FEATURE_WEIGHTS: np.ndarray = np.array([
|
| 79 |
+
# Manners (11): consonant sonorant fricative nasal stop
|
| 80 |
+
2.0, 1.5, 1.8, 2.0, 2.0,
|
| 81 |
+
# approximant affricate liquid vowel semivowel continuant
|
| 82 |
+
1.5, 1.8, 1.5, 2.0, 1.5, 1.2,
|
| 83 |
+
# Places (18): alveolar palatal dental glottal labial velar
|
| 84 |
+
1.5, 1.4, 1.3, 1.2, 1.5, 1.5,
|
| 85 |
+
# mid high low front back central
|
| 86 |
+
1.8, 1.8, 1.8, 1.6, 1.6, 1.2,
|
| 87 |
+
# anterior posterior retroflex bilabial coronal dorsal
|
| 88 |
+
1.3, 1.3, 1.3, 1.4, 1.3, 1.3,
|
| 89 |
+
# Others (6): long short monophthong diphthong round voiced
|
| 90 |
+
1.0, 1.0, 1.2, 1.2, 1.2, 2.5,
|
| 91 |
+
], dtype=np.float32)
|
| 92 |
+
|
| 93 |
+
assert len(FEATURE_WEIGHTS) == 35
|
| 94 |
+
|
| 95 |
+
# Alignment op codes
|
| 96 |
+
MATCH = 0 # same label, same position
|
| 97 |
+
MISMATCH = 1 # different label, same position
|
| 98 |
+
DELETE = 2 # target has event, actual has gap (deletion error)
|
| 99 |
+
INSERT = 3 # actual has event, target has gap (insertion error)
|
| 100 |
+
|
| 101 |
+
# NW scoring scheme
|
| 102 |
+
MATCH_SCORE = 2
|
| 103 |
+
MISMATCH_SCORE = -1
|
| 104 |
+
GAP_PENALTY = -2 # penalises deletions and insertions equally
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 108 |
+
# 2. Data classes
|
| 109 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
+
|
| 111 |
+
@dataclass
|
| 112 |
+
class AlignedPosition:
|
| 113 |
+
"""One position in the target sequence after multi-feature alignment."""
|
| 114 |
+
target_idx: int # index in target phoneme sequence
|
| 115 |
+
actual_idx: Optional[int] # index in actual sequence, None = deletion
|
| 116 |
+
op: int # MATCH / MISMATCH / DELETE / INSERT
|
| 117 |
+
target_bits: List[int] # canonical feature vector (35 bits)
|
| 118 |
+
actual_bits: List[int] # observed feature vector (35 bits, 0 if deleted)
|
| 119 |
+
missing_features: List[str] # +att in target, -att or gap in actual
|
| 120 |
+
extra_features: List[str] # -att in target, +att in actual
|
| 121 |
+
feature_accuracy: float # weighted accuracy 0-1
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class PhonemeError:
|
| 126 |
+
"""One mispronounced phoneme with its full feature-level diagnosis."""
|
| 127 |
+
position: int # index in target sequence
|
| 128 |
+
target_phoneme: str # ARPAbet label from typed sentence
|
| 129 |
+
missing_features: List[str] # features the student failed to produce
|
| 130 |
+
extra_features: List[str] # features the student added erroneously
|
| 131 |
+
is_deletion: bool # student dropped this phoneme entirely
|
| 132 |
+
feature_accuracy: float # 0-1
|
| 133 |
+
severity: str # "mild" | "moderate" | "severe"
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@dataclass
|
| 137 |
+
class MDDResult:
|
| 138 |
+
utterance_score: float # 0-100
|
| 139 |
+
phoneme_scores: List[float] # per target phoneme, 0-1
|
| 140 |
+
errors: List[PhonemeError]
|
| 141 |
+
aligned_positions: List[AlignedPosition]
|
| 142 |
+
feature_error_counts: Dict[str, int] # aggregated across all phonemes
|
| 143 |
+
deletion_count: int
|
| 144 |
+
insertion_count: int
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
# 3. Needleman-Wunsch per-feature aligner
|
| 149 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
|
| 151 |
+
def _nw_align(target_seq: List[int],
|
| 152 |
+
actual_seq: List[int]) -> List[Tuple[Optional[int], Optional[int]]]:
|
| 153 |
+
"""
|
| 154 |
+
Global sequence alignment (Needleman-Wunsch) for two binary label sequences.
|
| 155 |
+
|
| 156 |
+
Returns a list of (target_idx, actual_idx) pairs where:
|
| 157 |
+
(i, j) β match or mismatch at target[i], actual[j]
|
| 158 |
+
(i, None) β deletion: target[i] has no corresponding actual event
|
| 159 |
+
(None, j) β insertion: actual[j] has no corresponding target event
|
| 160 |
+
|
| 161 |
+
Binary values: 1 = +att, 0 = -att
|
| 162 |
+
"""
|
| 163 |
+
T = len(target_seq)
|
| 164 |
+
A = len(actual_seq)
|
| 165 |
+
|
| 166 |
+
# Fill score matrix
|
| 167 |
+
score = np.zeros((T + 1, A + 1), dtype=np.float32)
|
| 168 |
+
score[0, :] = np.arange(A + 1) * GAP_PENALTY
|
| 169 |
+
score[:, 0] = np.arange(T + 1) * GAP_PENALTY
|
| 170 |
+
|
| 171 |
+
for i in range(1, T + 1):
|
| 172 |
+
for j in range(1, A + 1):
|
| 173 |
+
s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE
|
| 174 |
+
score[i, j] = max(
|
| 175 |
+
score[i-1, j-1] + s, # match/mismatch
|
| 176 |
+
score[i-1, j] + GAP_PENALTY, # deletion
|
| 177 |
+
score[i, j-1] + GAP_PENALTY, # insertion
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Traceback
|
| 181 |
+
path: List[Tuple[Optional[int], Optional[int]]] = []
|
| 182 |
+
i, j = T, A
|
| 183 |
+
while i > 0 or j > 0:
|
| 184 |
+
if i > 0 and j > 0:
|
| 185 |
+
s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE
|
| 186 |
+
if score[i, j] == score[i-1, j-1] + s:
|
| 187 |
+
path.append((i-1, j-1))
|
| 188 |
+
i -= 1; j -= 1
|
| 189 |
+
continue
|
| 190 |
+
if i > 0 and score[i, j] == score[i-1, j] + GAP_PENALTY:
|
| 191 |
+
path.append((i-1, None)) # deletion
|
| 192 |
+
i -= 1
|
| 193 |
+
else:
|
| 194 |
+
path.append((None, j-1)) # insertion
|
| 195 |
+
j -= 1
|
| 196 |
+
|
| 197 |
+
path.reverse()
|
| 198 |
+
return path
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 202 |
+
# 4. Multi-feature alignment aggregator
|
| 203 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 204 |
+
|
| 205 |
+
def _align_all_features(
|
| 206 |
+
target_feat_seqs: List[List[int]], # 35 lists, each length T
|
| 207 |
+
actual_feat_seqs: List[List[int]], # 35 lists, each possibly != T
|
| 208 |
+
T: int, # number of target phonemes
|
| 209 |
+
) -> List[AlignedPosition]:
|
| 210 |
+
"""
|
| 211 |
+
Run NW alignment independently on each of 35 feature sequences, then
|
| 212 |
+
aggregate the results per target phoneme position.
|
| 213 |
+
|
| 214 |
+
Strategy
|
| 215 |
+
--------
|
| 216 |
+
Each feature gives its own alignment path. We collect, for each target
|
| 217 |
+
position i, a vote over all 35 features about what actual position it
|
| 218 |
+
maps to. The plurality actual index wins. If the majority vote is "gap"
|
| 219 |
+
(deletion), the position is marked as a deletion.
|
| 220 |
+
|
| 221 |
+
Then per position we reconstruct the actual feature bits from the voted
|
| 222 |
+
actual index across all features.
|
| 223 |
+
"""
|
| 224 |
+
# votes[target_idx] β list of actual_idx votes (None = deletion vote)
|
| 225 |
+
votes: List[List[Optional[int]]] = [[] for _ in range(T)]
|
| 226 |
+
# per_feature_actual_idx[feat][target_idx] β actual_idx or None
|
| 227 |
+
per_feat_map: List[Dict[int, Optional[int]]] = [
|
| 228 |
+
{} for _ in range(NUM_FEATURES)
|
| 229 |
+
]
|
| 230 |
+
|
| 231 |
+
for feat_i in range(NUM_FEATURES):
|
| 232 |
+
t_seq = target_feat_seqs[feat_i] # length T
|
| 233 |
+
a_seq = actual_feat_seqs[feat_i] # length may differ
|
| 234 |
+
|
| 235 |
+
path = _nw_align(t_seq, a_seq)
|
| 236 |
+
|
| 237 |
+
for (ti, ai) in path:
|
| 238 |
+
if ti is None:
|
| 239 |
+
continue # insertion β no target position, skip
|
| 240 |
+
votes[ti].append(ai) # ai may be None (deletion)
|
| 241 |
+
per_feat_map[feat_i][ti] = ai
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
# Resolve votes per target position
|
| 245 |
+
aligned: List[AlignedPosition] = []
|
| 246 |
+
|
| 247 |
+
DELETION_VOTE_THRESHOLD = 0.5 # >50% gap votes β mark as DELETE
|
| 248 |
+
|
| 249 |
+
for ti in range(T):
|
| 250 |
+
v = votes[ti]
|
| 251 |
+
non_null = [x for x in v if x is not None]
|
| 252 |
+
null_count = len(v) - len(non_null)
|
| 253 |
+
deletion_fraction = null_count / max(len(v), 1)
|
| 254 |
+
|
| 255 |
+
if not non_null or deletion_fraction > DELETION_VOTE_THRESHOLD:
|
| 256 |
+
chosen_ai = None
|
| 257 |
+
else:
|
| 258 |
+
# Plurality vote among non-null actual indices
|
| 259 |
+
counts: Dict[int, int] = {}
|
| 260 |
+
for idx in non_null:
|
| 261 |
+
counts[idx] = counts.get(idx, 0) + 1
|
| 262 |
+
chosen_ai = max(counts, key=counts.__getitem__)
|
| 263 |
+
# Build target and actual bit vectors for this position
|
| 264 |
+
target_bits = [target_feat_seqs[f][ti] for f in range(NUM_FEATURES)]
|
| 265 |
+
|
| 266 |
+
if chosen_ai is not None:
|
| 267 |
+
actual_bits = []
|
| 268 |
+
for f in range(NUM_FEATURES):
|
| 269 |
+
# Use per-feature actual value if this feature agrees on chosen_ai
|
| 270 |
+
feat_ai = per_feat_map[f].get(ti, None)
|
| 271 |
+
if feat_ai == chosen_ai:
|
| 272 |
+
actual_bits.append(actual_feat_seqs[f][feat_ai]
|
| 273 |
+
if feat_ai < len(actual_feat_seqs[f]) else 0)
|
| 274 |
+
else:
|
| 275 |
+
# Feature disagrees on the position β use its own aligned value
|
| 276 |
+
fa = per_feat_map[f].get(ti, None)
|
| 277 |
+
if fa is not None and fa < len(actual_feat_seqs[f]):
|
| 278 |
+
actual_bits.append(actual_feat_seqs[f][fa])
|
| 279 |
+
else:
|
| 280 |
+
actual_bits.append(0) # treat as absent
|
| 281 |
+
op = MATCH if target_bits == actual_bits else MISMATCH
|
| 282 |
+
else:
|
| 283 |
+
actual_bits = [0] * NUM_FEATURES
|
| 284 |
+
op = DELETE
|
| 285 |
+
|
| 286 |
+
# Compute feature-level errors
|
| 287 |
+
missing = [FEATURE_NAMES[f] for f in range(NUM_FEATURES)
|
| 288 |
+
if target_bits[f] == 1 and actual_bits[f] == 0]
|
| 289 |
+
extra = [FEATURE_NAMES[f] for f in range(NUM_FEATURES)
|
| 290 |
+
if target_bits[f] == 0 and actual_bits[f] == 1]
|
| 291 |
+
|
| 292 |
+
# Weighted accuracy: fraction of weighted features correctly produced
|
| 293 |
+
correct_weight = sum(
|
| 294 |
+
FEATURE_WEIGHTS[f]
|
| 295 |
+
for f in range(NUM_FEATURES)
|
| 296 |
+
if target_bits[f] == actual_bits[f]
|
| 297 |
+
)
|
| 298 |
+
total_weight = float(FEATURE_WEIGHTS.sum())
|
| 299 |
+
accuracy = float(correct_weight / total_weight)
|
| 300 |
+
|
| 301 |
+
aligned.append(AlignedPosition(
|
| 302 |
+
target_idx=ti,
|
| 303 |
+
actual_idx=chosen_ai,
|
| 304 |
+
op=op,
|
| 305 |
+
target_bits=target_bits,
|
| 306 |
+
actual_bits=actual_bits,
|
| 307 |
+
missing_features=missing,
|
| 308 |
+
extra_features=extra,
|
| 309 |
+
feature_accuracy=accuracy,
|
| 310 |
+
))
|
| 311 |
+
|
| 312 |
+
return aligned
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 316 |
+
# 5. Insertion detector
|
| 317 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 318 |
+
|
| 319 |
+
def _count_insertions(
|
| 320 |
+
actual_feat_seqs: List[List[int]],
|
| 321 |
+
actual_len: int,
|
| 322 |
+
aligned: List[AlignedPosition],
|
| 323 |
+
) -> int:
|
| 324 |
+
"""
|
| 325 |
+
Count actual positions that were voted as insertions (not mapped to any
|
| 326 |
+
target position) by the majority of features.
|
| 327 |
+
"""
|
| 328 |
+
used_actual = set(
|
| 329 |
+
ap.actual_idx for ap in aligned if ap.actual_idx is not None
|
| 330 |
+
)
|
| 331 |
+
inserted = set(range(actual_len)) - used_actual
|
| 332 |
+
return len(inserted)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
# βββββββββοΏ½οΏ½βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 336 |
+
# 6. Severity classifier
|
| 337 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 338 |
+
|
| 339 |
+
# Thresholds on weighted feature error rate
|
| 340 |
+
_SEV = {"mild": 0.85, "moderate": 0.65} # accuracy thresholds (higher = easier)
|
| 341 |
+
|
| 342 |
+
def _severity(accuracy: float, is_deletion: bool) -> str:
|
| 343 |
+
if is_deletion:
|
| 344 |
+
return "severe"
|
| 345 |
+
if accuracy >= _SEV["mild"]:
|
| 346 |
+
return "mild"
|
| 347 |
+
if accuracy >= _SEV["moderate"]:
|
| 348 |
+
return "moderate"
|
| 349 |
+
return "severe"
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 353 |
+
# 7. Scorer
|
| 354 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 355 |
+
|
| 356 |
+
def _score_utterance(aligned: List[AlignedPosition]) -> Tuple[float, List[float]]:
|
| 357 |
+
"""
|
| 358 |
+
Per-phoneme score: weighted feature accuracy (0-1).
|
| 359 |
+
Deletions score 0.
|
| 360 |
+
Utterance score: weighted mean, penalising deletions most.
|
| 361 |
+
"""
|
| 362 |
+
phoneme_scores = [ap.feature_accuracy for ap in aligned]
|
| 363 |
+
utterance_score = float(np.mean(phoneme_scores)) * 100.0
|
| 364 |
+
return utterance_score, phoneme_scores
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 368 |
+
# 8. Error list builder
|
| 369 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 370 |
+
|
| 371 |
+
def _build_errors(
|
| 372 |
+
aligned: List[AlignedPosition],
|
| 373 |
+
target_phonemes: List[str],
|
| 374 |
+
) -> List[PhonemeError]:
|
| 375 |
+
errors = []
|
| 376 |
+
for ap in aligned:
|
| 377 |
+
if ap.op == MATCH and not ap.missing_features and not ap.extra_features:
|
| 378 |
+
continue # perfectly correct, no error to report
|
| 379 |
+
|
| 380 |
+
errors.append(PhonemeError(
|
| 381 |
+
position=ap.target_idx,
|
| 382 |
+
target_phoneme=target_phonemes[ap.target_idx],
|
| 383 |
+
missing_features=ap.missing_features,
|
| 384 |
+
extra_features=ap.extra_features,
|
| 385 |
+
is_deletion=(ap.op == DELETE),
|
| 386 |
+
feature_accuracy=ap.feature_accuracy,
|
| 387 |
+
severity=_severity(ap.feature_accuracy, ap.op == DELETE),
|
| 388 |
+
))
|
| 389 |
+
return errors
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 393 |
+
# 9. Aggregate feature error counts
|
| 394 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 395 |
+
|
| 396 |
+
def _aggregate(errors: List[PhonemeError]) -> Dict[str, int]:
|
| 397 |
+
counts: Dict[str, int] = {}
|
| 398 |
+
for e in errors:
|
| 399 |
+
for f in e.missing_features + e.extra_features:
|
| 400 |
+
counts[f] = counts.get(f, 0) + 1
|
| 401 |
+
return dict(sorted(counts.items(), key=lambda x: -x[1]))
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 405 |
+
# 10. Public entry point
|
| 406 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 407 |
+
|
| 408 |
+
def run_mdd(
|
| 409 |
+
actual_feature_seqs: List[List[int]],
|
| 410 |
+
target_phonemes: List[str],
|
| 411 |
+
) -> MDDResult:
|
| 412 |
+
"""
|
| 413 |
+
Full MDD pipeline for a CTC phonological-feature model.
|
| 414 |
+
|
| 415 |
+
Parameters
|
| 416 |
+
----------
|
| 417 |
+
actual_feature_seqs : list of 35 lists of int (0 or 1)
|
| 418 |
+
CTC-decoded output of your model, AFTER blank removal and run-length
|
| 419 |
+
collapsing. Each list is the decoded +att/βatt sequence for one feature.
|
| 420 |
+
Lengths may differ from each other and from len(target_phonemes).
|
| 421 |
+
Index order must match PHONOLOGICAL_FEATURES / FEATURE_NAMES.
|
| 422 |
+
|
| 423 |
+
Concretely, if your model outputs logits of shape (T_audio, 71):
|
| 424 |
+
nodes 0-34 = +att for features 0-34
|
| 425 |
+
nodes 35-69 = -att for features 0-34
|
| 426 |
+
node 70 = blank
|
| 427 |
+
Then for feature i, the CTC-decoded sequence is a list of 0s and 1s
|
| 428 |
+
(1 = +att node fired, 0 = -att node fired), blanks removed.
|
| 429 |
+
|
| 430 |
+
target_phonemes : list of str
|
| 431 |
+
CMU ARPAbet phoneme sequence from the user's typed sentence.
|
| 432 |
+
Obtain via any G2P tool, e.g. g2p_en:
|
| 433 |
+
from g2p_en import G2p
|
| 434 |
+
target_phonemes = G2p()(sentence)
|
| 435 |
+
|
| 436 |
+
Returns
|
| 437 |
+
-------
|
| 438 |
+
MDDResult
|
| 439 |
+
"""
|
| 440 |
+
assert len(actual_feature_seqs) == 35, \
|
| 441 |
+
f"Expected 35 feature sequences, got {len(actual_feature_seqs)}"
|
| 442 |
+
assert len(target_phonemes) > 0, "target_phonemes must not be empty"
|
| 443 |
+
|
| 444 |
+
T = len(target_phonemes)
|
| 445 |
+
|
| 446 |
+
# Build canonical target feature sequences from the phoneme labels
|
| 447 |
+
target_feat_seqs: List[List[int]] = phoneme_sequence_to_feature_sequences(
|
| 448 |
+
target_phonemes
|
| 449 |
+
) # 35 lists, each of length T
|
| 450 |
+
|
| 451 |
+
# Actual lengths (for insertion counting)
|
| 452 |
+
actual_len = max((len(s) for s in actual_feature_seqs), default=0)
|
| 453 |
+
|
| 454 |
+
# Step 1: per-feature NW alignment β per target-position feature bits
|
| 455 |
+
aligned = _align_all_features(target_feat_seqs, actual_feature_seqs, T)
|
| 456 |
+
|
| 457 |
+
# Step 2: count structural errors
|
| 458 |
+
deletions = sum(1 for ap in aligned if ap.op == DELETE)
|
| 459 |
+
insertions = _count_insertions(actual_feature_seqs, actual_len, aligned)
|
| 460 |
+
|
| 461 |
+
# Step 3: score
|
| 462 |
+
utt_score, phoneme_scores = _score_utterance(aligned)
|
| 463 |
+
|
| 464 |
+
# Step 4: build error list
|
| 465 |
+
errors = _build_errors(aligned, target_phonemes)
|
| 466 |
+
|
| 467 |
+
# Step 5: aggregate feature error counts
|
| 468 |
+
feat_error_counts = _aggregate(errors)
|
| 469 |
+
|
| 470 |
+
return MDDResult(
|
| 471 |
+
utterance_score=utt_score,
|
| 472 |
+
phoneme_scores=phoneme_scores,
|
| 473 |
+
errors=errors,
|
| 474 |
+
aligned_positions=aligned,
|
| 475 |
+
feature_error_counts=feat_error_counts,
|
| 476 |
+
deletion_count=deletions,
|
| 477 |
+
insertion_count=insertions,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 482 |
+
# 11. CTC decode helper (use this on raw model logits)
|
| 483 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 484 |
+
|
| 485 |
+
def ctc_decode_feature_seqs(
|
| 486 |
+
logits: np.ndarray, # (T_audio, 71) β raw model output per frame
|
| 487 |
+
blank_idx: int = 70,
|
| 488 |
+
) -> List[List[int]]:
|
| 489 |
+
"""
|
| 490 |
+
Greedy CTC decode for a phonological feature model with 71 output nodes.
|
| 491 |
+
|
| 492 |
+
For each of the 35 features independently:
|
| 493 |
+
1. At each frame, pick argmax between pos_node (feat_i) and neg_node (feat_i+35)
|
| 494 |
+
(ignoring blank).
|
| 495 |
+
2. Collapse runs and remove frames where blank wins overall.
|
| 496 |
+
3. Return the sequence of 1s (+att) and 0s (-att).
|
| 497 |
+
|
| 498 |
+
Parameters
|
| 499 |
+
----------
|
| 500 |
+
logits : np.ndarray (T_audio, 71)
|
| 501 |
+
Raw model output before softmax. If you've already applied softmax,
|
| 502 |
+
pass probabilities β the argmax logic is identical.
|
| 503 |
+
blank_idx : int
|
| 504 |
+
Index of the shared blank node (default 70).
|
| 505 |
+
|
| 506 |
+
Returns
|
| 507 |
+
-------
|
| 508 |
+
List of 35 lists of int (0 or 1), CTC-decoded.
|
| 509 |
+
"""
|
| 510 |
+
T_audio = logits.shape[0]
|
| 511 |
+
feature_seqs: List[List[int]] = [[] for _ in range(35)]
|
| 512 |
+
|
| 513 |
+
for feat_i in range(35):
|
| 514 |
+
pos_node = feat_i # +att node
|
| 515 |
+
neg_node = feat_i + 35 # -att node
|
| 516 |
+
|
| 517 |
+
prev_label = None
|
| 518 |
+
for t in range(T_audio):
|
| 519 |
+
frame = logits[t]
|
| 520 |
+
best_overall = int(np.argmax(frame))
|
| 521 |
+
|
| 522 |
+
if best_overall == blank_idx:
|
| 523 |
+
prev_label = None # blank resets run
|
| 524 |
+
continue
|
| 525 |
+
|
| 526 |
+
# Among pos/neg for this feature, pick the winner
|
| 527 |
+
label = 1 if frame[pos_node] >= frame[neg_node] else 0
|
| 528 |
+
|
| 529 |
+
# CTC run-length collapse
|
| 530 |
+
if label != prev_label:
|
| 531 |
+
feature_seqs[feat_i].append(label)
|
| 532 |
+
prev_label = label
|
| 533 |
+
|
| 534 |
+
return feature_seqs
|
phonological_features.py
ADDED
|
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
phonological_features.py
|
| 3 |
+
========================
|
| 4 |
+
Defines the 35 phonological features from Table 1 of Shahin et al. (2025)
|
| 5 |
+
and provides the phoneme-to-feature mapping for the 39-phoneme CMU set.
|
| 6 |
+
|
| 7 |
+
Feature categories (paper Table 1):
|
| 8 |
+
Manners: consonant, sonorant, fricative, nasal, stop, approximant,
|
| 9 |
+
affricate, liquid, vowel, semivowel, continuant
|
| 10 |
+
Places: alveolar, palatal, dental, glottal, labial, velar, mid, high,
|
| 11 |
+
low, front, back, central, anterior, posterior, retroflex,
|
| 12 |
+
bilabial, coronal, dorsal
|
| 13 |
+
Others: long, short, monophthong, diphthong, round, voiced
|
| 14 |
+
|
| 15 |
+
The model output has 71 nodes: 35 (+att) + 35 (-att) + 1 (shared blank).
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 19 |
+
# The 35 phonological features (paper Table 1), in a fixed canonical order
|
| 20 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
+
PHONOLOGICAL_FEATURES = [
|
| 22 |
+
# Manners (11)
|
| 23 |
+
"consonant", "sonorant", "fricative", "nasal", "stop",
|
| 24 |
+
"approximant", "affricate", "liquid", "vowel", "semivowel", "continuant",
|
| 25 |
+
# Places (18)
|
| 26 |
+
"alveolar", "palatal", "dental", "glottal", "labial", "velar",
|
| 27 |
+
"mid", "high", "low", "front", "back", "central",
|
| 28 |
+
"anterior", "posterior", "retroflex", "bilabial", "coronal", "dorsal",
|
| 29 |
+
# Others (6)
|
| 30 |
+
"long", "short", "monophthong", "diphthong", "round", "voiced",
|
| 31 |
+
]
|
| 32 |
+
assert len(PHONOLOGICAL_FEATURES) == 35, "Must have exactly 35 features"
|
| 33 |
+
|
| 34 |
+
FEATURE_TO_IDX = {feat: i for i, feat in enumerate(PHONOLOGICAL_FEATURES)}
|
| 35 |
+
NUM_FEATURES = len(PHONOLOGICAL_FEATURES)
|
| 36 |
+
|
| 37 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
+
# Output node layout (paper Section 3.3):
|
| 39 |
+
# nodes 0..34 β +att for features 0..34
|
| 40 |
+
# nodes 35..69 β -att for features 0..34
|
| 41 |
+
# node 70 β shared blank
|
| 42 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
NUM_OUTPUT_NODES = 71 # 35 + 35 + 1
|
| 44 |
+
BLANK_IDX = 70
|
| 45 |
+
|
| 46 |
+
def feature_idx_to_pos_node(feat_idx: int) -> int:
|
| 47 |
+
"""Return output node index for +att of a given feature."""
|
| 48 |
+
return feat_idx
|
| 49 |
+
|
| 50 |
+
def feature_idx_to_neg_node(feat_idx: int) -> int:
|
| 51 |
+
"""Return output node index for -att of a given feature."""
|
| 52 |
+
return feat_idx + NUM_FEATURES
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
# CMU 39-phoneme set (TIMIT 61β39 reduced set used in the paper)
|
| 57 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 58 |
+
CMU_39_PHONEMES = [
|
| 59 |
+
"aa", "ae", "ah", "aw", "ay","ao",
|
| 60 |
+
"b", "ch", "d", "dh", "eh",
|
| 61 |
+
"er", "ey", "f", "g", "hh",
|
| 62 |
+
"ih", "iy", "jh", "k", "l",
|
| 63 |
+
"m", "n", "ng", "ow", "oy",
|
| 64 |
+
"p", "r", "s", "sh", "t",
|
| 65 |
+
"th", "uh", "uw", "v", "w",
|
| 66 |
+
"y", "z", "zh",
|
| 67 |
+
]
|
| 68 |
+
PHONEME_TO_IDX = {p: i for i, p in enumerate(CMU_39_PHONEMES)}
|
| 69 |
+
NUM_PHONEMES = len(CMU_39_PHONEMES) # 39
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 73 |
+
# Phoneme β phonological feature binary vector
|
| 74 |
+
# Each phoneme maps to a dict {feature_name: True/False}.
|
| 75 |
+
# Derived from standard phonological feature charts (Chomsky & Halle 1968,
|
| 76 |
+
# as referenced in the paper).
|
| 77 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
def _p(features_present: list[str]) -> dict[str, bool]:
|
| 79 |
+
"""Helper: build feature dict from list of present features."""
|
| 80 |
+
return {f: (f in features_present) for f in PHONOLOGICAL_FEATURES}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
PHONEME_FEATURES: dict[str, dict[str, bool]] = {
|
| 84 |
+
# ββ Stops ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 85 |
+
"p": _p(["consonant", "stop", "labial", "anterior", "bilabial"]),
|
| 86 |
+
"b": _p(["consonant", "stop", "labial", "anterior", "bilabial",
|
| 87 |
+
"voiced"]),
|
| 88 |
+
"t": _p(["consonant", "stop", "alveolar", "anterior", "coronal"]),
|
| 89 |
+
"d": _p(["consonant", "stop", "alveolar", "anterior", "coronal",
|
| 90 |
+
"voiced"]),
|
| 91 |
+
"k": _p(["consonant", "stop", "velar", "posterior", "dorsal"]),
|
| 92 |
+
"g": _p(["consonant", "stop", "velar", "posterior", "dorsal",
|
| 93 |
+
"voiced"]),
|
| 94 |
+
|
| 95 |
+
# ββ Fricatives βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 96 |
+
|
| 97 |
+
"f": _p(["consonant", "fricative", "continuant", "labial", "anterior"]),
|
| 98 |
+
"v": _p(["consonant", "fricative", "continuant", "labial", "anterior", "voiced"]),
|
| 99 |
+
"th": _p(["consonant", "fricative", "continuant", "dental", "anterior",
|
| 100 |
+
"coronal"]),
|
| 101 |
+
"dh": _p(["consonant", "fricative", "continuant", "dental", "anterior",
|
| 102 |
+
"coronal", "voiced"]),
|
| 103 |
+
"s": _p(["consonant", "fricative", "continuant", "alveolar", "anterior",
|
| 104 |
+
"coronal"]),
|
| 105 |
+
"z": _p(["consonant", "fricative", "continuant", "alveolar", "anterior",
|
| 106 |
+
"coronal", "voiced"]),
|
| 107 |
+
|
| 108 |
+
"sh": _p(["consonant", "fricative", "continuant", "palatal", "posterior",
|
| 109 |
+
"coronal"]),
|
| 110 |
+
|
| 111 |
+
"zh": _p(["consonant", "fricative", "continuant", "palatal", "posterior",
|
| 112 |
+
"coronal", "voiced"]),
|
| 113 |
+
|
| 114 |
+
"hh": _p(["consonant", "fricative", "continuant", "glottal", "posterior",
|
| 115 |
+
"dorsal"]),
|
| 116 |
+
|
| 117 |
+
# ββ Affricates βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 118 |
+
|
| 119 |
+
"ch": _p(["consonant", "affricate", "palatal", "posterior", "coronal"]),
|
| 120 |
+
"jh": _p(["consonant", "affricate", "palatal", "posterior", "coronal",
|
| 121 |
+
"voiced"]),
|
| 122 |
+
|
| 123 |
+
# ββ Nasals βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 124 |
+
|
| 125 |
+
"m": _p(["consonant", "sonorant", "nasal", "continuant", "labial",
|
| 126 |
+
"anterior", "bilabial", "voiced"]),
|
| 127 |
+
"n": _p(["consonant", "sonorant", "nasal", "continuant", "alveolar",
|
| 128 |
+
"anterior", "coronal", "voiced"]),
|
| 129 |
+
"ng": _p(["consonant", "sonorant", "nasal", "continuant", "velar",
|
| 130 |
+
"posterior", "dorsal", "voiced"]),
|
| 131 |
+
|
| 132 |
+
# ββ Liquids ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 133 |
+
"l": _p(["consonant", "sonorant", "approximant", "liquid", "continuant",
|
| 134 |
+
"alveolar", "anterior", "coronal", "voiced"]),
|
| 135 |
+
|
| 136 |
+
"r": _p(["consonant", "sonorant", "approximant", "liquid", "continuant",
|
| 137 |
+
"alveolar", "anterior", "retroflex", "coronal", "voiced"]),
|
| 138 |
+
|
| 139 |
+
# ββ Semivowels (Glides) ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 140 |
+
|
| 141 |
+
"w": _p(["sonorant", "approximant", "semivowel", "continuant", "labial",
|
| 142 |
+
"high", "anterior", "bilabial", "round", "voiced"]),
|
| 143 |
+
"y": _p(["sonorant", "approximant", "semivowel", "continuant", "palatal",
|
| 144 |
+
"high", "posterior", "coronal", "voiced"]),
|
| 145 |
+
|
| 146 |
+
# ββ Short Monophthong Vowels βββββββββββββββββββββββββββββββββββββββββββ
|
| 147 |
+
"ih": _p(["sonorant", "vowel", "continuant", "high", "front",
|
| 148 |
+
"short", "monophthong", "voiced"]),
|
| 149 |
+
|
| 150 |
+
"eh": _p(["sonorant", "vowel", "mid", "front",
|
| 151 |
+
"short", "monophthong", "voiced"]),
|
| 152 |
+
|
| 153 |
+
"ae": _p(["sonorant", "vowel", "continuant", "low", "front",
|
| 154 |
+
"long", "monophthong", "voiced"]),
|
| 155 |
+
|
| 156 |
+
"ah": _p(["sonorant", "vowel", "continuant", "mid", "back",
|
| 157 |
+
"short", "monophthong", "voiced"]),
|
| 158 |
+
|
| 159 |
+
"uh": _p(["sonorant", "vowel", "continuant", "high", "back",
|
| 160 |
+
"short", "monophthong", "round", "voiced"]),
|
| 161 |
+
|
| 162 |
+
# ββ Long Monophthong Vowels ββββββββββββββββββββββββββββββββββββββββββββ
|
| 163 |
+
"iy": _p(["sonorant", "vowel", "continuant", "high", "front",
|
| 164 |
+
"long", "monophthong", "voiced"]),
|
| 165 |
+
"aa": _p(["sonorant", "vowel", "continuant", "low", "back",
|
| 166 |
+
"long", "monophthong", "voiced"]),
|
| 167 |
+
"ao": _p(["sonorant", "vowel", "continuant", "mid", "back",
|
| 168 |
+
"long", "monophthong", "round", "voiced"]),
|
| 169 |
+
"er": _p(["sonorant", "vowel", "continuant", "mid", "central",
|
| 170 |
+
"retroflex", "short", "monophthong", "voiced"]),
|
| 171 |
+
"uw": _p(["sonorant", "vowel", "continuant", "high", "back",
|
| 172 |
+
"long", "monophthong", "round", "voiced"]),
|
| 173 |
+
|
| 174 |
+
# ββ Diphthongs βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 175 |
+
"ey": _p(["sonorant", "vowel", "continuant", "mid", "front",
|
| 176 |
+
"long", "diphthong", "voiced"]),
|
| 177 |
+
|
| 178 |
+
"aw": _p(["sonorant", "vowel", "continuant", "low", "central",
|
| 179 |
+
"long", "diphthong", "round", "voiced"]),
|
| 180 |
+
|
| 181 |
+
"ay": _p(["sonorant", "vowel", "low", "central",
|
| 182 |
+
"long", "diphthong", "voiced"]),
|
| 183 |
+
|
| 184 |
+
"oy": _p(["sonorant", "vowel", "continuant", "mid", "back",
|
| 185 |
+
"long", "diphthong", "round", "voiced"]),
|
| 186 |
+
|
| 187 |
+
"ow": _p(["sonorant", "vowel", "continuant", "mid", "central",
|
| 188 |
+
"long", "diphthong", "round", "voiced"]),
|
| 189 |
+
|
| 190 |
+
# ββ Silence ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 191 |
+
# Paper: "All silence labels were further removed leaving silence frames
|
| 192 |
+
# to be handled by the blank label."
|
| 193 |
+
"sil": _p([]), # all features absent; treated as blank during training
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
# Verify all 39 phonemes are covered.
|
| 197 |
+
# "sil" is intentionally extra β it is a fallback/blank placeholder, not a
|
| 198 |
+
# speech target, so it lives in PHONEME_FEATURES but not in CMU_39_PHONEMES.
|
| 199 |
+
_expected = set(CMU_39_PHONEMES) | {"sil"}
|
| 200 |
+
assert set(PHONEME_FEATURES.keys()) == _expected, (
|
| 201 |
+
f"Missing from PHONEME_FEATURES : {_expected - set(PHONEME_FEATURES.keys())}\n"
|
| 202 |
+
f"Unexpected in PHONEME_FEATURES: {set(PHONEME_FEATURES.keys()) - _expected}"
|
| 203 |
+
)
|
| 204 |
+
assert NUM_PHONEMES == 39, f"Expected 39 phonemes, got {NUM_PHONEMES}"
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def phoneme_to_feature_vector(phoneme: str) -> list[bool]:
|
| 208 |
+
"""Return a binary list of length 35 for a given phoneme."""
|
| 209 |
+
feat_dict = PHONEME_FEATURES.get(phoneme, PHONEME_FEATURES["sil"])
|
| 210 |
+
return [feat_dict[f] for f in PHONOLOGICAL_FEATURES]
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def phoneme_sequence_to_feature_sequences(
|
| 214 |
+
phonemes: list[str],
|
| 215 |
+
) -> list[list[int]]:
|
| 216 |
+
"""
|
| 217 |
+
Convert a phoneme sequence to N=35 binary label sequences.
|
| 218 |
+
|
| 219 |
+
Returns:
|
| 220 |
+
feature_seqs: list of 35 lists, each containing +att(1) or -att(0)
|
| 221 |
+
integers for each phoneme position.
|
| 222 |
+
"""
|
| 223 |
+
feature_seqs = [[] for _ in range(NUM_FEATURES)]
|
| 224 |
+
for ph in phonemes:
|
| 225 |
+
vec = phoneme_to_feature_vector(ph)
|
| 226 |
+
for feat_idx, present in enumerate(vec):
|
| 227 |
+
feature_seqs[feat_idx].append(1 if present else 0)
|
| 228 |
+
return feature_seqs
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def feature_sequences_to_ctc_labels(
|
| 232 |
+
feature_seqs: list[list[int]],
|
| 233 |
+
) -> list[list[int]]:
|
| 234 |
+
"""
|
| 235 |
+
Convert binary feature sequences (0/1) to CTC label indices.
|
| 236 |
+
|
| 237 |
+
For category i:
|
| 238 |
+
- +att β node index i (feature_idx_to_pos_node)
|
| 239 |
+
- -att β node index i + 35 (feature_idx_to_neg_node)
|
| 240 |
+
|
| 241 |
+
Returns:
|
| 242 |
+
ctc_labels: list of 35 lists of node indices (int)
|
| 243 |
+
"""
|
| 244 |
+
ctc_labels = []
|
| 245 |
+
for feat_idx, seq in enumerate(feature_seqs):
|
| 246 |
+
label_seq = []
|
| 247 |
+
for val in seq:
|
| 248 |
+
if val == 1:
|
| 249 |
+
label_seq.append(feature_idx_to_pos_node(feat_idx))
|
| 250 |
+
else:
|
| 251 |
+
label_seq.append(feature_idx_to_neg_node(feat_idx))
|
| 252 |
+
ctc_labels.append(label_seq)
|
| 253 |
+
return ctc_labels
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
scipy>=1.10.0
|
| 5 |
+
|
| 6 |
+
# Model
|
| 7 |
+
torch>=2.0.0
|
| 8 |
+
transformers>=4.40.0
|
| 9 |
+
huggingface_hub>=0.20.0
|
| 10 |
+
|
| 11 |
+
# Audio
|
| 12 |
+
librosa>=0.10.0
|
| 13 |
+
soundfile>=0.12.0
|
| 14 |
+
|
| 15 |
+
# Optional LLM rewriter
|
| 16 |
+
accelerate>=0.27.0
|
| 17 |
+
httpx>=0.25.0
|
wav2vec2_phonological.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
wav2vec2_phonological.py
|
| 3 |
+
========================
|
| 4 |
+
Phonological feature detection model strictly following Figure 2 of
|
| 5 |
+
Shahin et al. (Speech Communication, 2025).
|
| 6 |
+
|
| 7 |
+
Architecture (Fig. 2):
|
| 8 |
+
Raw Speech
|
| 9 |
+
β
|
| 10 |
+
βΌ
|
| 11 |
+
wav2vec2.0 (pre-trained, CNN encoder frozen)
|
| 12 |
+
ββ CNN Feature Extractor [FROZEN]
|
| 13 |
+
ββ Transformer [FINE-TUNED]
|
| 14 |
+
β
|
| 15 |
+
βΌ
|
| 16 |
+
Linear Layer (hidden_size β 71 nodes)
|
| 17 |
+
β
|
| 18 |
+
βΌ
|
| 19 |
+
SCTC-SB Loss (during training)
|
| 20 |
+
OR
|
| 21 |
+
argmax per category (during inference)
|
| 22 |
+
|
| 23 |
+
Output nodes:
|
| 24 |
+
0..34 β +att_i (presence of feature i)
|
| 25 |
+
35..69 β -att_i (absence of feature i)
|
| 26 |
+
70 β shared blank
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
from transformers import Wav2Vec2Model
|
| 32 |
+
from typing import Optional
|
| 33 |
+
|
| 34 |
+
from phonological_features import NUM_FEATURES, NUM_OUTPUT_NODES, BLANK_IDX
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class PhonologicalWav2Vec2(nn.Module):
|
| 38 |
+
"""
|
| 39 |
+
wav2vec2-based phonological feature detection model.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
pretrained_model_name (str): HuggingFace model ID.
|
| 43 |
+
Paper uses 'facebook/wav2vec2-large-robust' (best performing).
|
| 44 |
+
num_output_nodes (int): 71 = 35(+att) + 35(-att) + 1(blank).
|
| 45 |
+
freeze_cnn_encoder (bool): Whether to freeze CNN feature extractor.
|
| 46 |
+
Paper Section 5.2: "its parameters were fixed during fine-tuning".
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
pretrained_model_name: str = "facebook/wav2vec2-large-robust",
|
| 52 |
+
num_output_nodes: int = NUM_OUTPUT_NODES,
|
| 53 |
+
freeze_cnn_encoder: bool = True,
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.num_output_nodes = num_output_nodes
|
| 57 |
+
self.num_features = NUM_FEATURES
|
| 58 |
+
self.blank_idx = BLANK_IDX
|
| 59 |
+
|
| 60 |
+
# ββ Load pre-trained wav2vec2 βββββββββββββββββββββββββββββββββββββ
|
| 61 |
+
print(f"[PhonologicalWav2Vec2] Loading '{pretrained_model_name}' ...")
|
| 62 |
+
self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_model_name)
|
| 63 |
+
|
| 64 |
+
# ββ Freeze CNN encoder (feature extractor) ββββββββββββββββββββββββ
|
| 65 |
+
# Paper: "Except for the CNN encoder layer, the whole network was
|
| 66 |
+
# then fine-tuned"
|
| 67 |
+
if freeze_cnn_encoder:
|
| 68 |
+
self.wav2vec2.feature_extractor._freeze_parameters()
|
| 69 |
+
print("[PhonologicalWav2Vec2] CNN encoder FROZEN.")
|
| 70 |
+
|
| 71 |
+
# ββ Linear projection head (Fig. 2) ββββββββββββββββββββββββββββββ
|
| 72 |
+
# "A linear layer was added on top of the transformer module with
|
| 73 |
+
# number of nodes equals to the number of target phonological-features"
|
| 74 |
+
hidden_size = self.wav2vec2.config.hidden_size
|
| 75 |
+
self.classifier = nn.Linear(hidden_size, num_output_nodes)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
print(f"[PhonologicalWav2Vec2] hidden_size={hidden_size}, "
|
| 79 |
+
f"output_nodes={num_output_nodes}")
|
| 80 |
+
|
| 81 |
+
def forward(
|
| 82 |
+
self,
|
| 83 |
+
input_values: torch.Tensor,
|
| 84 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 85 |
+
apply_spec_augment: bool = False,
|
| 86 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 87 |
+
"""
|
| 88 |
+
Forward pass.
|
| 89 |
+
|
| 90 |
+
Paper Section 5.2: "SpecAugment was applied to the output of the
|
| 91 |
+
CNN encoder to add more variations to the training data."
|
| 92 |
+
|
| 93 |
+
Wav2Vec2Model natively applies SpecAugment between the CNN encoder
|
| 94 |
+
and the transformer when mask_time_indices is provided.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
logits : (B, T_frames, 71) raw logits
|
| 98 |
+
output_lengths: (B,) number of valid frames per batch item
|
| 99 |
+
"""
|
| 100 |
+
# Build mask_time_indices for SpecAugment if training
|
| 101 |
+
mask_time_indices = None
|
| 102 |
+
if apply_spec_augment and self.training:
|
| 103 |
+
# Compute output frame lengths to know valid T per item
|
| 104 |
+
if attention_mask is not None:
|
| 105 |
+
feat_lengths = self._get_feat_extract_output_lengths(
|
| 106 |
+
attention_mask.sum(dim=1)
|
| 107 |
+
)
|
| 108 |
+
else:
|
| 109 |
+
B, T_audio = input_values.shape
|
| 110 |
+
ones = torch.ones(B, dtype=torch.long, device=input_values.device) * T_audio
|
| 111 |
+
feat_lengths = self._get_feat_extract_output_lengths(ones)
|
| 112 |
+
|
| 113 |
+
B = input_values.shape[0]
|
| 114 |
+
T = int(feat_lengths.max().item())
|
| 115 |
+
|
| 116 |
+
# Build boolean mask: mask up to 10% of valid frames per utterance
|
| 117 |
+
mask_time_indices = torch.zeros(B, T, dtype=torch.bool,
|
| 118 |
+
device=input_values.device)
|
| 119 |
+
t_len = max(1, int(T * 0.10))
|
| 120 |
+
for b in range(B):
|
| 121 |
+
valid = int(feat_lengths[b].item())
|
| 122 |
+
if valid > t_len:
|
| 123 |
+
t0 = torch.randint(0, valid - t_len, (1,)).item()
|
| 124 |
+
mask_time_indices[b, t0:t0 + t_len] = True
|
| 125 |
+
|
| 126 |
+
outputs = self.wav2vec2(
|
| 127 |
+
input_values=input_values,
|
| 128 |
+
attention_mask=attention_mask,
|
| 129 |
+
mask_time_indices=mask_time_indices,
|
| 130 |
+
output_hidden_states=False,
|
| 131 |
+
)
|
| 132 |
+
hidden_states = outputs.last_hidden_state # (B, T, 1024)
|
| 133 |
+
logits = self.classifier(hidden_states)
|
| 134 |
+
|
| 135 |
+
if attention_mask is not None:
|
| 136 |
+
output_lengths = self._get_feat_extract_output_lengths(
|
| 137 |
+
attention_mask.sum(dim=1)
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
B, T_audio = input_values.shape
|
| 141 |
+
ones = torch.ones(B, dtype=torch.long, device=input_values.device) * T_audio
|
| 142 |
+
output_lengths = self._get_feat_extract_output_lengths(ones)
|
| 143 |
+
|
| 144 |
+
return logits, output_lengths
|
| 145 |
+
|
| 146 |
+
def _get_feat_extract_output_lengths(
|
| 147 |
+
self, input_lengths: torch.Tensor
|
| 148 |
+
) -> torch.Tensor:
|
| 149 |
+
return self.wav2vec2._get_feat_extract_output_lengths(input_lengths)
|
| 150 |
+
|
| 151 |
+
@torch.no_grad()
|
| 152 |
+
def decode(
|
| 153 |
+
self,
|
| 154 |
+
logits: torch.Tensor, # (B, T, 71) or (T, 71) for single item
|
| 155 |
+
output_lengths: Optional[torch.Tensor] = None, # (B,) valid frame counts
|
| 156 |
+
) -> list[list[list[int]]]:
|
| 157 |
+
"""
|
| 158 |
+
Greedy CTC decoding per category.
|
| 159 |
+
|
| 160 |
+
For each feature category i, applies argmax over the 3-node slice
|
| 161 |
+
[pos_i, neg_i, blank] and collapses repeated labels + blanks.
|
| 162 |
+
|
| 163 |
+
Paper Section 3.3, Eq. 7:
|
| 164 |
+
h_i(x) = argmax_j y^t_{i,j}
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
logits : (B, T, 71) raw model logits
|
| 168 |
+
output_lengths: (B,) number of valid (non-padded) frames per item.
|
| 169 |
+
If None, all T frames are used (may include padding noise).
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
decoded: [B][35] list of decoded label sequences
|
| 173 |
+
Each label sequence contains +att(True) or -att(False)
|
| 174 |
+
"""
|
| 175 |
+
if logits.dim() == 2:
|
| 176 |
+
logits = logits.unsqueeze(0)
|
| 177 |
+
|
| 178 |
+
B, T, _ = logits.shape
|
| 179 |
+
decoded_batch = []
|
| 180 |
+
|
| 181 |
+
for b in range(B):
|
| 182 |
+
valid_T = T if output_lengths is None else int(output_lengths[b].item())
|
| 183 |
+
decoded_features = []
|
| 184 |
+
for feat_idx in range(self.num_features):
|
| 185 |
+
pos_node = feat_idx
|
| 186 |
+
neg_node = feat_idx + self.num_features
|
| 187 |
+
|
| 188 |
+
# Extract 3-node slice over valid frames only: (valid_T, 3)
|
| 189 |
+
cat_logits = torch.stack([
|
| 190 |
+
logits[b, :valid_T, pos_node],
|
| 191 |
+
logits[b, :valid_T, neg_node],
|
| 192 |
+
logits[b, :valid_T, self.blank_idx],
|
| 193 |
+
], dim=-1) # (valid_T, 3)
|
| 194 |
+
|
| 195 |
+
# Argmax: 0=+att, 1=-att, 2=blank
|
| 196 |
+
preds = cat_logits.argmax(dim=-1) # (valid_T,)
|
| 197 |
+
|
| 198 |
+
# CTC collapse: remove blanks and repeated labels
|
| 199 |
+
collapsed = []
|
| 200 |
+
prev = -1
|
| 201 |
+
for p in preds.tolist():
|
| 202 |
+
if p == 2: # blank
|
| 203 |
+
prev = -1
|
| 204 |
+
continue
|
| 205 |
+
if p != prev:
|
| 206 |
+
collapsed.append(p == 0) # True=+att, False=-att
|
| 207 |
+
prev = p
|
| 208 |
+
|
| 209 |
+
decoded_features.append(collapsed)
|
| 210 |
+
decoded_batch.append(decoded_features)
|
| 211 |
+
|
| 212 |
+
return decoded_batch
|
| 213 |
+
|
| 214 |
+
def count_parameters(self) -> dict:
|
| 215 |
+
"""Count trainable vs frozen parameters."""
|
| 216 |
+
total = sum(p.numel() for p in self.parameters())
|
| 217 |
+
trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 218 |
+
frozen = total - trainable
|
| 219 |
+
return {"total": total, "trainable": trainable, "frozen": frozen}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 223 |
+
# Phoneme-level baseline model (for comparison, paper Section 3)
|
| 224 |
+
# Same architecture but with 40 output nodes (39 phonemes + 1 blank)
|
| 225 |
+
# and standard CTC loss
|
| 226 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 227 |
+
class PhonemeLevelWav2Vec2(nn.Module):
|
| 228 |
+
"""
|
| 229 |
+
Phoneme-level MDD baseline (paper Section 3, Fig. 1 top branch).
|
| 230 |
+
Uses standard CTC with 39 phonemes + blank.
|
| 231 |
+
"""
|
| 232 |
+
|
| 233 |
+
def __init__(
|
| 234 |
+
self,
|
| 235 |
+
pretrained_model_name: str = "facebook/wav2vec2-large-robust",
|
| 236 |
+
num_phonemes: int = 39,
|
| 237 |
+
freeze_cnn_encoder: bool = True,
|
| 238 |
+
):
|
| 239 |
+
super().__init__()
|
| 240 |
+
self.num_phonemes = num_phonemes
|
| 241 |
+
self.blank_idx = num_phonemes # index 39 = blank
|
| 242 |
+
|
| 243 |
+
self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_model_name)
|
| 244 |
+
if freeze_cnn_encoder:
|
| 245 |
+
self.wav2vec2.feature_extractor._freeze_parameters()
|
| 246 |
+
|
| 247 |
+
hidden_size = self.wav2vec2.config.hidden_size
|
| 248 |
+
# 40 nodes: 39 phonemes + 1 blank
|
| 249 |
+
self.classifier = nn.Linear(hidden_size, num_phonemes + 1)
|
| 250 |
+
|
| 251 |
+
def forward(self, input_values, attention_mask=None):
|
| 252 |
+
outputs = self.wav2vec2(
|
| 253 |
+
input_values=input_values,
|
| 254 |
+
attention_mask=attention_mask,
|
| 255 |
+
)
|
| 256 |
+
hidden_states = outputs.last_hidden_state
|
| 257 |
+
logits = self.classifier(hidden_states)
|
| 258 |
+
|
| 259 |
+
if attention_mask is not None:
|
| 260 |
+
output_lengths = self.wav2vec2._get_feat_extract_output_lengths(
|
| 261 |
+
attention_mask.sum(dim=1))
|
| 262 |
+
else:
|
| 263 |
+
B, T = input_values.shape
|
| 264 |
+
ones = torch.ones(B, dtype=torch.long, device=input_values.device) * T
|
| 265 |
+
output_lengths = self.wav2vec2._get_feat_extract_output_lengths(ones)
|
| 266 |
+
|
| 267 |
+
return logits, output_lengths
|