heldtomaturity commited on
Commit
0515ef3
Β·
1 Parent(s): c35de4f

initial app deploy

Browse files
app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Mispronunciation Detection & Diagnosis β€” HuggingFace Space
3
+ ===========================================================
4
+ Wires together:
5
+ 1. PhonologicalWav2Vec2 (your best_model.pt, loaded once at cold start)
6
+ 2. MDD engine (per-feature NW alignment β†’ errors + score)
7
+ 3. Feedback generator (rule engine + optional LLM rewriter)
8
+
9
+ Environment variables to set in Space β†’ Settings β†’ Variables and secrets:
10
+ HF_TOKEN (secret) β€” read token for your private model repo
11
+ HF_MODEL_REPO (variable) β€” e.g. "Backlighteu/phonological-mdd"
12
+ HF_MODEL_FILENAME (variable) β€” e.g. "best_model.pt" (default)
13
+ """
14
+
15
+ import os
16
+ import json
17
+ import torch
18
+ import numpy as np
19
+ import gradio as gr
20
+ import librosa
21
+
22
+ from huggingface_hub import hf_hub_download, snapshot_download
23
+ from transformers import Wav2Vec2FeatureExtractor
24
+
25
+ from wav2vec2_phonological import PhonologicalWav2Vec2
26
+ from mdd_engine import run_mdd
27
+ from feedback_generator import generate_feedback
28
+ from phonological_features import (
29
+ CMU_39_PHONEMES,
30
+ )
31
+
32
+ # ─────────────────────────────────────────────────────────────────────────────
33
+ # 1. Model β€” loaded once at cold start, reused for every request
34
+ # ─────────────────────────────────────────────────────────────────────────────
35
+
36
+ _model = None
37
+ _feature_extractor = None
38
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
39
+
40
+ PRETRAINED_BASE = "facebook/wav2vec2-large-robust"
41
+ MODEL_REPO = os.environ.get("HF_MODEL_REPO", "Backlighteu/phonological-mdd")
42
+ MODEL_FILENAME = os.environ.get("HF_MODEL_FILENAME", "best_model.pt")
43
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
44
+
45
+
46
+ def load_model():
47
+ global _model, _feature_extractor
48
+
49
+ if _model is not None:
50
+ return
51
+
52
+ # Download entire repo into ./model_cache once, then load from disk.
53
+ # hf_hub_download checks cache first β€” no re-download if already present.
54
+ print(f"[startup] Caching {MODEL_REPO} to ./model_cache ...")
55
+ snapshot_download(
56
+ repo_id=MODEL_REPO,
57
+ token=HF_TOKEN,
58
+ local_dir="./model_cache",
59
+ )
60
+ weights_path = "./model_cache/best_model.pt"
61
+ print(f"[startup] Loading weights from {weights_path}")
62
+
63
+ model = PhonologicalWav2Vec2(
64
+ pretrained_model_name=PRETRAINED_BASE,
65
+ num_output_nodes=71,
66
+ freeze_cnn_encoder=True,
67
+ )
68
+
69
+ state_dict = torch.load(weights_path, map_location=_device)
70
+ model.load_state_dict(state_dict)
71
+ model.to(_device)
72
+ model.eval()
73
+ _model = model
74
+ print(f"[startup] Model ready on {_device}.")
75
+
76
+ print(f"[startup] Loading feature extractor from '{PRETRAINED_BASE}' ...")
77
+ _feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(PRETRAINED_BASE)
78
+ print("[startup] Feature extractor ready.")
79
+
80
+ # ─────────────────────────────────────────────────────────────────────────────
81
+ # 2. Audio β†’ decoded feature sequences
82
+ # ─────────────────────────────────────────────────────────────────────────────
83
+
84
+ TARGET_SR = 16_000
85
+
86
+
87
+ def decode_audio(audio_path: str) -> list:
88
+ """
89
+ Load audio, run the phonological model, return CTC-decoded feature seqs.
90
+
91
+ Returns
92
+ -------
93
+ actual_feature_seqs : list of 35 lists of int (0 or 1)
94
+ CTC-decoded +att / -att sequence for each of the 35 features.
95
+ """
96
+ load_model()
97
+
98
+ waveform, sr = librosa.load(audio_path, sr=TARGET_SR, mono=True)
99
+ waveform = waveform.astype(np.float32)
100
+
101
+ inputs = _feature_extractor(
102
+ waveform,
103
+ sampling_rate=TARGET_SR,
104
+ return_tensors="pt",
105
+ padding=True,
106
+ )
107
+
108
+ input_values = inputs.input_values.to(_device)
109
+ attention_mask = inputs.get("attention_mask")
110
+ if attention_mask is not None:
111
+ attention_mask = attention_mask.to(_device)
112
+
113
+ with torch.no_grad():
114
+ logits, output_lengths = _model(
115
+ input_values,
116
+ attention_mask,
117
+ apply_spec_augment=False,
118
+ )
119
+
120
+ # model.decode() returns list[B][35][list[bool]] β€” True=+att, False=-att
121
+ decoded_batch = _model.decode(logits, output_lengths)
122
+ decoded_35 = decoded_batch[0] # [35][list[bool]]
123
+
124
+ # Convert bool β†’ int (1/0)
125
+ actual_feature_seqs = [
126
+ [1 if v else 0 for v in feat_seq]
127
+ for feat_seq in decoded_35
128
+ ]
129
+
130
+ return actual_feature_seqs
131
+
132
+
133
+ # ──────────────────���──────────────────────────────────────────────────────────
134
+ # 3. Text β†’ canonical phoneme sequence
135
+ # ─────────────────────────────────────────────────────────────────────────────
136
+
137
+ _VALID_PHONEMES = set(CMU_39_PHONEMES) | {"sil"}
138
+
139
+
140
+ def parse_phoneme_input(text: str) -> list:
141
+ """
142
+ Accept space-separated CMU ARPAbet tokens typed by the user.
143
+ Unknown tokens are skipped with a warning.
144
+ """
145
+ tokens = text.lower().split()
146
+ valid, skipped = [], []
147
+ for t in tokens:
148
+ if t in _VALID_PHONEMES:
149
+ valid.append(t)
150
+ else:
151
+ skipped.append(t)
152
+ if skipped:
153
+ print(f"[warning] Unrecognised tokens skipped: {skipped}")
154
+ return valid if valid else ["sil"]
155
+
156
+
157
+ # ─────────────────────────────────────────────────────────────────────────────
158
+ # 4. Gradio processing function
159
+ # ─────────────────────────────────────────────────────────────────────────────
160
+
161
+ def process(audio_input, script_text, use_llm, max_issues):
162
+ if audio_input is None:
163
+ return "Please record or upload audio first.", "", "{}"
164
+
165
+ script_text = script_text.strip()
166
+ if not script_text:
167
+ return (
168
+ "Please type the target sentence as ARPAbet phoneme tokens.\n"
169
+ "Example: `dh ae k ae t` for 'the cat'",
170
+ "", "{}",
171
+ )
172
+
173
+ try:
174
+ actual_feature_seqs = decode_audio(audio_input)
175
+ except Exception as e:
176
+ return f"Audio processing error: {e}", "", "{}"
177
+
178
+ target_phonemes = parse_phoneme_input(script_text)
179
+
180
+ try:
181
+ result = run_mdd(
182
+ actual_feature_seqs=actual_feature_seqs,
183
+ target_phonemes=target_phonemes,
184
+ )
185
+ except Exception as e:
186
+ return f"MDD engine error: {e}", "", "{}"
187
+
188
+ feedback_dict = generate_feedback(
189
+ result,
190
+ use_llm=use_llm,
191
+ max_issues=int(max_issues),
192
+ )
193
+
194
+ score = feedback_dict["score"]
195
+ main_feedback = (
196
+ f"**Pronunciation Score: {score}/100**\n\n"
197
+ + feedback_dict["final_feedback"]
198
+ )
199
+
200
+ detail_lines = ["### Per-phoneme detail\n"]
201
+ for e in feedback_dict["error_summary"]:
202
+ deletion_tag = " *(deleted)*" if e.get("is_deletion") else ""
203
+ detail_lines.append(
204
+ f"- **/{e['target']}/** (pos {e['position']}){deletion_tag}: "
205
+ f"severity=`{e['severity']}`, accuracy={e['accuracy']:.0%}\n"
206
+ f" - Missing: {', '.join(e['missing_features']) or 'β€”'}\n"
207
+ f" - Extra: {', '.join(e['extra_features']) or 'β€”'}"
208
+ )
209
+ if not feedback_dict["error_summary"]:
210
+ detail_lines.append("No feature-level errors detected β€” great pronunciation!")
211
+
212
+ detail_text = "\n".join(detail_lines)
213
+
214
+ json_output = json.dumps({
215
+ "score": feedback_dict["score"],
216
+ "deletion_count": result.deletion_count,
217
+ "insertion_count": result.insertion_count,
218
+ "feature_error_counts": feedback_dict["feature_error_counts"],
219
+ "rules_triggered": feedback_dict["rules_triggered"],
220
+ "target_phonemes": target_phonemes,
221
+ "actual_seq_lengths": [len(s) for s in actual_feature_seqs],
222
+ }, indent=2)
223
+
224
+ return main_feedback, detail_text, json_output
225
+
226
+
227
+ # ─────────────────────────────────────────────────────────────────────────────
228
+ # 5. Gradio UI
229
+ # ─────────────────────────────────────────────────────────────────────────────
230
+
231
+ VALID_PHONEME_LIST = ", ".join(sorted(CMU_39_PHONEMES))
232
+
233
+ with gr.Blocks(title="Pronunciation Coach", theme=gr.themes.Soft()) as demo:
234
+ gr.Markdown(
235
+ """
236
+ # Pronunciation Coach
237
+ Speak a sentence, type what you meant to say as **ARPAbet phoneme tokens**,
238
+ and get phonological-feature-level feedback with articulation tips.
239
+ """
240
+ )
241
+
242
+ with gr.Row():
243
+ with gr.Column(scale=1):
244
+ audio_input = gr.Audio(
245
+ sources=["microphone", "upload"],
246
+ type="filepath",
247
+ label="Your speech",
248
+ )
249
+ script_input = gr.Textbox(
250
+ label="Target sentence β€” space-separated ARPAbet tokens",
251
+ placeholder="e.g. dh ae k ae t (= 'the cat')",
252
+ lines=2,
253
+ )
254
+ with gr.Accordion("Valid phoneme tokens", open=False):
255
+ gr.Markdown(f"`{VALID_PHONEME_LIST}`")
256
+ with gr.Row():
257
+ use_llm = gr.Checkbox(value=False, label="LLM feedback rewriter")
258
+ max_issues = gr.Slider(1, 5, value=3, step=1, label="Max issues shown")
259
+ submit_btn = gr.Button("Analyse", variant="primary")
260
+
261
+ with gr.Column(scale=2):
262
+ feedback_out = gr.Markdown(label="Coaching feedback")
263
+ with gr.Accordion("Per-phoneme detail", open=False):
264
+ detail_out = gr.Markdown()
265
+ with gr.Accordion("Raw JSON (developers)", open=False):
266
+ json_out = gr.Code(language="json")
267
+
268
+ submit_btn.click(
269
+ fn=process,
270
+ inputs=[audio_input, script_input, use_llm, max_issues],
271
+ outputs=[feedback_out, detail_out, json_out],
272
+ )
273
+
274
+ gr.Markdown(
275
+ """
276
+ ---
277
+ **How to enter the target sentence:**
278
+ Convert your sentence to ARPAbet using the
279
+ [CMU Pronouncing Dictionary](http://www.speech.cs.cmu.edu/cgi-bin/cmudict)
280
+ then paste the space-separated tokens here.
281
+ Example: *"the cat sat"* β†’ `dh ax k ae t s ae t`
282
+ """
283
+ )
284
+
285
+
286
+ if __name__ == "__main__":
287
+ demo.launch()
feedback_generator.py ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Feedback Generator
3
+ ==================
4
+ Two-layer system:
5
+ Layer 1 β€” Rule engine: maps specific feature errors to expert articulatory cues
6
+ Layer 2 β€” LLM rewriter: takes rule outputs and rewrites them into natural,
7
+ encouraging coach-like language via a lightweight local model
8
+ (or cloud fallback).
9
+
10
+ The rule templates are the ground truth; the LLM only adds warmth and fluency.
11
+ """
12
+
13
+ from __future__ import annotations
14
+ import os
15
+ import json
16
+ import textwrap
17
+ from dataclasses import dataclass
18
+ from typing import List, Dict, Optional, Tuple
19
+ from mdd_engine import PhonemeError, MDDResult, FEATURE_NAMES
20
+
21
+
22
+ # ──────────────────────────────────────────────
23
+ # 1. Articulatory feedback rule bank
24
+ # ──────────────────────────────────────────────
25
+ # Each rule = {trigger_features, direction, tip, drill, self_check}
26
+ # direction: "missing" | "extra" | "both"
27
+
28
+ FEATURE_RULES: List[Dict] = [
29
+ # ── VOICING (Others group) ────────────────────────────────────────────
30
+ {
31
+ "features": ["voiced"],
32
+ "direction": "missing",
33
+ "tip": (
34
+ "Your vocal cords are not vibrating when they should be. "
35
+ "Place two fingers lightly on your throat (the Adam's apple area). "
36
+ "Now say the sound β€” if you feel vibration, you've got it. "
37
+ "Try humming first ('mmm'), then slide into the target sound."
38
+ ),
39
+ "drill": "Practice pairs: /f/ β†’ /v/, /s/ β†’ /z/, /p/ β†’ /b/. "
40
+ "Feel the buzz turn on for the second sound each time.",
41
+ "self_check": "Put your hand on your throat. You should feel a gentle buzz.",
42
+ },
43
+ {
44
+ "features": ["voiced"],
45
+ "direction": "extra",
46
+ "tip": (
47
+ "You are voicing a sound that should be voiceless β€” your vocal cords "
48
+ "are buzzing when they should be still. "
49
+ "Whisper the sound first to train your cords to stay quiet, "
50
+ "then gradually add breath pressure without the buzz."
51
+ ),
52
+ "drill": "Whisper-shout drill: whisper /p/, /t/, /k/, /f/, /s/ ten times.",
53
+ "self_check": "Put your hand on your throat. It should feel still, no vibration.",
54
+ },
55
+
56
+ # ── MANNER: STOP ─────────────────────────────────────────────────────
57
+ {
58
+ "features": ["stop"],
59
+ "direction": "missing",
60
+ "tip": (
61
+ "This sound needs a full closure in your mouth β€” air must be completely "
62
+ "blocked and then released in a burst. "
63
+ "Your tongue or lips are not making a tight enough seal, letting air trickle "
64
+ "through instead of building up pressure."
65
+ ),
66
+ "drill": "Tap your fingers on the desk for each stop: /p/ – /t/ – /k/. "
67
+ "Feel the 'pop' as pressure releases each time.",
68
+ "self_check": "Before the release, you should feel air pressure building behind the closure.",
69
+ },
70
+ {
71
+ "features": ["stop"],
72
+ "direction": "extra",
73
+ "tip": (
74
+ "You are closing your airway completely when the sound should be continuous. "
75
+ "Relax the articulators and keep a small opening so air can flow through "
76
+ "without a burst."
77
+ ),
78
+ "drill": "Say /s/ and /f/ β€” feel the continuous uninterrupted airflow, no pop.",
79
+ "self_check": "You should hear no 'pop' or sudden release β€” just steady air.",
80
+ },
81
+
82
+ # ── MANNER: FRICATIVE ────────────────────────────────────────────────
83
+ {
84
+ "features": ["fricative"],
85
+ "direction": "missing",
86
+ "tip": (
87
+ "This sound requires turbulent airflow β€” a hissing or buzzing quality. "
88
+ "Narrow the passage between your tongue (or lips) and the articulators just enough "
89
+ "that the air becomes turbulent. Too wide gives a vowel; full closure gives a stop."
90
+ ),
91
+ "drill": "Hold /s/, /f/, /sh/ for three full seconds each. Feel the continuous friction.",
92
+ "self_check": "You should hear a clear hissing or buzzing sound throughout, not silence or a pop.",
93
+ },
94
+
95
+ # ── MANNER: NASAL ─────────────────────────────────────────────────────
96
+ {
97
+ "features": ["nasal"],
98
+ "direction": "missing",
99
+ "tip": (
100
+ "This sound requires airflow through your nose. "
101
+ "Pinch your nostrils closed β€” if the sound changes dramatically, "
102
+ "you were accidentally blocking nasal airflow. "
103
+ "Let air flow freely through your nose as you make the sound."
104
+ ),
105
+ "drill": "Alternate: hum 'mmm' (nasal), then 'bbb' (not nasal). Feel the difference.",
106
+ "self_check": "Pinch your nose lightly β€” a nasal sound will feel 'stuffed up' when blocked.",
107
+ },
108
+ {
109
+ "features": ["nasal"],
110
+ "direction": "extra",
111
+ "tip": (
112
+ "Your sound has unwanted nasality β€” air is leaking through your nose. "
113
+ "Practice lifting the soft palate by saying 'uh-oh' firmly, then keep that "
114
+ "lifted feeling while producing the target sound."
115
+ ),
116
+ "drill": "Say 'back β€” bank', 'bad β€” band'. The first word of each pair is not nasal.",
117
+ "self_check": "Hold a mirror under your nose β€” it should not fog up.",
118
+ },
119
+
120
+ # ── MANNER: AFFRICATE ────────────────────────────────────────────────
121
+ {
122
+ "features": ["affricate"],
123
+ "direction": "missing",
124
+ "tip": (
125
+ "An affricate starts with a complete closure (like a stop) then releases "
126
+ "into a fricative β€” think of /ch/ in 'church' or /jh/ in 'judge'. "
127
+ "You are either skipping the closure or the friction release. "
128
+ "Make sure you feel both: a tight seal followed by a hissing release."
129
+ ),
130
+ "drill": "Say 'ch-ch-ch' rapidly, feeling the tap-and-hiss for each one.",
131
+ "self_check": "You should feel a brief closure then turbulent airflow β€” two phases in one sound.",
132
+ },
133
+
134
+ # ── MANNER: APPROXIMANT / LIQUID ─────────────────────────────────────
135
+ {
136
+ "features": ["approximant", "liquid"],
137
+ "direction": "missing",
138
+ "tip": (
139
+ "This sound (/l/, /r/, /w/, /y/) needs your articulators to approach each other "
140
+ "closely without fully touching or creating friction. "
141
+ "Relax the contact β€” you may be pressing too hard and creating a stop, "
142
+ "or not shaping your mouth precisely enough."
143
+ ),
144
+ "drill": "Say 'la-la-la' for /l/ and 'ra-ra-ra' for /r/ slowly, keeping the tongue light.",
145
+ "self_check": "There should be no pop and no hiss β€” just a smooth, resonant glide.",
146
+ },
147
+
148
+ # ── MANNER: CONTINUANT ───────────────────────────────────────────────
149
+ {
150
+ "features": ["continuant"],
151
+ "direction": "missing",
152
+ "tip": (
153
+ "This sound should have continuous, uninterrupted airflow β€” it is not a stop. "
154
+ "Keep your airway open and let air flow through for the full duration of the sound."
155
+ ),
156
+ "drill": "Sustain /s/, /m/, /l/ or /v/ for three seconds without any interruption.",
157
+ "self_check": "You should be able to hold the sound indefinitely without cutting off air.",
158
+ },
159
+
160
+ # ── PLACE: BILABIAL ──────────────────────────────────────────────────
161
+ {
162
+ "features": ["bilabial"],
163
+ "direction": "missing",
164
+ "tip": (
165
+ "This sound needs both lips pressed firmly together (/p/, /b/, /m/). "
166
+ "You may be making it with only one lip or further back in the mouth. "
167
+ "Press your lips together completely before releasing."
168
+ ),
169
+ "drill": "Say 'pa-ba-ma' ten times, exaggerating full lip closure each time.",
170
+ "self_check": "Watch yourself in a mirror β€” both lips should close completely.",
171
+ },
172
+
173
+ # ── PLACE: LABIAL (labiodental /f/, /v/) ────────────────────────────
174
+ {
175
+ "features": ["labial"],
176
+ "direction": "missing",
177
+ "tip": (
178
+ "This sound needs your lips to be active β€” either both lips together (bilabial: /p/, /b/, /m/) "
179
+ "or upper teeth touching the lower lip (labiodental: /f/, /v/). "
180
+ "You may be making the sound too far back with the tongue."
181
+ ),
182
+ "drill": "Exaggerate lip contact. Say 'pop', 'bob', 'mom', 'five', 'very' in front of a mirror.",
183
+ "self_check": "Watch yourself in a mirror β€” you should see clear lip movement.",
184
+ },
185
+
186
+ # ── PLACE: DENTAL ────────────────────────────────────────────────────
187
+ {
188
+ "features": ["dental"],
189
+ "direction": "missing",
190
+ "tip": (
191
+ "This sound (/th/, /dh/) requires your tongue tip to be right at or between your teeth. "
192
+ "Stick your tongue tip just between your upper and lower front teeth "
193
+ "and let air flow over it."
194
+ ),
195
+ "drill": "Say 'think' and 'this' slowly, deliberately placing your tongue between your teeth each time.",
196
+ "self_check": "You should feel your tongue tip touching the edges of your front teeth.",
197
+ },
198
+
199
+ # ── PLACE: ALVEOLAR ──────────────────────────────────────────────────
200
+ {
201
+ "features": ["alveolar"],
202
+ "direction": "missing",
203
+ "tip": (
204
+ "Your tongue tip needs to touch the alveolar ridge β€” the hard bump just behind "
205
+ "your upper front teeth. "
206
+ "This is the target for /t/, /d/, /n/, /s/, /z/, /l/. "
207
+ "You may be placing your tongue too far back or too far forward."
208
+ ),
209
+ "drill": "Touch the ridge behind your upper teeth with your tongue tip and feel it. "
210
+ "Now tap /t/ ten times, always returning to that exact spot.",
211
+ "self_check": "Is your tongue tip touching the hard ridge β€” not the teeth and not the palate?",
212
+ },
213
+
214
+ # ── PLACE: PALATAL ────────────────────────────────────────────────────
215
+ {
216
+ "features": ["palatal"],
217
+ "direction": "missing",
218
+ "tip": (
219
+ "This sound (/sh/, /zh/, /ch/, /jh/, /y/) is made with the tongue body raised "
220
+ "toward the hard palate β€” the hard, bony roof just behind the alveolar ridge. "
221
+ "Move your tongue further back from the teeth and arch it upward."
222
+ ),
223
+ "drill": "Say 'she', 'measure', 'church' β€” feel your tongue body rise toward the hard palate.",
224
+ "self_check": "You should feel your tongue broadly touching or approaching the middle of the roof.",
225
+ },
226
+
227
+ # ── PLACE: VELAR ──────────────────────────────────────────────────────
228
+ {
229
+ "features": ["velar"],
230
+ "direction": "missing",
231
+ "tip": (
232
+ "This sound (/k/, /g/, /ng/) is made at the back of your mouth, with the back of your tongue "
233
+ "touching the soft palate (velum). "
234
+ "Try gargling β€” that back-of-tongue raised position is exactly what you need."
235
+ ),
236
+ "drill": "Say 'king', 'ring', 'sing' β€” focus on the back-of-tongue closure each time.",
237
+ "self_check": "You should feel the back of your tongue lift and meet the soft palate.",
238
+ },
239
+
240
+ # ── PLACE: GLOTTAL ────────────────────────────────────────────────────
241
+ {
242
+ "features": ["glottal"],
243
+ "direction": "missing",
244
+ "tip": (
245
+ "This sound (/hh/) is made deep in the throat at the vocal folds. "
246
+ "Think of fogging up a mirror β€” breathe out gently with a completely open throat. "
247
+ "No tongue or lip constriction should be involved."
248
+ ),
249
+ "drill": "Say 'hi', 'hat', 'hot' β€” the /h/ should feel like a breath, not a friction sound.",
250
+ "self_check": "Place a hand on your throat β€” you should feel warmth from breath, not a hiss.",
251
+ },
252
+
253
+ # ── PLACE: RETROFLEX ─────────────────────────────────────────────────
254
+ {
255
+ "features": ["retroflex"],
256
+ "direction": "missing",
257
+ "tip": (
258
+ "This sound (/r/ in English, /er/) requires your tongue tip to curl back toward "
259
+ "the back of the alveolar ridge without touching anything, or to bunch up in the "
260
+ "center of your mouth. "
261
+ "Say 'uh' then slowly curl your tongue tip upward and backward."
262
+ ),
263
+ "drill": "Practice: 'uh' β†’ curl tongue β†’ 'er'. Hold 'er' for three seconds.",
264
+ "self_check": "Your tongue tip should point upward or backward but NOT touch the roof.",
265
+ },
266
+
267
+ # ── PLACE: CORONAL ───────────────────────────────────────────────────
268
+ {
269
+ "features": ["coronal"],
270
+ "direction": "missing",
271
+ "tip": (
272
+ "Coronal sounds are made with the front part (blade or tip) of the tongue β€” "
273
+ "this covers /t/, /d/, /s/, /z/, /n/, /l/, /sh/, /th/, and /r/. "
274
+ "Make sure your tongue front is active and positioned correctly for this sound."
275
+ ),
276
+ "drill": "Say 'tip', 'dip', 'sip', 'nip' β€” feel the tongue tip or blade doing the work.",
277
+ "self_check": "Is your tongue front β€” tip or blade β€” the part making contact?",
278
+ },
279
+
280
+ # ── PLACE: DORSAL ────────────────────────────────────────────────────
281
+ {
282
+ "features": ["dorsal"],
283
+ "direction": "missing",
284
+ "tip": (
285
+ "Dorsal sounds (/k/, /g/, /ng/, /w/, /y/) involve the back (body or root) of the tongue. "
286
+ "Your tongue body needs to arch toward the velum or palate. "
287
+ "You may be using your tongue tip when the back of the tongue should lead."
288
+ ),
289
+ "drill": "Say 'key', 'go', 'sing' β€” feel the back hump of your tongue rise each time.",
290
+ "self_check": "The front of your tongue should be relaxed; the back should be doing the work.",
291
+ },
292
+
293
+ # ── VOWEL HEIGHT ──────────────────────────────────────────────────────
294
+ {
295
+ "features": ["high"],
296
+ "direction": "missing",
297
+ "tip": (
298
+ "This vowel needs your tongue to be high in your mouth. "
299
+ "Think of 'ee' in 'feet' or 'oo' in 'food' β€” the tongue is raised close to the palate. "
300
+ "Raise your tongue toward the roof of your mouth as you say the vowel."
301
+ ),
302
+ "drill": "Slide from 'ah' (low, jaw open) β†’ 'ee' (high, jaw nearly closed) and feel the tongue rise.",
303
+ "self_check": "Your jaw should be mostly closed; the tongue should be near the roof.",
304
+ },
305
+ {
306
+ "features": ["mid"],
307
+ "direction": "missing",
308
+ "tip": (
309
+ "This vowel needs a mid-height tongue position β€” halfway between fully raised and fully lowered. "
310
+ "Think of 'eh' in 'bed' or 'oh' in 'boat'. "
311
+ "Relax your jaw to a half-open position."
312
+ ),
313
+ "drill": "Slide 'ee' (high) β†’ 'eh' (mid) β†’ 'ah' (low) and stop at the middle position.",
314
+ "self_check": "Your jaw should be half open β€” neither clenched nor dropped wide.",
315
+ },
316
+ {
317
+ "features": ["low"],
318
+ "direction": "missing",
319
+ "tip": (
320
+ "This vowel needs your tongue to drop down and your jaw to open wide. "
321
+ "Think of 'ah' in 'father' or 'ae' in 'cat' β€” the tongue is flat and low. "
322
+ "Let your jaw drop and your tongue rest at the bottom of your mouth."
323
+ ),
324
+ "drill": "Say 'ah' like a doctor's exam β€” exaggerate the open jaw and flat tongue.",
325
+ "self_check": "Your jaw should be open wide; your tongue should feel flat at the bottom.",
326
+ },
327
+
328
+ # ── VOWEL BACKNESS ───────────────────────────────────────────────────
329
+ {
330
+ "features": ["front"],
331
+ "direction": "missing",
332
+ "tip": (
333
+ "This vowel should be made with your tongue pushed toward the front of your mouth. "
334
+ "Smile slightly β€” this naturally pulls the tongue body forward."
335
+ ),
336
+ "drill": "Say 'ee – ay – eh' and feel your tongue staying at the front for all three.",
337
+ "self_check": "You should feel tension or contact toward the front of your mouth.",
338
+ },
339
+ {
340
+ "features": ["back"],
341
+ "direction": "missing",
342
+ "tip": (
343
+ "This vowel should be made with your tongue retracted toward the back of your mouth. "
344
+ "Round your lips slightly and pull your tongue body backward as you say the vowel."
345
+ ),
346
+ "drill": "Say 'oo – oh – aw' β€” feel your tongue pulling back and the lips rounding each time.",
347
+ "self_check": "You should feel the back of your tongue arch upward and backward.",
348
+ },
349
+ {
350
+ "features": ["central"],
351
+ "direction": "missing",
352
+ "tip": (
353
+ "This vowel (like the schwa /Ι™/ in 'about') should be made with a completely neutral, "
354
+ "centered tongue β€” not pushed forward or pulled back. "
355
+ "Relax all tension in your jaw, lips, and tongue."
356
+ ),
357
+ "drill": "Say 'uh' with a completely relaxed, drooping jaw and limp tongue.",
358
+ "self_check": "Your mouth should feel effortless, tongue neither front nor back.",
359
+ },
360
+
361
+ # ── LIP ROUNDING (Others group: 'round') ─────────────────────────────
362
+ {
363
+ "features": ["round"],
364
+ "direction": "missing",
365
+ "tip": (
366
+ "This sound requires rounded, protruded lips β€” like you are blowing out a candle. "
367
+ "Form an 'oo' shape with your lips before and during the sound."
368
+ ),
369
+ "drill": "Exaggerate lip rounding: say 'oo – oh – aw' with very pursed lips.",
370
+ "self_check": "Look in a mirror β€” your lips should form a clear circle or oval.",
371
+ },
372
+ {
373
+ "features": ["round"],
374
+ "direction": "extra",
375
+ "tip": (
376
+ "You are rounding your lips when they should be spread or neutral. "
377
+ "Spread your lips into a slight smile and keep them flat as you say the sound."
378
+ ),
379
+ "drill": "Say 'ee – ih – eh' with a relaxed smile β€” no lip rounding at all.",
380
+ "self_check": "Your lips should be flat or slightly spread, not puckered.",
381
+ },
382
+
383
+ # ── VOWEL LENGTH (Others group: 'long' / 'short') ─────────────────────
384
+ {
385
+ "features": ["long"],
386
+ "direction": "missing",
387
+ "tip": (
388
+ "This vowel should be noticeably longer in duration. "
389
+ "English long vowels (/iy/, /uw/, /aa/, /ao/, /ae/, /er/) are roughly twice "
390
+ "as long as their short counterparts. Stretch it out."
391
+ ),
392
+ "drill": "Say 'beat' and hold the vowel: 'beeeeat'. Then compare with the short 'bit'.",
393
+ "self_check": "Record yourself β€” the vowel should sound stretched, not clipped.",
394
+ },
395
+ {
396
+ "features": ["short"],
397
+ "direction": "missing",
398
+ "tip": (
399
+ "This vowel should be brief and clipped. "
400
+ "Short vowels (/ih/, /eh/, /ah/, /uh/) are reduced in duration. "
401
+ "Don't let the vowel linger β€” move quickly to the next sound."
402
+ ),
403
+ "drill": "Say 'bit', 'bet', 'but', 'book' β€” snap off each vowel quickly.",
404
+ "self_check": "The vowel should feel brief. If you can hold it comfortably, it's too long.",
405
+ },
406
+
407
+ # ── VOWEL TYPE (Others group: 'monophthong' / 'diphthong') ──────────
408
+ {
409
+ "features": ["monophthong"],
410
+ "direction": "missing",
411
+ "tip": (
412
+ "This vowel should be pure and steady β€” your tongue and lips should hold the same "
413
+ "position throughout. You may be letting the vowel glide (diphthongize). "
414
+ "Keep your tongue and jaw completely still from start to finish."
415
+ ),
416
+ "drill": "Hold /aa/, /iy/, or /uw/ for three seconds without any movement.",
417
+ "self_check": "The vowel quality should be identical at the beginning and end β€” no glide.",
418
+ },
419
+ {
420
+ "features": ["diphthong"],
421
+ "direction": "missing",
422
+ "tip": (
423
+ "This vowel should glide from one position to another β€” it is a diphthong. "
424
+ "English diphthongs like /ay/ ('bite'), /aw/ ('bout'), /oy/ ('boy'), "
425
+ "/ey/ ('bait'), /ow/ ('boat') have a clear movement. "
426
+ "Let your tongue and jaw glide smoothly to the second target."
427
+ ),
428
+ "drill": "Say 'buy – bow – boy – bay – boat' slowly and feel the glide in each vowel.",
429
+ "self_check": "The vowel should sound like it is moving, not fixed in one place.",
430
+ },
431
+ ]
432
+
433
+ # Build a fast lookup: feature β†’ list of applicable rules
434
+ _RULE_INDEX: Dict[str, List[Dict]] = {}
435
+ for rule in FEATURE_RULES:
436
+ for feat in rule["features"]:
437
+ _RULE_INDEX.setdefault(feat, []).append(rule)
438
+
439
+
440
+ # ──────────────────────────────────────────────
441
+ # 2. Rule matcher
442
+ # ──────────────────────────────────────────────
443
+
444
+ @dataclass
445
+ class RuleFeedback:
446
+ feature: str
447
+ direction: str # "missing" | "extra"
448
+ tip: str
449
+ drill: str
450
+ self_check: str
451
+ count: int = 1 # how many phonemes triggered this rule
452
+
453
+
454
+ def match_rules(errors: List[PhonemeError]) -> List[RuleFeedback]:
455
+ """
456
+ Given a list of phoneme errors, find the most relevant feedback rules.
457
+ Rules are deduplicated and sorted by frequency of occurrence.
458
+ """
459
+ triggered: Dict[Tuple[str, str], RuleFeedback] = {}
460
+
461
+ for error in errors:
462
+ for feat in error.missing_features:
463
+ for rule in _RULE_INDEX.get(feat, []):
464
+ if rule["direction"] in ("missing", "both"):
465
+ key = (feat, "missing")
466
+ if key in triggered:
467
+ triggered[key].count += 1
468
+ else:
469
+ triggered[key] = RuleFeedback(
470
+ feature=feat,
471
+ direction="missing",
472
+ tip=rule["tip"],
473
+ drill=rule["drill"],
474
+ self_check=rule["self_check"],
475
+ )
476
+
477
+ for feat in error.extra_features:
478
+ for rule in _RULE_INDEX.get(feat, []):
479
+ if rule["direction"] in ("extra", "both"):
480
+ key = (feat, "extra")
481
+ if key in triggered:
482
+ triggered[key].count += 1
483
+ else:
484
+ triggered[key] = RuleFeedback(
485
+ feature=feat,
486
+ direction="extra",
487
+ tip=rule["tip"],
488
+ drill=rule["drill"],
489
+ self_check=rule["self_check"],
490
+ )
491
+
492
+ # Sort by occurrence count descending
493
+ return sorted(triggered.values(), key=lambda r: -r.count)
494
+
495
+
496
+ # ──────────────────────────────────────────────
497
+ # 3. Template-based fallback feedback (no LLM needed)
498
+ # ──────────────────────────────────────────────
499
+
500
+ def format_feedback_template(
501
+ result: MDDResult,
502
+ rules: List[RuleFeedback],
503
+ max_issues: int = 3,
504
+ ) -> str:
505
+ """Structured text feedback without LLM β€” always available."""
506
+ lines = []
507
+ score = result.utterance_score
508
+
509
+ # Score header
510
+ if score >= 85:
511
+ lines.append(f"πŸŽ‰ Great pronunciation! Score: {score:.0f}/100")
512
+ elif score >= 65:
513
+ lines.append(f"πŸ‘ Good effort! Score: {score:.0f}/100 β€” a few things to polish.")
514
+ elif score >= 45:
515
+ lines.append(f"πŸ“š Score: {score:.0f}/100 β€” let's work on some key areas.")
516
+ else:
517
+ lines.append(f"πŸ’ͺ Score: {score:.0f}/100 β€” keep practicing, you'll get there!")
518
+
519
+ if not rules:
520
+ lines.append("\nNo significant feature errors detected. Well done!")
521
+ return "\n".join(lines)
522
+
523
+ lines.append(f"\nI found {len(result.errors)} phoneme(s) that need attention.\n")
524
+
525
+ for i, rule in enumerate(rules[:max_issues]):
526
+ direction_word = "missing" if rule.direction == "missing" else "extra"
527
+ lines.append(f"β€” Issue {i+1}: [{rule.feature}] feature {direction_word}")
528
+ lines.append(f" πŸ’‘ {rule.tip}")
529
+ lines.append(f" πŸ‹οΈ Drill: {rule.drill}")
530
+ lines.append(f" βœ… Self-check: {rule.self_check}\n")
531
+
532
+ return "\n".join(lines)
533
+
534
+
535
+ # ──────────────────────────────────────────────
536
+ # 4. LLM-enhanced feedback
537
+ # ──────────────────────────────────────────────
538
+
539
+ LLM_SYSTEM_PROMPT = """You are a warm, encouraging English pronunciation coach.
540
+ Your student just attempted to say a sentence and you've identified specific
541
+ phonological feature errors. Your task is to rewrite the structured feedback
542
+ into a single natural, conversational coaching response.
543
+
544
+ Rules:
545
+ - Keep ALL the articulatory tips and self-checks intact β€” do not omit or soften them.
546
+ - Write as if speaking to the student directly.
547
+ - Be encouraging but honest.
548
+ - Limit response to 200 words maximum.
549
+ - Do not add new advice not present in the structured feedback.
550
+ - Start with a brief overall assessment, then naturally weave in the tips.
551
+ - End with one motivating sentence.
552
+ """
553
+
554
+ def generate_llm_feedback(
555
+ structured_feedback: str,
556
+ score: float,
557
+ model_name: str = "Qwen/Qwen2.5-0.5B-Instruct", # lightweight default
558
+ use_cloud_fallback: bool = True,
559
+ ) -> str:
560
+ """
561
+ Rewrites structured feedback into natural coaching language.
562
+
563
+ Tries (in order):
564
+ 1. Local transformers model (if available)
565
+ 2. Cloud LLM API (if use_cloud_fallback=True and API key set)
566
+ 3. Returns structured_feedback unchanged as graceful degradation
567
+ """
568
+ prompt = f"""Here is structured pronunciation feedback for a student who scored {score:.0f}/100:
569
+
570
+ {structured_feedback}
571
+
572
+ Please rewrite this as a warm, natural coaching response."""
573
+
574
+ # --- Try local model first ---
575
+ try:
576
+ from transformers import AutoTokenizer, AutoModelForCausalLM
577
+ import torch
578
+
579
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
580
+ model = AutoModelForCausalLM.from_pretrained(
581
+ model_name,
582
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
583
+ device_map="auto",
584
+ )
585
+
586
+ messages = [
587
+ {"role": "system", "content": LLM_SYSTEM_PROMPT},
588
+ {"role": "user", "content": prompt},
589
+ ]
590
+ text = tokenizer.apply_chat_template(
591
+ messages, tokenize=False, add_generation_prompt=True
592
+ )
593
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
594
+ with torch.no_grad():
595
+ output = model.generate(
596
+ **inputs,
597
+ max_new_tokens=256,
598
+ temperature=0.7,
599
+ do_sample=True,
600
+ pad_token_id=tokenizer.eos_token_id,
601
+ )
602
+ response = tokenizer.decode(
603
+ output[0][inputs.input_ids.shape[1]:], skip_special_tokens=True
604
+ )
605
+ return response.strip()
606
+
607
+ except Exception as local_err:
608
+ print(f"[Local LLM] Not available: {local_err}")
609
+
610
+ # --- Cloud fallback (OpenAI-compatible API) ---
611
+ if use_cloud_fallback:
612
+ api_key = os.environ.get("OPENAI_API_KEY") or os.environ.get("LLM_API_KEY")
613
+ api_base = os.environ.get("LLM_API_BASE", "https://api.openai.com/v1")
614
+ cloud_model = os.environ.get("LLM_MODEL", "gpt-4o-mini")
615
+
616
+ if api_key:
617
+ try:
618
+ import httpx
619
+ headers = {
620
+ "Authorization": f"Bearer {api_key}",
621
+ "Content-Type": "application/json",
622
+ }
623
+ body = {
624
+ "model": cloud_model,
625
+ "messages": [
626
+ {"role": "system", "content": LLM_SYSTEM_PROMPT},
627
+ {"role": "user", "content": prompt},
628
+ ],
629
+ "max_tokens": 300,
630
+ "temperature": 0.7,
631
+ }
632
+ r = httpx.post(f"{api_base}/chat/completions", json=body, headers=headers, timeout=15)
633
+ r.raise_for_status()
634
+ return r.json()["choices"][0]["message"]["content"].strip()
635
+ except Exception as cloud_err:
636
+ print(f"[Cloud LLM] Failed: {cloud_err}")
637
+
638
+ # --- Graceful degradation ---
639
+ return structured_feedback
640
+
641
+
642
+ # ──────────────────────────────────────────────
643
+ # 5. Main feedback pipeline
644
+ # ──────────────────────────────────────────────
645
+
646
+ def generate_feedback(
647
+ result: MDDResult,
648
+ use_llm: bool = True,
649
+ max_issues: int = 3,
650
+ ) -> Dict:
651
+ """
652
+ Full feedback pipeline. Returns a dict with keys:
653
+ score, template_feedback, final_feedback, error_summary, rules_triggered
654
+ """
655
+ rules = match_rules(result.errors)
656
+ template_fb = format_feedback_template(result, rules, max_issues)
657
+
658
+ if use_llm and rules:
659
+ final_fb = generate_llm_feedback(template_fb, result.utterance_score)
660
+ else:
661
+ final_fb = template_fb
662
+
663
+ error_summary = [
664
+ {
665
+ "position": e.position,
666
+ "target": e.target_phoneme,
667
+ "produced": e.produced_phoneme,
668
+ "missing_features": e.missing_features,
669
+ "extra_features": e.extra_features,
670
+ "accuracy": round(e.feature_accuracy, 3),
671
+ "severity": e.severity,
672
+ }
673
+ for e in result.errors
674
+ ]
675
+
676
+ return {
677
+ "score": round(result.utterance_score, 1),
678
+ "template_feedback": template_fb,
679
+ "final_feedback": final_fb,
680
+ "error_summary": error_summary,
681
+ "feature_error_counts": result.feature_error_counts,
682
+ "rules_triggered": [
683
+ {
684
+ "feature": r.feature,
685
+ "direction": r.direction,
686
+ "occurrences": r.count,
687
+ }
688
+ for r in rules
689
+ ],
690
+ }
mdd_engine.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MDD Engine β€” Mispronunciation Detection and Diagnosis
3
+ =====================================================
4
+
5
+ Architecture (Shahin et al. 2025)
6
+ ----------------------------------
7
+ Your model runs 35 independent CTC decoders, one per phonological feature.
8
+ Each decoder outputs a sequence of +att(1) / -att(0) labels, with blanks
9
+ already removed and runs collapsed β€” so the output length reflects the number
10
+ of detected phoneme-level events, NOT audio frames.
11
+
12
+ The canonical target comes from the user's typed sentence:
13
+ sentence β†’ G2P (CMU ARPAbet) β†’ phoneme_sequence_to_feature_sequences()
14
+ β†’ 35 binary label sequences of length T (number of target phonemes)
15
+
16
+ The problem: the actual decoded sequence per feature may have a DIFFERENT
17
+ length than T, because the student may have:
18
+ - deleted phonemes (actual shorter than target)
19
+ - inserted extras (actual longer than target)
20
+ - substituted (same length, wrong labels)
21
+
22
+ Solution: Needleman-Wunsch (global sequence alignment) per feature
23
+ ------------------------------------------------------------------
24
+ For each of the 35 features we run a global pairwise alignment between the
25
+ target binary sequence and the actual binary sequence. This gives us an
26
+ explicit alignment path with match / mismatch / insertion / deletion ops.
27
+
28
+ We then aggregate across all 35 features to get, per target phoneme position:
29
+ - which actual position it maps to (or DELETION if no match)
30
+ - which features are missing (+att in target, -att or gap in actual)
31
+ - which features are extra (-att in target, +att in actual)
32
+ - a weighted feature accuracy score
33
+
34
+ This is the standard approach in phonological MDD literature when no frame-
35
+ level forced alignment is available (see e.g. Lee & Glass 2015, Leung et al.
36
+ 2019, and the feature-based MDD track of the AIP challenge).
37
+
38
+ Input/output contract
39
+ ---------------------
40
+ actual_feature_seqs : list[list[int]] β€” 35 lists, each decoded CTC output
41
+ Values: 1 (+att) or 0 (-att)
42
+ Lengths may differ across features
43
+ and from the canonical length T
44
+
45
+ target_phonemes : list[str] β€” CMU ARPAbet phoneme sequence from
46
+ the user's typed sentence, length T
47
+
48
+ Output: MDDResult (see dataclass below)
49
+ """
50
+
51
+ from __future__ import annotations
52
+
53
+ import numpy as np
54
+ from dataclasses import dataclass, field
55
+ from typing import List, Dict, Tuple, Optional
56
+
57
+ from phonological_features import (
58
+ PHONOLOGICAL_FEATURES,
59
+ phoneme_sequence_to_feature_sequences,
60
+ phoneme_to_feature_vector,
61
+ )
62
+
63
+ # ─────────────────────────────────────────────────────────────────────────────
64
+ # 1. Feature schema & weights
65
+ # ─────────────────────────────────────────────────────────────────────────────
66
+
67
+ FEATURE_NAMES: List[str] = PHONOLOGICAL_FEATURES # 35 features, canonical order
68
+ NUM_FEATURES = len(FEATURE_NAMES) # 35
69
+ assert NUM_FEATURES == 35
70
+
71
+ F2I: Dict[str, int] = {f: i for i, f in enumerate(FEATURE_NAMES)}
72
+
73
+ # Perceptual salience weights β€” higher = more important mismatch.
74
+ # Manner errors (wrong sound class) are most disruptive.
75
+ # Voicing errors are highly salient in English.
76
+ # Place errors matter but less so than manner.
77
+ # Length/type distinctions are least salient in L2 MDD.
78
+ FEATURE_WEIGHTS: np.ndarray = np.array([
79
+ # Manners (11): consonant sonorant fricative nasal stop
80
+ 2.0, 1.5, 1.8, 2.0, 2.0,
81
+ # approximant affricate liquid vowel semivowel continuant
82
+ 1.5, 1.8, 1.5, 2.0, 1.5, 1.2,
83
+ # Places (18): alveolar palatal dental glottal labial velar
84
+ 1.5, 1.4, 1.3, 1.2, 1.5, 1.5,
85
+ # mid high low front back central
86
+ 1.8, 1.8, 1.8, 1.6, 1.6, 1.2,
87
+ # anterior posterior retroflex bilabial coronal dorsal
88
+ 1.3, 1.3, 1.3, 1.4, 1.3, 1.3,
89
+ # Others (6): long short monophthong diphthong round voiced
90
+ 1.0, 1.0, 1.2, 1.2, 1.2, 2.5,
91
+ ], dtype=np.float32)
92
+
93
+ assert len(FEATURE_WEIGHTS) == 35
94
+
95
+ # Alignment op codes
96
+ MATCH = 0 # same label, same position
97
+ MISMATCH = 1 # different label, same position
98
+ DELETE = 2 # target has event, actual has gap (deletion error)
99
+ INSERT = 3 # actual has event, target has gap (insertion error)
100
+
101
+ # NW scoring scheme
102
+ MATCH_SCORE = 2
103
+ MISMATCH_SCORE = -1
104
+ GAP_PENALTY = -2 # penalises deletions and insertions equally
105
+
106
+
107
+ # ─────────────────────────────────────────────────────────────────────────────
108
+ # 2. Data classes
109
+ # ─────────────────────────────────────────────────────────────────────────────
110
+
111
+ @dataclass
112
+ class AlignedPosition:
113
+ """One position in the target sequence after multi-feature alignment."""
114
+ target_idx: int # index in target phoneme sequence
115
+ actual_idx: Optional[int] # index in actual sequence, None = deletion
116
+ op: int # MATCH / MISMATCH / DELETE / INSERT
117
+ target_bits: List[int] # canonical feature vector (35 bits)
118
+ actual_bits: List[int] # observed feature vector (35 bits, 0 if deleted)
119
+ missing_features: List[str] # +att in target, -att or gap in actual
120
+ extra_features: List[str] # -att in target, +att in actual
121
+ feature_accuracy: float # weighted accuracy 0-1
122
+
123
+
124
+ @dataclass
125
+ class PhonemeError:
126
+ """One mispronounced phoneme with its full feature-level diagnosis."""
127
+ position: int # index in target sequence
128
+ target_phoneme: str # ARPAbet label from typed sentence
129
+ missing_features: List[str] # features the student failed to produce
130
+ extra_features: List[str] # features the student added erroneously
131
+ is_deletion: bool # student dropped this phoneme entirely
132
+ feature_accuracy: float # 0-1
133
+ severity: str # "mild" | "moderate" | "severe"
134
+
135
+
136
+ @dataclass
137
+ class MDDResult:
138
+ utterance_score: float # 0-100
139
+ phoneme_scores: List[float] # per target phoneme, 0-1
140
+ errors: List[PhonemeError]
141
+ aligned_positions: List[AlignedPosition]
142
+ feature_error_counts: Dict[str, int] # aggregated across all phonemes
143
+ deletion_count: int
144
+ insertion_count: int
145
+
146
+
147
+ # ─────────────────────────────────────────────────────────────────────────────
148
+ # 3. Needleman-Wunsch per-feature aligner
149
+ # ─────────────────────────────────────────────────────────────────────────────
150
+
151
+ def _nw_align(target_seq: List[int],
152
+ actual_seq: List[int]) -> List[Tuple[Optional[int], Optional[int]]]:
153
+ """
154
+ Global sequence alignment (Needleman-Wunsch) for two binary label sequences.
155
+
156
+ Returns a list of (target_idx, actual_idx) pairs where:
157
+ (i, j) β†’ match or mismatch at target[i], actual[j]
158
+ (i, None) β†’ deletion: target[i] has no corresponding actual event
159
+ (None, j) β†’ insertion: actual[j] has no corresponding target event
160
+
161
+ Binary values: 1 = +att, 0 = -att
162
+ """
163
+ T = len(target_seq)
164
+ A = len(actual_seq)
165
+
166
+ # Fill score matrix
167
+ score = np.zeros((T + 1, A + 1), dtype=np.float32)
168
+ score[0, :] = np.arange(A + 1) * GAP_PENALTY
169
+ score[:, 0] = np.arange(T + 1) * GAP_PENALTY
170
+
171
+ for i in range(1, T + 1):
172
+ for j in range(1, A + 1):
173
+ s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE
174
+ score[i, j] = max(
175
+ score[i-1, j-1] + s, # match/mismatch
176
+ score[i-1, j] + GAP_PENALTY, # deletion
177
+ score[i, j-1] + GAP_PENALTY, # insertion
178
+ )
179
+
180
+ # Traceback
181
+ path: List[Tuple[Optional[int], Optional[int]]] = []
182
+ i, j = T, A
183
+ while i > 0 or j > 0:
184
+ if i > 0 and j > 0:
185
+ s = MATCH_SCORE if target_seq[i-1] == actual_seq[j-1] else MISMATCH_SCORE
186
+ if score[i, j] == score[i-1, j-1] + s:
187
+ path.append((i-1, j-1))
188
+ i -= 1; j -= 1
189
+ continue
190
+ if i > 0 and score[i, j] == score[i-1, j] + GAP_PENALTY:
191
+ path.append((i-1, None)) # deletion
192
+ i -= 1
193
+ else:
194
+ path.append((None, j-1)) # insertion
195
+ j -= 1
196
+
197
+ path.reverse()
198
+ return path
199
+
200
+
201
+ # ─────────────────────────────────────────────────────────────────────────────
202
+ # 4. Multi-feature alignment aggregator
203
+ # ─────────────────────────────────────────────────────────────────────────────
204
+
205
+ def _align_all_features(
206
+ target_feat_seqs: List[List[int]], # 35 lists, each length T
207
+ actual_feat_seqs: List[List[int]], # 35 lists, each possibly != T
208
+ T: int, # number of target phonemes
209
+ ) -> List[AlignedPosition]:
210
+ """
211
+ Run NW alignment independently on each of 35 feature sequences, then
212
+ aggregate the results per target phoneme position.
213
+
214
+ Strategy
215
+ --------
216
+ Each feature gives its own alignment path. We collect, for each target
217
+ position i, a vote over all 35 features about what actual position it
218
+ maps to. The plurality actual index wins. If the majority vote is "gap"
219
+ (deletion), the position is marked as a deletion.
220
+
221
+ Then per position we reconstruct the actual feature bits from the voted
222
+ actual index across all features.
223
+ """
224
+ # votes[target_idx] β†’ list of actual_idx votes (None = deletion vote)
225
+ votes: List[List[Optional[int]]] = [[] for _ in range(T)]
226
+ # per_feature_actual_idx[feat][target_idx] β†’ actual_idx or None
227
+ per_feat_map: List[Dict[int, Optional[int]]] = [
228
+ {} for _ in range(NUM_FEATURES)
229
+ ]
230
+
231
+ for feat_i in range(NUM_FEATURES):
232
+ t_seq = target_feat_seqs[feat_i] # length T
233
+ a_seq = actual_feat_seqs[feat_i] # length may differ
234
+
235
+ path = _nw_align(t_seq, a_seq)
236
+
237
+ for (ti, ai) in path:
238
+ if ti is None:
239
+ continue # insertion β€” no target position, skip
240
+ votes[ti].append(ai) # ai may be None (deletion)
241
+ per_feat_map[feat_i][ti] = ai
242
+
243
+
244
+ # Resolve votes per target position
245
+ aligned: List[AlignedPosition] = []
246
+
247
+ DELETION_VOTE_THRESHOLD = 0.5 # >50% gap votes β†’ mark as DELETE
248
+
249
+ for ti in range(T):
250
+ v = votes[ti]
251
+ non_null = [x for x in v if x is not None]
252
+ null_count = len(v) - len(non_null)
253
+ deletion_fraction = null_count / max(len(v), 1)
254
+
255
+ if not non_null or deletion_fraction > DELETION_VOTE_THRESHOLD:
256
+ chosen_ai = None
257
+ else:
258
+ # Plurality vote among non-null actual indices
259
+ counts: Dict[int, int] = {}
260
+ for idx in non_null:
261
+ counts[idx] = counts.get(idx, 0) + 1
262
+ chosen_ai = max(counts, key=counts.__getitem__)
263
+ # Build target and actual bit vectors for this position
264
+ target_bits = [target_feat_seqs[f][ti] for f in range(NUM_FEATURES)]
265
+
266
+ if chosen_ai is not None:
267
+ actual_bits = []
268
+ for f in range(NUM_FEATURES):
269
+ # Use per-feature actual value if this feature agrees on chosen_ai
270
+ feat_ai = per_feat_map[f].get(ti, None)
271
+ if feat_ai == chosen_ai:
272
+ actual_bits.append(actual_feat_seqs[f][feat_ai]
273
+ if feat_ai < len(actual_feat_seqs[f]) else 0)
274
+ else:
275
+ # Feature disagrees on the position β€” use its own aligned value
276
+ fa = per_feat_map[f].get(ti, None)
277
+ if fa is not None and fa < len(actual_feat_seqs[f]):
278
+ actual_bits.append(actual_feat_seqs[f][fa])
279
+ else:
280
+ actual_bits.append(0) # treat as absent
281
+ op = MATCH if target_bits == actual_bits else MISMATCH
282
+ else:
283
+ actual_bits = [0] * NUM_FEATURES
284
+ op = DELETE
285
+
286
+ # Compute feature-level errors
287
+ missing = [FEATURE_NAMES[f] for f in range(NUM_FEATURES)
288
+ if target_bits[f] == 1 and actual_bits[f] == 0]
289
+ extra = [FEATURE_NAMES[f] for f in range(NUM_FEATURES)
290
+ if target_bits[f] == 0 and actual_bits[f] == 1]
291
+
292
+ # Weighted accuracy: fraction of weighted features correctly produced
293
+ correct_weight = sum(
294
+ FEATURE_WEIGHTS[f]
295
+ for f in range(NUM_FEATURES)
296
+ if target_bits[f] == actual_bits[f]
297
+ )
298
+ total_weight = float(FEATURE_WEIGHTS.sum())
299
+ accuracy = float(correct_weight / total_weight)
300
+
301
+ aligned.append(AlignedPosition(
302
+ target_idx=ti,
303
+ actual_idx=chosen_ai,
304
+ op=op,
305
+ target_bits=target_bits,
306
+ actual_bits=actual_bits,
307
+ missing_features=missing,
308
+ extra_features=extra,
309
+ feature_accuracy=accuracy,
310
+ ))
311
+
312
+ return aligned
313
+
314
+
315
+ # ─────────────────────────────────────────────────────────────────────────────
316
+ # 5. Insertion detector
317
+ # ─────────────────────────────────────────────────────────────────────────────
318
+
319
+ def _count_insertions(
320
+ actual_feat_seqs: List[List[int]],
321
+ actual_len: int,
322
+ aligned: List[AlignedPosition],
323
+ ) -> int:
324
+ """
325
+ Count actual positions that were voted as insertions (not mapped to any
326
+ target position) by the majority of features.
327
+ """
328
+ used_actual = set(
329
+ ap.actual_idx for ap in aligned if ap.actual_idx is not None
330
+ )
331
+ inserted = set(range(actual_len)) - used_actual
332
+ return len(inserted)
333
+
334
+
335
+ # ─────────��───────────────────────────────────────────────────────────────────
336
+ # 6. Severity classifier
337
+ # ─────────────────────────────────────────────────────────────────────────────
338
+
339
+ # Thresholds on weighted feature error rate
340
+ _SEV = {"mild": 0.85, "moderate": 0.65} # accuracy thresholds (higher = easier)
341
+
342
+ def _severity(accuracy: float, is_deletion: bool) -> str:
343
+ if is_deletion:
344
+ return "severe"
345
+ if accuracy >= _SEV["mild"]:
346
+ return "mild"
347
+ if accuracy >= _SEV["moderate"]:
348
+ return "moderate"
349
+ return "severe"
350
+
351
+
352
+ # ─────────────────────────────────────────────────────────────────────────────
353
+ # 7. Scorer
354
+ # ─────────────────────────────────────────────────────────────────────────────
355
+
356
+ def _score_utterance(aligned: List[AlignedPosition]) -> Tuple[float, List[float]]:
357
+ """
358
+ Per-phoneme score: weighted feature accuracy (0-1).
359
+ Deletions score 0.
360
+ Utterance score: weighted mean, penalising deletions most.
361
+ """
362
+ phoneme_scores = [ap.feature_accuracy for ap in aligned]
363
+ utterance_score = float(np.mean(phoneme_scores)) * 100.0
364
+ return utterance_score, phoneme_scores
365
+
366
+
367
+ # ─────────────────────────────────────────────────────────────────────────────
368
+ # 8. Error list builder
369
+ # ─────────────────────────────────────────────────────────────────────────────
370
+
371
+ def _build_errors(
372
+ aligned: List[AlignedPosition],
373
+ target_phonemes: List[str],
374
+ ) -> List[PhonemeError]:
375
+ errors = []
376
+ for ap in aligned:
377
+ if ap.op == MATCH and not ap.missing_features and not ap.extra_features:
378
+ continue # perfectly correct, no error to report
379
+
380
+ errors.append(PhonemeError(
381
+ position=ap.target_idx,
382
+ target_phoneme=target_phonemes[ap.target_idx],
383
+ missing_features=ap.missing_features,
384
+ extra_features=ap.extra_features,
385
+ is_deletion=(ap.op == DELETE),
386
+ feature_accuracy=ap.feature_accuracy,
387
+ severity=_severity(ap.feature_accuracy, ap.op == DELETE),
388
+ ))
389
+ return errors
390
+
391
+
392
+ # ─────────────────────────────────────────────────────────────────────────────
393
+ # 9. Aggregate feature error counts
394
+ # ─────────────────────────────────────────────────────────────────────────────
395
+
396
+ def _aggregate(errors: List[PhonemeError]) -> Dict[str, int]:
397
+ counts: Dict[str, int] = {}
398
+ for e in errors:
399
+ for f in e.missing_features + e.extra_features:
400
+ counts[f] = counts.get(f, 0) + 1
401
+ return dict(sorted(counts.items(), key=lambda x: -x[1]))
402
+
403
+
404
+ # ─────────────────────────────────────────────────────────────────────────────
405
+ # 10. Public entry point
406
+ # ─────────────────────────────────────────────────────────────────────────────
407
+
408
+ def run_mdd(
409
+ actual_feature_seqs: List[List[int]],
410
+ target_phonemes: List[str],
411
+ ) -> MDDResult:
412
+ """
413
+ Full MDD pipeline for a CTC phonological-feature model.
414
+
415
+ Parameters
416
+ ----------
417
+ actual_feature_seqs : list of 35 lists of int (0 or 1)
418
+ CTC-decoded output of your model, AFTER blank removal and run-length
419
+ collapsing. Each list is the decoded +att/βˆ’att sequence for one feature.
420
+ Lengths may differ from each other and from len(target_phonemes).
421
+ Index order must match PHONOLOGICAL_FEATURES / FEATURE_NAMES.
422
+
423
+ Concretely, if your model outputs logits of shape (T_audio, 71):
424
+ nodes 0-34 = +att for features 0-34
425
+ nodes 35-69 = -att for features 0-34
426
+ node 70 = blank
427
+ Then for feature i, the CTC-decoded sequence is a list of 0s and 1s
428
+ (1 = +att node fired, 0 = -att node fired), blanks removed.
429
+
430
+ target_phonemes : list of str
431
+ CMU ARPAbet phoneme sequence from the user's typed sentence.
432
+ Obtain via any G2P tool, e.g. g2p_en:
433
+ from g2p_en import G2p
434
+ target_phonemes = G2p()(sentence)
435
+
436
+ Returns
437
+ -------
438
+ MDDResult
439
+ """
440
+ assert len(actual_feature_seqs) == 35, \
441
+ f"Expected 35 feature sequences, got {len(actual_feature_seqs)}"
442
+ assert len(target_phonemes) > 0, "target_phonemes must not be empty"
443
+
444
+ T = len(target_phonemes)
445
+
446
+ # Build canonical target feature sequences from the phoneme labels
447
+ target_feat_seqs: List[List[int]] = phoneme_sequence_to_feature_sequences(
448
+ target_phonemes
449
+ ) # 35 lists, each of length T
450
+
451
+ # Actual lengths (for insertion counting)
452
+ actual_len = max((len(s) for s in actual_feature_seqs), default=0)
453
+
454
+ # Step 1: per-feature NW alignment β†’ per target-position feature bits
455
+ aligned = _align_all_features(target_feat_seqs, actual_feature_seqs, T)
456
+
457
+ # Step 2: count structural errors
458
+ deletions = sum(1 for ap in aligned if ap.op == DELETE)
459
+ insertions = _count_insertions(actual_feature_seqs, actual_len, aligned)
460
+
461
+ # Step 3: score
462
+ utt_score, phoneme_scores = _score_utterance(aligned)
463
+
464
+ # Step 4: build error list
465
+ errors = _build_errors(aligned, target_phonemes)
466
+
467
+ # Step 5: aggregate feature error counts
468
+ feat_error_counts = _aggregate(errors)
469
+
470
+ return MDDResult(
471
+ utterance_score=utt_score,
472
+ phoneme_scores=phoneme_scores,
473
+ errors=errors,
474
+ aligned_positions=aligned,
475
+ feature_error_counts=feat_error_counts,
476
+ deletion_count=deletions,
477
+ insertion_count=insertions,
478
+ )
479
+
480
+
481
+ # ─────────────────────────────────────────────────────────────────────────────
482
+ # 11. CTC decode helper (use this on raw model logits)
483
+ # ─────────────────────────────────────────────────────────────────────────────
484
+
485
+ def ctc_decode_feature_seqs(
486
+ logits: np.ndarray, # (T_audio, 71) β€” raw model output per frame
487
+ blank_idx: int = 70,
488
+ ) -> List[List[int]]:
489
+ """
490
+ Greedy CTC decode for a phonological feature model with 71 output nodes.
491
+
492
+ For each of the 35 features independently:
493
+ 1. At each frame, pick argmax between pos_node (feat_i) and neg_node (feat_i+35)
494
+ (ignoring blank).
495
+ 2. Collapse runs and remove frames where blank wins overall.
496
+ 3. Return the sequence of 1s (+att) and 0s (-att).
497
+
498
+ Parameters
499
+ ----------
500
+ logits : np.ndarray (T_audio, 71)
501
+ Raw model output before softmax. If you've already applied softmax,
502
+ pass probabilities β€” the argmax logic is identical.
503
+ blank_idx : int
504
+ Index of the shared blank node (default 70).
505
+
506
+ Returns
507
+ -------
508
+ List of 35 lists of int (0 or 1), CTC-decoded.
509
+ """
510
+ T_audio = logits.shape[0]
511
+ feature_seqs: List[List[int]] = [[] for _ in range(35)]
512
+
513
+ for feat_i in range(35):
514
+ pos_node = feat_i # +att node
515
+ neg_node = feat_i + 35 # -att node
516
+
517
+ prev_label = None
518
+ for t in range(T_audio):
519
+ frame = logits[t]
520
+ best_overall = int(np.argmax(frame))
521
+
522
+ if best_overall == blank_idx:
523
+ prev_label = None # blank resets run
524
+ continue
525
+
526
+ # Among pos/neg for this feature, pick the winner
527
+ label = 1 if frame[pos_node] >= frame[neg_node] else 0
528
+
529
+ # CTC run-length collapse
530
+ if label != prev_label:
531
+ feature_seqs[feat_i].append(label)
532
+ prev_label = label
533
+
534
+ return feature_seqs
phonological_features.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ phonological_features.py
3
+ ========================
4
+ Defines the 35 phonological features from Table 1 of Shahin et al. (2025)
5
+ and provides the phoneme-to-feature mapping for the 39-phoneme CMU set.
6
+
7
+ Feature categories (paper Table 1):
8
+ Manners: consonant, sonorant, fricative, nasal, stop, approximant,
9
+ affricate, liquid, vowel, semivowel, continuant
10
+ Places: alveolar, palatal, dental, glottal, labial, velar, mid, high,
11
+ low, front, back, central, anterior, posterior, retroflex,
12
+ bilabial, coronal, dorsal
13
+ Others: long, short, monophthong, diphthong, round, voiced
14
+
15
+ The model output has 71 nodes: 35 (+att) + 35 (-att) + 1 (shared blank).
16
+ """
17
+
18
+ # ─────────────────────────────────────────────────────────────────────────────
19
+ # The 35 phonological features (paper Table 1), in a fixed canonical order
20
+ # ─────────────────────────────────────────────────────────────────────────────
21
+ PHONOLOGICAL_FEATURES = [
22
+ # Manners (11)
23
+ "consonant", "sonorant", "fricative", "nasal", "stop",
24
+ "approximant", "affricate", "liquid", "vowel", "semivowel", "continuant",
25
+ # Places (18)
26
+ "alveolar", "palatal", "dental", "glottal", "labial", "velar",
27
+ "mid", "high", "low", "front", "back", "central",
28
+ "anterior", "posterior", "retroflex", "bilabial", "coronal", "dorsal",
29
+ # Others (6)
30
+ "long", "short", "monophthong", "diphthong", "round", "voiced",
31
+ ]
32
+ assert len(PHONOLOGICAL_FEATURES) == 35, "Must have exactly 35 features"
33
+
34
+ FEATURE_TO_IDX = {feat: i for i, feat in enumerate(PHONOLOGICAL_FEATURES)}
35
+ NUM_FEATURES = len(PHONOLOGICAL_FEATURES)
36
+
37
+ # ─────────────────────────────────────────────────────────────────────────────
38
+ # Output node layout (paper Section 3.3):
39
+ # nodes 0..34 β†’ +att for features 0..34
40
+ # nodes 35..69 β†’ -att for features 0..34
41
+ # node 70 β†’ shared blank
42
+ # ─────────────────────────────────────────────────────────────────────────────
43
+ NUM_OUTPUT_NODES = 71 # 35 + 35 + 1
44
+ BLANK_IDX = 70
45
+
46
+ def feature_idx_to_pos_node(feat_idx: int) -> int:
47
+ """Return output node index for +att of a given feature."""
48
+ return feat_idx
49
+
50
+ def feature_idx_to_neg_node(feat_idx: int) -> int:
51
+ """Return output node index for -att of a given feature."""
52
+ return feat_idx + NUM_FEATURES
53
+
54
+
55
+ # ─────────────────────────────────────────────────────────────────────────────
56
+ # CMU 39-phoneme set (TIMIT 61β†’39 reduced set used in the paper)
57
+ # ─────────────────────────────────────────────────────────────────────────────
58
+ CMU_39_PHONEMES = [
59
+ "aa", "ae", "ah", "aw", "ay","ao",
60
+ "b", "ch", "d", "dh", "eh",
61
+ "er", "ey", "f", "g", "hh",
62
+ "ih", "iy", "jh", "k", "l",
63
+ "m", "n", "ng", "ow", "oy",
64
+ "p", "r", "s", "sh", "t",
65
+ "th", "uh", "uw", "v", "w",
66
+ "y", "z", "zh",
67
+ ]
68
+ PHONEME_TO_IDX = {p: i for i, p in enumerate(CMU_39_PHONEMES)}
69
+ NUM_PHONEMES = len(CMU_39_PHONEMES) # 39
70
+
71
+
72
+ # ─────────────────────────────────────────────────────────────────────────────
73
+ # Phoneme β†’ phonological feature binary vector
74
+ # Each phoneme maps to a dict {feature_name: True/False}.
75
+ # Derived from standard phonological feature charts (Chomsky & Halle 1968,
76
+ # as referenced in the paper).
77
+ # ─────────────────────────────────────────────────────────────────────────────
78
+ def _p(features_present: list[str]) -> dict[str, bool]:
79
+ """Helper: build feature dict from list of present features."""
80
+ return {f: (f in features_present) for f in PHONOLOGICAL_FEATURES}
81
+
82
+
83
+ PHONEME_FEATURES: dict[str, dict[str, bool]] = {
84
+ # ── Stops ──────────────────────────────────────────────────────────────
85
+ "p": _p(["consonant", "stop", "labial", "anterior", "bilabial"]),
86
+ "b": _p(["consonant", "stop", "labial", "anterior", "bilabial",
87
+ "voiced"]),
88
+ "t": _p(["consonant", "stop", "alveolar", "anterior", "coronal"]),
89
+ "d": _p(["consonant", "stop", "alveolar", "anterior", "coronal",
90
+ "voiced"]),
91
+ "k": _p(["consonant", "stop", "velar", "posterior", "dorsal"]),
92
+ "g": _p(["consonant", "stop", "velar", "posterior", "dorsal",
93
+ "voiced"]),
94
+
95
+ # ── Fricatives ─────────────────────────────────────────────────────────
96
+
97
+ "f": _p(["consonant", "fricative", "continuant", "labial", "anterior"]),
98
+ "v": _p(["consonant", "fricative", "continuant", "labial", "anterior", "voiced"]),
99
+ "th": _p(["consonant", "fricative", "continuant", "dental", "anterior",
100
+ "coronal"]),
101
+ "dh": _p(["consonant", "fricative", "continuant", "dental", "anterior",
102
+ "coronal", "voiced"]),
103
+ "s": _p(["consonant", "fricative", "continuant", "alveolar", "anterior",
104
+ "coronal"]),
105
+ "z": _p(["consonant", "fricative", "continuant", "alveolar", "anterior",
106
+ "coronal", "voiced"]),
107
+
108
+ "sh": _p(["consonant", "fricative", "continuant", "palatal", "posterior",
109
+ "coronal"]),
110
+
111
+ "zh": _p(["consonant", "fricative", "continuant", "palatal", "posterior",
112
+ "coronal", "voiced"]),
113
+
114
+ "hh": _p(["consonant", "fricative", "continuant", "glottal", "posterior",
115
+ "dorsal"]),
116
+
117
+ # ── Affricates ─────────────────────────────────────────────────────────
118
+
119
+ "ch": _p(["consonant", "affricate", "palatal", "posterior", "coronal"]),
120
+ "jh": _p(["consonant", "affricate", "palatal", "posterior", "coronal",
121
+ "voiced"]),
122
+
123
+ # ── Nasals ─────────────────────────────────────────────────────────────
124
+
125
+ "m": _p(["consonant", "sonorant", "nasal", "continuant", "labial",
126
+ "anterior", "bilabial", "voiced"]),
127
+ "n": _p(["consonant", "sonorant", "nasal", "continuant", "alveolar",
128
+ "anterior", "coronal", "voiced"]),
129
+ "ng": _p(["consonant", "sonorant", "nasal", "continuant", "velar",
130
+ "posterior", "dorsal", "voiced"]),
131
+
132
+ # ── Liquids ────────────────────────────────────────────────────────────
133
+ "l": _p(["consonant", "sonorant", "approximant", "liquid", "continuant",
134
+ "alveolar", "anterior", "coronal", "voiced"]),
135
+
136
+ "r": _p(["consonant", "sonorant", "approximant", "liquid", "continuant",
137
+ "alveolar", "anterior", "retroflex", "coronal", "voiced"]),
138
+
139
+ # ── Semivowels (Glides) ────────────────────────────────────────────────
140
+
141
+ "w": _p(["sonorant", "approximant", "semivowel", "continuant", "labial",
142
+ "high", "anterior", "bilabial", "round", "voiced"]),
143
+ "y": _p(["sonorant", "approximant", "semivowel", "continuant", "palatal",
144
+ "high", "posterior", "coronal", "voiced"]),
145
+
146
+ # ── Short Monophthong Vowels ───────────────────────────────────────────
147
+ "ih": _p(["sonorant", "vowel", "continuant", "high", "front",
148
+ "short", "monophthong", "voiced"]),
149
+
150
+ "eh": _p(["sonorant", "vowel", "mid", "front",
151
+ "short", "monophthong", "voiced"]),
152
+
153
+ "ae": _p(["sonorant", "vowel", "continuant", "low", "front",
154
+ "long", "monophthong", "voiced"]),
155
+
156
+ "ah": _p(["sonorant", "vowel", "continuant", "mid", "back",
157
+ "short", "monophthong", "voiced"]),
158
+
159
+ "uh": _p(["sonorant", "vowel", "continuant", "high", "back",
160
+ "short", "monophthong", "round", "voiced"]),
161
+
162
+ # ── Long Monophthong Vowels ────────────────────────────────────────────
163
+ "iy": _p(["sonorant", "vowel", "continuant", "high", "front",
164
+ "long", "monophthong", "voiced"]),
165
+ "aa": _p(["sonorant", "vowel", "continuant", "low", "back",
166
+ "long", "monophthong", "voiced"]),
167
+ "ao": _p(["sonorant", "vowel", "continuant", "mid", "back",
168
+ "long", "monophthong", "round", "voiced"]),
169
+ "er": _p(["sonorant", "vowel", "continuant", "mid", "central",
170
+ "retroflex", "short", "monophthong", "voiced"]),
171
+ "uw": _p(["sonorant", "vowel", "continuant", "high", "back",
172
+ "long", "monophthong", "round", "voiced"]),
173
+
174
+ # ── Diphthongs ─────────────────────────────────────────────────────────
175
+ "ey": _p(["sonorant", "vowel", "continuant", "mid", "front",
176
+ "long", "diphthong", "voiced"]),
177
+
178
+ "aw": _p(["sonorant", "vowel", "continuant", "low", "central",
179
+ "long", "diphthong", "round", "voiced"]),
180
+
181
+ "ay": _p(["sonorant", "vowel", "low", "central",
182
+ "long", "diphthong", "voiced"]),
183
+
184
+ "oy": _p(["sonorant", "vowel", "continuant", "mid", "back",
185
+ "long", "diphthong", "round", "voiced"]),
186
+
187
+ "ow": _p(["sonorant", "vowel", "continuant", "mid", "central",
188
+ "long", "diphthong", "round", "voiced"]),
189
+
190
+ # ── Silence ────────────────────────────────────────────────────────────
191
+ # Paper: "All silence labels were further removed leaving silence frames
192
+ # to be handled by the blank label."
193
+ "sil": _p([]), # all features absent; treated as blank during training
194
+ }
195
+
196
+ # Verify all 39 phonemes are covered.
197
+ # "sil" is intentionally extra β€” it is a fallback/blank placeholder, not a
198
+ # speech target, so it lives in PHONEME_FEATURES but not in CMU_39_PHONEMES.
199
+ _expected = set(CMU_39_PHONEMES) | {"sil"}
200
+ assert set(PHONEME_FEATURES.keys()) == _expected, (
201
+ f"Missing from PHONEME_FEATURES : {_expected - set(PHONEME_FEATURES.keys())}\n"
202
+ f"Unexpected in PHONEME_FEATURES: {set(PHONEME_FEATURES.keys()) - _expected}"
203
+ )
204
+ assert NUM_PHONEMES == 39, f"Expected 39 phonemes, got {NUM_PHONEMES}"
205
+
206
+
207
+ def phoneme_to_feature_vector(phoneme: str) -> list[bool]:
208
+ """Return a binary list of length 35 for a given phoneme."""
209
+ feat_dict = PHONEME_FEATURES.get(phoneme, PHONEME_FEATURES["sil"])
210
+ return [feat_dict[f] for f in PHONOLOGICAL_FEATURES]
211
+
212
+
213
+ def phoneme_sequence_to_feature_sequences(
214
+ phonemes: list[str],
215
+ ) -> list[list[int]]:
216
+ """
217
+ Convert a phoneme sequence to N=35 binary label sequences.
218
+
219
+ Returns:
220
+ feature_seqs: list of 35 lists, each containing +att(1) or -att(0)
221
+ integers for each phoneme position.
222
+ """
223
+ feature_seqs = [[] for _ in range(NUM_FEATURES)]
224
+ for ph in phonemes:
225
+ vec = phoneme_to_feature_vector(ph)
226
+ for feat_idx, present in enumerate(vec):
227
+ feature_seqs[feat_idx].append(1 if present else 0)
228
+ return feature_seqs
229
+
230
+
231
+ def feature_sequences_to_ctc_labels(
232
+ feature_seqs: list[list[int]],
233
+ ) -> list[list[int]]:
234
+ """
235
+ Convert binary feature sequences (0/1) to CTC label indices.
236
+
237
+ For category i:
238
+ - +att β†’ node index i (feature_idx_to_pos_node)
239
+ - -att β†’ node index i + 35 (feature_idx_to_neg_node)
240
+
241
+ Returns:
242
+ ctc_labels: list of 35 lists of node indices (int)
243
+ """
244
+ ctc_labels = []
245
+ for feat_idx, seq in enumerate(feature_seqs):
246
+ label_seq = []
247
+ for val in seq:
248
+ if val == 1:
249
+ label_seq.append(feature_idx_to_pos_node(feat_idx))
250
+ else:
251
+ label_seq.append(feature_idx_to_neg_node(feat_idx))
252
+ ctc_labels.append(label_seq)
253
+ return ctc_labels
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ gradio>=4.0.0
3
+ numpy>=1.24.0
4
+ scipy>=1.10.0
5
+
6
+ # Model
7
+ torch>=2.0.0
8
+ transformers>=4.40.0
9
+ huggingface_hub>=0.20.0
10
+
11
+ # Audio
12
+ librosa>=0.10.0
13
+ soundfile>=0.12.0
14
+
15
+ # Optional LLM rewriter
16
+ accelerate>=0.27.0
17
+ httpx>=0.25.0
wav2vec2_phonological.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ wav2vec2_phonological.py
3
+ ========================
4
+ Phonological feature detection model strictly following Figure 2 of
5
+ Shahin et al. (Speech Communication, 2025).
6
+
7
+ Architecture (Fig. 2):
8
+ Raw Speech
9
+ β”‚
10
+ β–Ό
11
+ wav2vec2.0 (pre-trained, CNN encoder frozen)
12
+ β”œβ”€ CNN Feature Extractor [FROZEN]
13
+ └─ Transformer [FINE-TUNED]
14
+ β”‚
15
+ β–Ό
16
+ Linear Layer (hidden_size β†’ 71 nodes)
17
+ β”‚
18
+ β–Ό
19
+ SCTC-SB Loss (during training)
20
+ OR
21
+ argmax per category (during inference)
22
+
23
+ Output nodes:
24
+ 0..34 β†’ +att_i (presence of feature i)
25
+ 35..69 β†’ -att_i (absence of feature i)
26
+ 70 β†’ shared blank
27
+ """
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ from transformers import Wav2Vec2Model
32
+ from typing import Optional
33
+
34
+ from phonological_features import NUM_FEATURES, NUM_OUTPUT_NODES, BLANK_IDX
35
+
36
+
37
+ class PhonologicalWav2Vec2(nn.Module):
38
+ """
39
+ wav2vec2-based phonological feature detection model.
40
+
41
+ Args:
42
+ pretrained_model_name (str): HuggingFace model ID.
43
+ Paper uses 'facebook/wav2vec2-large-robust' (best performing).
44
+ num_output_nodes (int): 71 = 35(+att) + 35(-att) + 1(blank).
45
+ freeze_cnn_encoder (bool): Whether to freeze CNN feature extractor.
46
+ Paper Section 5.2: "its parameters were fixed during fine-tuning".
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ pretrained_model_name: str = "facebook/wav2vec2-large-robust",
52
+ num_output_nodes: int = NUM_OUTPUT_NODES,
53
+ freeze_cnn_encoder: bool = True,
54
+ ):
55
+ super().__init__()
56
+ self.num_output_nodes = num_output_nodes
57
+ self.num_features = NUM_FEATURES
58
+ self.blank_idx = BLANK_IDX
59
+
60
+ # ── Load pre-trained wav2vec2 ─────────────────────────────────────
61
+ print(f"[PhonologicalWav2Vec2] Loading '{pretrained_model_name}' ...")
62
+ self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_model_name)
63
+
64
+ # ── Freeze CNN encoder (feature extractor) ────────────────────────
65
+ # Paper: "Except for the CNN encoder layer, the whole network was
66
+ # then fine-tuned"
67
+ if freeze_cnn_encoder:
68
+ self.wav2vec2.feature_extractor._freeze_parameters()
69
+ print("[PhonologicalWav2Vec2] CNN encoder FROZEN.")
70
+
71
+ # ── Linear projection head (Fig. 2) ──────────────────────────────
72
+ # "A linear layer was added on top of the transformer module with
73
+ # number of nodes equals to the number of target phonological-features"
74
+ hidden_size = self.wav2vec2.config.hidden_size
75
+ self.classifier = nn.Linear(hidden_size, num_output_nodes)
76
+
77
+
78
+ print(f"[PhonologicalWav2Vec2] hidden_size={hidden_size}, "
79
+ f"output_nodes={num_output_nodes}")
80
+
81
+ def forward(
82
+ self,
83
+ input_values: torch.Tensor,
84
+ attention_mask: Optional[torch.Tensor] = None,
85
+ apply_spec_augment: bool = False,
86
+ ) -> tuple[torch.Tensor, torch.Tensor]:
87
+ """
88
+ Forward pass.
89
+
90
+ Paper Section 5.2: "SpecAugment was applied to the output of the
91
+ CNN encoder to add more variations to the training data."
92
+
93
+ Wav2Vec2Model natively applies SpecAugment between the CNN encoder
94
+ and the transformer when mask_time_indices is provided.
95
+
96
+ Returns:
97
+ logits : (B, T_frames, 71) raw logits
98
+ output_lengths: (B,) number of valid frames per batch item
99
+ """
100
+ # Build mask_time_indices for SpecAugment if training
101
+ mask_time_indices = None
102
+ if apply_spec_augment and self.training:
103
+ # Compute output frame lengths to know valid T per item
104
+ if attention_mask is not None:
105
+ feat_lengths = self._get_feat_extract_output_lengths(
106
+ attention_mask.sum(dim=1)
107
+ )
108
+ else:
109
+ B, T_audio = input_values.shape
110
+ ones = torch.ones(B, dtype=torch.long, device=input_values.device) * T_audio
111
+ feat_lengths = self._get_feat_extract_output_lengths(ones)
112
+
113
+ B = input_values.shape[0]
114
+ T = int(feat_lengths.max().item())
115
+
116
+ # Build boolean mask: mask up to 10% of valid frames per utterance
117
+ mask_time_indices = torch.zeros(B, T, dtype=torch.bool,
118
+ device=input_values.device)
119
+ t_len = max(1, int(T * 0.10))
120
+ for b in range(B):
121
+ valid = int(feat_lengths[b].item())
122
+ if valid > t_len:
123
+ t0 = torch.randint(0, valid - t_len, (1,)).item()
124
+ mask_time_indices[b, t0:t0 + t_len] = True
125
+
126
+ outputs = self.wav2vec2(
127
+ input_values=input_values,
128
+ attention_mask=attention_mask,
129
+ mask_time_indices=mask_time_indices,
130
+ output_hidden_states=False,
131
+ )
132
+ hidden_states = outputs.last_hidden_state # (B, T, 1024)
133
+ logits = self.classifier(hidden_states)
134
+
135
+ if attention_mask is not None:
136
+ output_lengths = self._get_feat_extract_output_lengths(
137
+ attention_mask.sum(dim=1)
138
+ )
139
+ else:
140
+ B, T_audio = input_values.shape
141
+ ones = torch.ones(B, dtype=torch.long, device=input_values.device) * T_audio
142
+ output_lengths = self._get_feat_extract_output_lengths(ones)
143
+
144
+ return logits, output_lengths
145
+
146
+ def _get_feat_extract_output_lengths(
147
+ self, input_lengths: torch.Tensor
148
+ ) -> torch.Tensor:
149
+ return self.wav2vec2._get_feat_extract_output_lengths(input_lengths)
150
+
151
+ @torch.no_grad()
152
+ def decode(
153
+ self,
154
+ logits: torch.Tensor, # (B, T, 71) or (T, 71) for single item
155
+ output_lengths: Optional[torch.Tensor] = None, # (B,) valid frame counts
156
+ ) -> list[list[list[int]]]:
157
+ """
158
+ Greedy CTC decoding per category.
159
+
160
+ For each feature category i, applies argmax over the 3-node slice
161
+ [pos_i, neg_i, blank] and collapses repeated labels + blanks.
162
+
163
+ Paper Section 3.3, Eq. 7:
164
+ h_i(x) = argmax_j y^t_{i,j}
165
+
166
+ Args:
167
+ logits : (B, T, 71) raw model logits
168
+ output_lengths: (B,) number of valid (non-padded) frames per item.
169
+ If None, all T frames are used (may include padding noise).
170
+
171
+ Returns:
172
+ decoded: [B][35] list of decoded label sequences
173
+ Each label sequence contains +att(True) or -att(False)
174
+ """
175
+ if logits.dim() == 2:
176
+ logits = logits.unsqueeze(0)
177
+
178
+ B, T, _ = logits.shape
179
+ decoded_batch = []
180
+
181
+ for b in range(B):
182
+ valid_T = T if output_lengths is None else int(output_lengths[b].item())
183
+ decoded_features = []
184
+ for feat_idx in range(self.num_features):
185
+ pos_node = feat_idx
186
+ neg_node = feat_idx + self.num_features
187
+
188
+ # Extract 3-node slice over valid frames only: (valid_T, 3)
189
+ cat_logits = torch.stack([
190
+ logits[b, :valid_T, pos_node],
191
+ logits[b, :valid_T, neg_node],
192
+ logits[b, :valid_T, self.blank_idx],
193
+ ], dim=-1) # (valid_T, 3)
194
+
195
+ # Argmax: 0=+att, 1=-att, 2=blank
196
+ preds = cat_logits.argmax(dim=-1) # (valid_T,)
197
+
198
+ # CTC collapse: remove blanks and repeated labels
199
+ collapsed = []
200
+ prev = -1
201
+ for p in preds.tolist():
202
+ if p == 2: # blank
203
+ prev = -1
204
+ continue
205
+ if p != prev:
206
+ collapsed.append(p == 0) # True=+att, False=-att
207
+ prev = p
208
+
209
+ decoded_features.append(collapsed)
210
+ decoded_batch.append(decoded_features)
211
+
212
+ return decoded_batch
213
+
214
+ def count_parameters(self) -> dict:
215
+ """Count trainable vs frozen parameters."""
216
+ total = sum(p.numel() for p in self.parameters())
217
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
218
+ frozen = total - trainable
219
+ return {"total": total, "trainable": trainable, "frozen": frozen}
220
+
221
+
222
+ # ─────────────────────────────────────────────────────────────────────────────
223
+ # Phoneme-level baseline model (for comparison, paper Section 3)
224
+ # Same architecture but with 40 output nodes (39 phonemes + 1 blank)
225
+ # and standard CTC loss
226
+ # ─────────────────────────────────────────────────────────────────────────────
227
+ class PhonemeLevelWav2Vec2(nn.Module):
228
+ """
229
+ Phoneme-level MDD baseline (paper Section 3, Fig. 1 top branch).
230
+ Uses standard CTC with 39 phonemes + blank.
231
+ """
232
+
233
+ def __init__(
234
+ self,
235
+ pretrained_model_name: str = "facebook/wav2vec2-large-robust",
236
+ num_phonemes: int = 39,
237
+ freeze_cnn_encoder: bool = True,
238
+ ):
239
+ super().__init__()
240
+ self.num_phonemes = num_phonemes
241
+ self.blank_idx = num_phonemes # index 39 = blank
242
+
243
+ self.wav2vec2 = Wav2Vec2Model.from_pretrained(pretrained_model_name)
244
+ if freeze_cnn_encoder:
245
+ self.wav2vec2.feature_extractor._freeze_parameters()
246
+
247
+ hidden_size = self.wav2vec2.config.hidden_size
248
+ # 40 nodes: 39 phonemes + 1 blank
249
+ self.classifier = nn.Linear(hidden_size, num_phonemes + 1)
250
+
251
+ def forward(self, input_values, attention_mask=None):
252
+ outputs = self.wav2vec2(
253
+ input_values=input_values,
254
+ attention_mask=attention_mask,
255
+ )
256
+ hidden_states = outputs.last_hidden_state
257
+ logits = self.classifier(hidden_states)
258
+
259
+ if attention_mask is not None:
260
+ output_lengths = self.wav2vec2._get_feat_extract_output_lengths(
261
+ attention_mask.sum(dim=1))
262
+ else:
263
+ B, T = input_values.shape
264
+ ones = torch.ones(B, dtype=torch.long, device=input_values.device) * T
265
+ output_lengths = self.wav2vec2._get_feat_extract_output_lengths(ones)
266
+
267
+ return logits, output_lengths