Minh Toàn commited on
Commit
4138b08
·
verified ·
1 Parent(s): 46f6b62

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. demo_all_tracks_gradio.ipynb +482 -0
  2. demo_all_tracks_gradio_pipeline.py +390 -0
  3. demo_run_from_hf.ipynb +119 -0
  4. demo_run_from_hf_pipeline.py +59 -0
  5. track1/demo_track1_gradio.ipynb +144 -0
  6. track1/demo_track1_gradio_pipeline.py +100 -0
  7. track1/track1_baseline.ipynb +108 -0
  8. track1/track1_baseline_pipeline.py +64 -0
  9. track2/demo_track2_emotion_gradio.ipynb +533 -0
  10. track2/demo_track2_emotion_gradio_pipeline.py +431 -0
  11. track2/demo_track2_gradio.ipynb +175 -0
  12. track2/demo_track2_gradio_pipeline.py +120 -0
  13. track2/exp02_train_emos.ipynb +542 -0
  14. track2/exp02_train_emos_pipeline.py +407 -0
  15. track2/exp03_emos_sailer.ipynb +392 -0
  16. track2/exp03_emos_sailer_pipeline.py +264 -0
  17. track2/exp04_fusion.ipynb +790 -0
  18. track2/exp04_fusion_pipeline.py +652 -0
  19. track2/exp05_vad_audeering.ipynb +443 -0
  20. track2/exp05_vad_audeering_pipeline.py +303 -0
  21. track2/exp06_qmos_train.ipynb +628 -0
  22. track2/exp06_qmos_train_pipeline.py +502 -0
  23. track2/exp07_fusion_qmos.ipynb +780 -0
  24. track2/exp07_fusion_qmos_pipeline.py +654 -0
  25. track2/exp08_finetune_emotion.ipynb +820 -0
  26. track2/exp08_finetune_emotion_pipeline.py +673 -0
  27. track2/exp08b_finetune_resume.ipynb +782 -0
  28. track2/exp08b_finetune_resume_pipeline.py +642 -0
  29. track2/exp09a_qmos_utmosv2_probe.ipynb +339 -0
  30. track2/exp09a_qmos_utmosv2_probe_pipeline.py +239 -0
  31. track2/exp10_finetune_audeering.ipynb +691 -0
  32. track2/exp10_finetune_audeering_pipeline.py +553 -0
  33. track2/exp11_finetune_joint.ipynb +805 -0
  34. track2/exp11_finetune_joint_pipeline.py +665 -0
  35. track2/exp12_wavlm_scratch.ipynb +690 -0
  36. track2/exp12_wavlm_scratch_pipeline.py +564 -0
  37. track2/exp13_finetune_qmos.ipynb +733 -0
  38. track2/exp13_finetune_qmos_pipeline.py +607 -0
  39. track2/exp14_mamba_head.ipynb +952 -0
  40. track2/exp14_mamba_head_pipeline.py +798 -0
  41. track2/exp15_predict.ipynb +698 -0
  42. track2/exp15_predict_pipeline.py +554 -0
  43. track2/exp15_wavlm_mamba_emotion.ipynb +1081 -0
  44. track2/exp15_wavlm_mamba_emotion_pipeline.py +920 -0
  45. track2/exp16_llm_judge.ipynb +650 -0
  46. track2/exp16_llm_judge_pipeline.py +480 -0
  47. track2/track2_baseline.ipynb +130 -0
  48. track2/track2_baseline_pipeline.py +321 -0
  49. track2/track2_prepare_data.ipynb +249 -0
  50. track2/track2_prepare_data_pipeline.py +164 -0
demo_all_tracks_gradio.ipynb ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "739ac809",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 — Demo Gradio GỘP 3 TRACK (1 link cho mentor)\n",
9
+ "\n",
10
+ "Gộp 3 demo lẻ (`track1/`, `track2/`, `track3/`) vào **1 app Gradio 3 tab**:\n",
11
+ "- **Track 1** · Speech Enhancement → **ACR** (chất lượng A) + **CCR** (so A vs B). Model: URGENT-MOS.\n",
12
+ "- **Track 2** · Emotional TTS → **EMOS / CAT / VAD**. Model TỐT NHẤT = **exp08** (WavLM fine-tune + audeering).\n",
13
+ "- **Track 3** · Speaker/Accent → **spk_sim / acc_sim**. Model: ECAPA fine-tuned (baseline BTC).\n",
14
+ "\n",
15
+ "> **Lazy-load:** mỗi track chỉ nạp model khi bạn bấm \"Dự đoán\" ở tab đó → tab nào thiếu checkpoint/repo\n",
16
+ "> chỉ báo lỗi trong tab đó, KHÔNG sập cả app. Track 1 & 3 chỉ cần Internet; Track 2 cần thêm checkpoint exp08.\n",
17
+ "\n",
18
+ "### Cách chạy trên Kaggle\n",
19
+ "1. Settings → **GPU T4 + Internet On**.\n",
20
+ "2. (Cho Track 2) Add Input: dataset Track 2 (`sets/train.csv`, `wav/`, `metadata.csv`) + dataset chứa\n",
21
+ " `ft_emotion_full_20epoch.pt` (slug `toanminh222/cache-exp8`). Thiếu thì 2 tab kia vẫn chạy.\n",
22
+ "3. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h) → gửi mentor."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "6f7119e0",
28
+ "metadata": {},
29
+ "source": [
30
+ "## 1. Cài đặt gói (1 lần cho cả 3 track)"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "4da07abf",
37
+ "metadata": {
38
+ "lines_to_next_cell": 1
39
+ },
40
+ "outputs": [],
41
+ "source": [
42
+ "!pip install -q gradio librosa soundfile speechbrain torchaudio loralib scipy scikit-learn pandas tqdm\n",
43
+ "\n",
44
+ "import os, sys, glob, subprocess\n",
45
+ "\n",
46
+ "def pip_install(*pkgs):\n",
47
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=False)\n",
48
+ "\n",
49
+ "# Cài nhẹ (Kaggle có sẵn torch/transformers/numpy → KHÔNG đụng numpy để tránh lệch ABI)\n",
50
+ "pip_install(\"gradio\", \"librosa\", \"soundfile\", \"speechbrain\", \"torchaudio\",\n",
51
+ " \"loralib\", \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
52
+ "\n",
53
+ "import librosa\n",
54
+ "import numpy as np\n",
55
+ "import torch\n",
56
+ "\n",
57
+ "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
58
+ "SR = 16000\n",
59
+ "print(\"Device:\", DEVICE, (\"✅ \" + torch.cuda.get_device_name(0)) if DEVICE == \"cuda\" else \"⚠️ CPU (chậm)\")\n",
60
+ "\n",
61
+ "def _stem(p):\n",
62
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
63
+ "\n",
64
+ "def _scalar(x):\n",
65
+ " return float(x.item()) if hasattr(x, \"item\") else float(x)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "markdown",
70
+ "id": "8ce385ee",
71
+ "metadata": {},
72
+ "source": [
73
+ "## 2. TRACK 1 — URGENT-MOS (ACR + CCR) · lazy-load"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "id": "8f09a3ee",
80
+ "metadata": {
81
+ "lines_to_next_cell": 1
82
+ },
83
+ "outputs": [],
84
+ "source": [
85
+ "URGENT_REPO = \"/kaggle/working/URGENT-MOS\"\n",
86
+ "URGENT_CKPT = \"urgent-challenge/urgent-mos-f1c1m5dcorpus\" # tự tải từ HuggingFace\n",
87
+ "_T1 = {}\n",
88
+ "\n",
89
+ "def _t1_load():\n",
90
+ " \"\"\"Nạp URGENT-MOS 1 lần (clone repo + sys.path + checkpoint).\"\"\"\n",
91
+ " if \"m\" in _T1:\n",
92
+ " return _T1[\"m\"]\n",
93
+ " if not os.path.isdir(URGENT_REPO):\n",
94
+ " subprocess.run(f\"git clone -q https://github.com/vvwangvv/URGENT-MOS.git {URGENT_REPO}\",\n",
95
+ " shell=True, check=True)\n",
96
+ " if URGENT_REPO not in sys.path:\n",
97
+ " sys.path.insert(0, URGENT_REPO)\n",
98
+ " import importlib\n",
99
+ " importlib.invalidate_caches()\n",
100
+ " try:\n",
101
+ " importlib.import_module(\"urgent_mos.api.infer\")\n",
102
+ " except Exception:\n",
103
+ " subprocess.run(f\"pip install -q -e {URGENT_REPO}\", shell=True, check=False)\n",
104
+ " importlib.invalidate_caches()\n",
105
+ " from urgent_mos.utils import load_model_from_checkpoint\n",
106
+ " m = load_model_from_checkpoint(URGENT_CKPT, DEVICE)\n",
107
+ " m.eval()\n",
108
+ " _T1[\"m\"] = m\n",
109
+ " return m\n",
110
+ "\n",
111
+ "def t1_predict(audio_a, audio_b):\n",
112
+ " if not audio_a:\n",
113
+ " return \"⚠️ Hãy tải lên ít nhất **Audio A**.\"\n",
114
+ " try:\n",
115
+ " m = _t1_load()\n",
116
+ " from urgent_mos.api.infer import infer, infer_pairs\n",
117
+ " wa = torch.from_numpy(librosa.load(audio_a, sr=SR, mono=True)[0]).float()\n",
118
+ " acr_a = max(1.0, min(5.0, _scalar(infer(m, [wa], sample_rate=[SR],\n",
119
+ " batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
120
+ " out = f\"**ACR (Audio A): {acr_a:.3f}** (chất lượng tuyệt đối, thang 1–5)\"\n",
121
+ " if audio_b:\n",
122
+ " wb = torch.from_numpy(librosa.load(audio_b, sr=SR, mono=True)[0]).float()\n",
123
+ " acr_b = max(1.0, min(5.0, _scalar(infer(m, [wb], sample_rate=[SR],\n",
124
+ " batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
125
+ " ccr = max(-3.0, min(3.0, _scalar(infer_pairs(m, [(wa, wb)], sample_rate=[(SR, SR)],\n",
126
+ " batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
127
+ " out += (f\"\\n\\n**ACR (Audio B): {acr_b:.3f}**\"\n",
128
+ " f\"\\n\\n**CCR (A so với B): {ccr:+.3f}** (>0: A tốt hơn B; thang −3..+3)\")\n",
129
+ " return out\n",
130
+ " except Exception as e:\n",
131
+ " return f\"❌ Track 1 lỗi: `{repr(e)}`\\n\\nKiểm tra **Internet On** (cần tải URGENT-MOS từ GitHub/HuggingFace).\""
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "id": "21d2f185",
137
+ "metadata": {},
138
+ "source": [
139
+ "## 3. TRACK 3 — ECAPA fine-tuned (spk_sim + acc_sim) · lazy-load"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "id": "143f5e27",
146
+ "metadata": {
147
+ "lines_to_next_cell": 1
148
+ },
149
+ "outputs": [],
150
+ "source": [
151
+ "T3_REPO = \"/kaggle/working/vmc2026-baselines/track3\"\n",
152
+ "CKPT_SPK = f\"{T3_REPO}/official-egs/spk_sim_adamw_lr1e-3/model_spk_sim_step20000.pt\"\n",
153
+ "CKPT_ACC = f\"{T3_REPO}/official-egs/acc_sim_adamw_lr1e-3/model_acc_sim_step20000.pt\"\n",
154
+ "_T3 = {}\n",
155
+ "\n",
156
+ "def _t3_load():\n",
157
+ " if \"spk\" in _T3:\n",
158
+ " return _T3\n",
159
+ " repo_root = \"/kaggle/working/vmc2026-baselines\"\n",
160
+ " if not os.path.isdir(repo_root):\n",
161
+ " subprocess.run(f\"git clone -q https://github.com/voicemos-challenge/vmc2026-baselines.git {repo_root}\",\n",
162
+ " shell=True, check=True)\n",
163
+ " if T3_REPO not in sys.path:\n",
164
+ " sys.path.insert(0, T3_REPO)\n",
165
+ " from model import Model\n",
166
+ " spk = Model(mlp_heads=[\"spk_sim\"])\n",
167
+ " spk.load_state_dict(torch.load(CKPT_SPK, map_location=\"cpu\"))\n",
168
+ " acc = Model(mlp_heads=[\"acc_sim\"])\n",
169
+ " acc.load_state_dict(torch.load(CKPT_ACC, map_location=\"cpu\"))\n",
170
+ " _T3.update(spk=spk.to(DEVICE).eval(), acc=acc.to(DEVICE).eval())\n",
171
+ " return _T3\n",
172
+ "\n",
173
+ "def t3_predict(audio_test, audio_ref):\n",
174
+ " if not audio_test or not audio_ref:\n",
175
+ " return \"⚠️ Cần **cả 2 file**: audio test + audio reference.\"\n",
176
+ " try:\n",
177
+ " M = _t3_load()\n",
178
+ " ta = torch.from_numpy(librosa.load(audio_test, sr=SR, mono=True)[0]).float().unsqueeze(0).to(DEVICE)\n",
179
+ " tb = torch.from_numpy(librosa.load(audio_ref, sr=SR, mono=True)[0]).float().unsqueeze(0).to(DEVICE)\n",
180
+ " with torch.no_grad():\n",
181
+ " o_spk = M[\"spk\"](ta, tb)\n",
182
+ " spk = float(o_spk[\"spk_sim\"].item())\n",
183
+ " acc = float(M[\"acc\"](ta, tb)[\"acc_sim\"].item())\n",
184
+ " cos = float(o_spk[\"cos_sim\"].item())\n",
185
+ " return (f\"**Speaker similarity: {spk:.3f}** (1–5)\\n\\n\"\n",
186
+ " f\"**Accent similarity : {acc:.3f}** (1–5)\\n\\n\"\n",
187
+ " f\"Cosine zero-shot (tham khảo): {cos:.3f}\")\n",
188
+ " except Exception as e:\n",
189
+ " return f\"❌ Track 3 lỗi: `{repr(e)}`\\n\\nKiểm tra **Internet On** (clone repo baseline chứa checkpoint).\""
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "markdown",
194
+ "id": "6fd2047d",
195
+ "metadata": {},
196
+ "source": [
197
+ "## 4. TRACK 2 — exp08 Emotional TTS Evaluator (EMOS/CAT/VAD) · lazy-load\n",
198
+ "\n",
199
+ "Model TỐT NHẤT: WavLM fine-tune (warm-start SAILER) + audeering frozen → trunk → 3 head.\n",
200
+ "Hằng kiến trúc PHẢI khớp exp08 (ckpt không lưu các số này)."
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "acfc33a7",
207
+ "metadata": {
208
+ "lines_to_next_cell": 1
209
+ },
210
+ "outputs": [],
211
+ "source": [
212
+ "EMO_MAX_SEC, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, USE_AMP = 8, 512, 128, 0.3, True\n",
213
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
214
+ "_EMO_ALIAS = {\n",
215
+ " \"angry\": \"angry\", \"anger\": \"angry\", \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
216
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\", \"sad\": \"sad\", \"sadness\": \"sad\",\n",
217
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
218
+ "}\n",
219
+ "def norm_emotion(label):\n",
220
+ " key = str(label).strip().lower()\n",
221
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
222
+ "\n",
223
+ "def _t2_find_ckpt():\n",
224
+ " for pat in [\"ft_emotion_full_20epoch*.pt\", \"ft_emotion_full*.pt\"]:\n",
225
+ " for base in [\"/kaggle/input\", \"/kaggle/working\"]:\n",
226
+ " hits = sorted(glob.glob(os.path.join(base, \"**\", pat), recursive=True))\n",
227
+ " if hits:\n",
228
+ " return hits[0]\n",
229
+ " return \"\"\n",
230
+ "\n",
231
+ "_T2 = {}\n",
232
+ "\n",
233
+ "def _t2_load():\n",
234
+ " if \"infer\" in _T2:\n",
235
+ " return _T2[\"infer\"]\n",
236
+ " import torch.nn as nn\n",
237
+ " ckpt_path = _t2_find_ckpt()\n",
238
+ " assert ckpt_path, \"Không thấy ft_emotion_full*.pt — Add Input dataset checkpoint exp08 (slug toanminh222/cache-exp8)?\"\n",
239
+ " # code SAILER để dựng backbone WavLM\n",
240
+ " repo = \"/kaggle/working/vox-profile-release\"\n",
241
+ " if not os.path.exists(repo):\n",
242
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
243
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", repo], check=True)\n",
244
+ " if repo not in sys.path:\n",
245
+ " sys.path.insert(0, repo)\n",
246
+ "\n",
247
+ " ckpt = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)\n",
248
+ " assert \"wavlm\" in ckpt and \"heads\" in ckpt, \"Checkpoint thiếu 'wavlm'/'heads' → cần bản đủ ft_emotion_full_20epoch.pt.\"\n",
249
+ " AUD_DIM = int(ckpt.get(\"AUD_DIM\", 0)); USE_AUDEERING = AUD_DIM > 0\n",
250
+ "\n",
251
+ " def find_hf_backbone(module):\n",
252
+ " cands = []\n",
253
+ " for name, m in module.named_modules():\n",
254
+ " enc = getattr(m, \"encoder\", None)\n",
255
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
256
+ " and getattr(enc, \"layers\", None) is not None:\n",
257
+ " cands.append((name, m))\n",
258
+ " if not cands:\n",
259
+ " return None, None\n",
260
+ " cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
261
+ " return cands[0]\n",
262
+ "\n",
263
+ " wavlm = None\n",
264
+ " try:\n",
265
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
266
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
267
+ " _name, wavlm = find_hf_backbone(_wrapper)\n",
268
+ " except Exception as e:\n",
269
+ " print(\"⚠️ SAILER wrapper lỗi:\", repr(e), \"→ fallback WavLM trắng.\")\n",
270
+ " if wavlm is None:\n",
271
+ " from transformers import WavLMModel\n",
272
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
273
+ " wavlm = wavlm.to(DEVICE).eval()\n",
274
+ " WAVLM_DIM = int(wavlm.config.hidden_size)\n",
275
+ " wavlm.config.layerdrop = 0.0\n",
276
+ " wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
277
+ "\n",
278
+ " def masked_mean(hidden, attn_mask):\n",
279
+ " if attn_mask is None:\n",
280
+ " return hidden.mean(dim=1)\n",
281
+ " try:\n",
282
+ " fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
283
+ " except Exception:\n",
284
+ " return hidden.mean(dim=1)\n",
285
+ " fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
286
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
287
+ "\n",
288
+ " # audeering frozen (nếu ckpt dùng)\n",
289
+ " aud_backbone = aud_head = aud_proc = None\n",
290
+ " if USE_AUDEERING:\n",
291
+ " from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
292
+ " from huggingface_hub import hf_hub_download\n",
293
+ " AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
294
+ " aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
295
+ " aud_backbone = Wav2Vec2Model(Wav2Vec2Config.from_pretrained(AUD_NAME))\n",
296
+ " try:\n",
297
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
298
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
299
+ " except Exception:\n",
300
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
301
+ " bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
302
+ " aud_backbone.load_state_dict(bb_sd, strict=False)\n",
303
+ " _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
304
+ " aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(),\n",
305
+ " nn.Linear(_hid, _sd[\"classifier.out_proj.weight\"].shape[0]))\n",
306
+ " aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
307
+ " aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
308
+ " aud_backbone = aud_backbone.to(DEVICE).eval(); aud_head = aud_head.to(DEVICE).eval()\n",
309
+ "\n",
310
+ " @torch.no_grad()\n",
311
+ " def audeering_feat(wave):\n",
312
+ " x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
313
+ " x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(DEVICE)\n",
314
+ " h = aud_backbone(x)[0].mean(dim=1)\n",
315
+ " out = aud_head(h)[0].cpu().numpy()\n",
316
+ " vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32)\n",
317
+ " return np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
318
+ "\n",
319
+ " N_EMO = len(EMOTIONS5)\n",
320
+ " TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
321
+ "\n",
322
+ " class EmoHeads(nn.Module):\n",
323
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
324
+ " super().__init__()\n",
325
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
326
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
327
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
328
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
329
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
330
+ " def forward(self, feat, tgt):\n",
331
+ " h = self.trunk(feat)\n",
332
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
333
+ "\n",
334
+ " heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(DEVICE).eval()\n",
335
+ " heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
336
+ " emos_mu, emos_sd = float(ckpt[\"emos_mu\"]), float(ckpt[\"emos_sd\"])\n",
337
+ " vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
338
+ "\n",
339
+ " def onehot_target(tgt):\n",
340
+ " v = np.zeros(N_EMO, dtype=np.float32)\n",
341
+ " if tgt in EMOTIONS5:\n",
342
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
343
+ " return v\n",
344
+ "\n",
345
+ " @torch.no_grad()\n",
346
+ " def infer_wave(wave, target_emotion):\n",
347
+ " wave = wave[: EMO_MAX_SEC * SR].astype(np.float32)\n",
348
+ " iv = torch.from_numpy(wave).unsqueeze(0).to(DEVICE)\n",
349
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=DEVICE)\n",
350
+ " tgt = torch.from_numpy(onehot_target(norm_emotion(target_emotion) if target_emotion else None)).unsqueeze(0).to(DEVICE)\n",
351
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and DEVICE == \"cuda\"):\n",
352
+ " fw = wavlm(iv, attention_mask=am).last_hidden_state\n",
353
+ " fw = masked_mean(fw, am)\n",
354
+ " if USE_AUDEERING:\n",
355
+ " fw = torch.cat([fw, torch.from_numpy(audeering_feat(wave)).unsqueeze(0).to(DEVICE)], dim=1)\n",
356
+ " emos_p, cat_l, vad_p = heads(fw, tgt)\n",
357
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
358
+ " cat5 = torch.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
359
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
360
+ " return emos, cat5, vad3\n",
361
+ "\n",
362
+ " print(f\"✅ Track 2 exp08 nạp xong (audeering {'ON' if USE_AUDEERING else 'OFF'}) từ {ckpt_path}\")\n",
363
+ " _T2[\"infer\"] = infer_wave\n",
364
+ " return infer_wave\n",
365
+ "\n",
366
+ "def t2_predict(audio, target_emotion):\n",
367
+ " \"\"\"Trả: verdict(md), EMOS(number), CAT(label dict), VAL, ARO, DOM.\"\"\"\n",
368
+ " if not audio:\n",
369
+ " return \"### ⚠️ Hãy tải audio (giọng TTS).\", None, {}, None, None, None\n",
370
+ " try:\n",
371
+ " infer_wave = _t2_load()\n",
372
+ " wave, _ = librosa.load(audio, sr=SR, mono=True)\n",
373
+ " emos, cat5, vad3 = infer_wave(wave, target_emotion)\n",
374
+ " cat_dict = {e: float(cat5[i]) for i, e in enumerate(EMOTIONS5)}\n",
375
+ " perceived = EMOTIONS5[int(np.argmax(cat5))]\n",
376
+ " if target_emotion:\n",
377
+ " match = \"✅ **KHỚP** target\" if perceived == norm_emotion(target_emotion) else \"⚠️ **LỆCH** target\"\n",
378
+ " band = \"🟢 tốt\" if emos >= 4 else (\"🟡 khá\" if emos >= 3 else \"🔴 yếu\")\n",
379
+ " verdict = (f\"### Kết luận biểu cảm\\n\"\n",
380
+ " f\"- Cảm xúc cảm nhận: **{perceived}** → {match} (`{target_emotion}`)\\n\"\n",
381
+ " f\"- EMOS = **{emos:.2f}/5** → biểu cảm {band}\")\n",
382
+ " else:\n",
383
+ " verdict = (f\"### Kết luận biểu cảm\\n- Cảm xúc cảm nhận: **{perceived}**\\n\"\n",
384
+ " f\"- *(Chọn cảm xúc target để bật EMOS — độ khớp ý đồ)*\")\n",
385
+ " emos = None\n",
386
+ " return verdict, (round(emos, 3) if emos is not None else None), cat_dict, \\\n",
387
+ " round(float(vad3[0]), 3), round(float(vad3[1]), 3), round(float(vad3[2]), 3)\n",
388
+ " except Exception as e:\n",
389
+ " return f\"### ❌ Track 2 lỗi\\n`{repr(e)}`\\n\\nĐã Add Input checkpoint exp08 + Internet On chưa?\", None, {}, None, None, None"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "markdown",
394
+ "id": "81da6580",
395
+ "metadata": {},
396
+ "source": [
397
+ "## 5. Giao diện Gradio GỘP — 3 tab + launch"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": null,
403
+ "id": "418857af",
404
+ "metadata": {},
405
+ "outputs": [],
406
+ "source": [
407
+ "import gradio as gr\n",
408
+ "\n",
409
+ "INTRO = (\n",
410
+ " \"# 🎙️ VoiceMOS Challenge 2026 — Demo 3 Track\\n\"\n",
411
+ " \"Một link cho cả 3 track. Mỗi tab nhận audio → trả điểm bộ chấm tự động.\\n\\n\"\n",
412
+ " \"| Track | Bài toán | Output |\\n|---|---|---|\\n\"\n",
413
+ " \"| **1** | Speech Enhancement | ACR (chất lượng) · CCR (so sánh cặp) |\\n\"\n",
414
+ " \"| **2** | Emotional TTS | EMOS · CAT · VAD (5 cột cảm xúc) — *model tốt nhất exp08* |\\n\"\n",
415
+ " \"| **3** | Speaker/Accent | spk_sim · acc_sim |\\n\\n\"\n",
416
+ " \"> Model nạp **lần đầu bấm nút** (chờ ~1–2 phút tải). Tab thiếu checkpoint chỉ báo lỗi trong tab đó.\"\n",
417
+ ")\n",
418
+ "\n",
419
+ "with gr.Blocks(title=\"VMC2026 — Demo 3 Track\") as demo:\n",
420
+ " gr.Markdown(INTRO)\n",
421
+ "\n",
422
+ " with gr.Tab(\"1️⃣ Track 1 · Chất lượng (ACR/CCR)\"):\n",
423
+ " gr.Markdown(\"Tải **Audio A** → ACR (1–5). Tải thêm **Audio B** → CCR (A vs B, −3..+3, >0 = A tốt hơn).\")\n",
424
+ " t1a = gr.Audio(type=\"filepath\", label=\"Audio A (bắt buộc)\")\n",
425
+ " t1b = gr.Audio(type=\"filepath\", label=\"Audio B (tùy chọn — để tính CCR)\")\n",
426
+ " t1out = gr.Markdown()\n",
427
+ " gr.Button(\"Dự đoán\", variant=\"primary\").click(t1_predict, [t1a, t1b], t1out)\n",
428
+ "\n",
429
+ " with gr.Tab(\"2️⃣ Track 2 · Cảm xúc (EMOS/CAT/VAD)\"):\n",
430
+ " gr.Markdown(\"Model tốt nhất **exp08** (WavLM fine-tune + audeering, offline). \"\n",
431
+ " \"Chọn **cảm xúc target** để bật EMOS (độ khớp ý đồ).\")\n",
432
+ " with gr.Row():\n",
433
+ " with gr.Column(scale=1):\n",
434
+ " t2a = gr.Audio(type=\"filepath\", label=\"Audio (giọng TTS)\")\n",
435
+ " t2tgt = gr.Dropdown(EMOTIONS5, label=\"🎯 Cảm xúc target (cho EMOS)\")\n",
436
+ " t2btn = gr.Button(\"Chấm cảm xúc\", variant=\"primary\")\n",
437
+ " with gr.Column(scale=2):\n",
438
+ " t2verdict = gr.Markdown()\n",
439
+ " t2emos = gr.Number(label=\"EMOS — khớp cảm xúc target (1–5)\", interactive=False)\n",
440
+ " t2cat = gr.Label(label=\"CAT — phân bố cảm xúc cảm nhận (5 lớp)\")\n",
441
+ " gr.Markdown(\"**VAD — toạ độ cảm xúc liên tục (1–5):**\")\n",
442
+ " with gr.Row():\n",
443
+ " t2val = gr.Number(label=\"Valence (tích cực↑)\", interactive=False)\n",
444
+ " t2aro = gr.Number(label=\"Arousal (kích động↑)\", interactive=False)\n",
445
+ " t2dom = gr.Number(label=\"Dominance (chi phối↑)\", interactive=False)\n",
446
+ " t2btn.click(t2_predict, [t2a, t2tgt], [t2verdict, t2emos, t2cat, t2val, t2aro, t2dom])\n",
447
+ "\n",
448
+ " with gr.Tab(\"3️⃣ Track 3 · Speaker/Accent\"):\n",
449
+ " gr.Markdown(\"Tải **audio cần đánh giá** + **audio tham chiếu** → độ giống người nói & accent (1–5).\")\n",
450
+ " t3t = gr.Audio(type=\"filepath\", label=\"Audio cần đánh giá (test)\")\n",
451
+ " t3r = gr.Audio(type=\"filepath\", label=\"Audio tham chiếu (reference)\")\n",
452
+ " t3out = gr.Markdown()\n",
453
+ " gr.Button(\"Dự đoán\", variant=\"primary\").click(t3_predict, [t3t, t3r], t3out)\n",
454
+ "\n",
455
+ "demo.launch(share=True)"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "markdown",
460
+ "id": "b04c3d5e",
461
+ "metadata": {},
462
+ "source": [
463
+ "## Ghi chú\n",
464
+ "- **Lazy-load:** mỗi `_tN_load()` nạp model 1 lần rồi cache module-level → tab nào không bấm thì không tốn RAM/VRAM.\n",
465
+ "- Track 1 cần URGENT-MOS (GitHub + HuggingFace); Track 3 clone repo baseline (có sẵn checkpoint); Track 2 cần\n",
466
+ " checkpoint exp08 (`ft_emotion_full_20epoch.pt`, slug `toanminh222/cache-exp8`) + tải WavLM/SAILER/audeering.\n",
467
+ "- Hằng `TRUNK_HIDDEN/HEAD_HIDDEN/EMO_MAX_SEC` của Track 2 PHẢI khớp exp08 (ckpt không lưu) — sai là lệch shape.\n",
468
+ "- 3 tab độc lập: thiếu checkpoint/Internet của 1 track chỉ báo lỗi trong tab đó, 2 tab còn lại vẫn chạy.\n",
469
+ "- Cần **GPU T4 + Internet On**. Bản chỉ Track 2 đầy đủ (có tab metric val nội bộ) ở `track2/demo_track2_emotion_gradio`."
470
+ ]
471
+ }
472
+ ],
473
+ "metadata": {
474
+ "jupytext": {
475
+ "cell_metadata_filter": "-all",
476
+ "main_language": "python",
477
+ "notebook_metadata_filter": "-all"
478
+ }
479
+ },
480
+ "nbformat": 4,
481
+ "nbformat_minor": 5
482
+ }
demo_all_tracks_gradio_pipeline.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 — Demo Gradio GỘP 3 TRACK (1 link cho mentor)
3
+ #
4
+ # Gộp 3 demo lẻ (`track1/`, `track2/`, `track3/`) vào **1 app Gradio 3 tab**:
5
+ # - **Track 1** · Speech Enhancement → **ACR** (chất lượng A) + **CCR** (so A vs B). Model: URGENT-MOS.
6
+ # - **Track 2** · Emotional TTS → **EMOS / CAT / VAD**. Model TỐT NHẤT = **exp08** (WavLM fine-tune + audeering).
7
+ # - **Track 3** · Speaker/Accent → **spk_sim / acc_sim**. Model: ECAPA fine-tuned (baseline BTC).
8
+ #
9
+ # > **Lazy-load:** mỗi track chỉ nạp model khi bạn bấm "Dự đoán" ở tab đó → tab nào thiếu checkpoint/repo
10
+ # > chỉ báo lỗi trong tab đó, KHÔNG sập cả app. Track 1 & 3 chỉ cần Internet; Track 2 cần thêm checkpoint exp08.
11
+ #
12
+ # ### Cách chạy trên Kaggle
13
+ # 1. Settings → **GPU T4 + Internet On**.
14
+ # 2. (Cho Track 2) Add Input: dataset Track 2 (`sets/train.csv`, `wav/`, `metadata.csv`) + dataset chứa
15
+ # `ft_emotion_full_20epoch.pt` (slug `toanminh222/cache-exp8`). Thiếu thì 2 tab kia vẫn chạy.
16
+ # 3. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h) → gửi mentor.
17
+
18
+ # %% [markdown]
19
+ # ## 1. Cài đặt gói (1 lần cho cả 3 track)
20
+
21
+ # %%
22
+ # !pip install -q gradio librosa soundfile speechbrain torchaudio loralib scipy scikit-learn pandas tqdm
23
+
24
+ import os, sys, glob, subprocess
25
+
26
+ def pip_install(*pkgs):
27
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=False)
28
+
29
+ # Cài nhẹ (Kaggle có sẵn torch/transformers/numpy → KHÔNG đụng numpy để tránh lệch ABI)
30
+ pip_install("gradio", "librosa", "soundfile", "speechbrain", "torchaudio",
31
+ "loralib", "scipy", "scikit-learn", "pandas", "tqdm")
32
+
33
+ import librosa
34
+ import numpy as np
35
+ import torch
36
+
37
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
38
+ SR = 16000
39
+ print("Device:", DEVICE, ("✅ " + torch.cuda.get_device_name(0)) if DEVICE == "cuda" else "⚠️ CPU (chậm)")
40
+
41
+ def _stem(p):
42
+ return os.path.splitext(os.path.basename(str(p)))[0]
43
+
44
+ def _scalar(x):
45
+ return float(x.item()) if hasattr(x, "item") else float(x)
46
+
47
+ # %% [markdown]
48
+ # ## 2. TRACK 1 — URGENT-MOS (ACR + CCR) · lazy-load
49
+
50
+ # %%
51
+ URGENT_REPO = "/kaggle/working/URGENT-MOS"
52
+ URGENT_CKPT = "urgent-challenge/urgent-mos-f1c1m5dcorpus" # tự tải từ HuggingFace
53
+ _T1 = {}
54
+
55
+ def _t1_load():
56
+ """Nạp URGENT-MOS 1 lần (clone repo + sys.path + checkpoint)."""
57
+ if "m" in _T1:
58
+ return _T1["m"]
59
+ if not os.path.isdir(URGENT_REPO):
60
+ subprocess.run(f"git clone -q https://github.com/vvwangvv/URGENT-MOS.git {URGENT_REPO}",
61
+ shell=True, check=True)
62
+ if URGENT_REPO not in sys.path:
63
+ sys.path.insert(0, URGENT_REPO)
64
+ import importlib
65
+ importlib.invalidate_caches()
66
+ try:
67
+ importlib.import_module("urgent_mos.api.infer")
68
+ except Exception:
69
+ subprocess.run(f"pip install -q -e {URGENT_REPO}", shell=True, check=False)
70
+ importlib.invalidate_caches()
71
+ from urgent_mos.utils import load_model_from_checkpoint
72
+ m = load_model_from_checkpoint(URGENT_CKPT, DEVICE)
73
+ m.eval()
74
+ _T1["m"] = m
75
+ return m
76
+
77
+ def t1_predict(audio_a, audio_b):
78
+ if not audio_a:
79
+ return "⚠️ Hãy tải lên ít nhất **Audio A**."
80
+ try:
81
+ m = _t1_load()
82
+ from urgent_mos.api.infer import infer, infer_pairs
83
+ wa = torch.from_numpy(librosa.load(audio_a, sr=SR, mono=True)[0]).float()
84
+ acr_a = max(1.0, min(5.0, _scalar(infer(m, [wa], sample_rate=[SR],
85
+ batch_frames=None, num_workers=0)[0]["mos_overall"])))
86
+ out = f"**ACR (Audio A): {acr_a:.3f}** (chất lượng tuyệt đối, thang 1–5)"
87
+ if audio_b:
88
+ wb = torch.from_numpy(librosa.load(audio_b, sr=SR, mono=True)[0]).float()
89
+ acr_b = max(1.0, min(5.0, _scalar(infer(m, [wb], sample_rate=[SR],
90
+ batch_frames=None, num_workers=0)[0]["mos_overall"])))
91
+ ccr = max(-3.0, min(3.0, _scalar(infer_pairs(m, [(wa, wb)], sample_rate=[(SR, SR)],
92
+ batch_frames=None, num_workers=0)[0]["mos_overall"])))
93
+ out += (f"\n\n**ACR (Audio B): {acr_b:.3f}**"
94
+ f"\n\n**CCR (A so với B): {ccr:+.3f}** (>0: A tốt hơn B; thang −3..+3)")
95
+ return out
96
+ except Exception as e:
97
+ return f"❌ Track 1 lỗi: `{repr(e)}`\n\nKiểm tra **Internet On** (cần tải URGENT-MOS từ GitHub/HuggingFace)."
98
+
99
+ # %% [markdown]
100
+ # ## 3. TRACK 3 — ECAPA fine-tuned (spk_sim + acc_sim) · lazy-load
101
+
102
+ # %%
103
+ T3_REPO = "/kaggle/working/vmc2026-baselines/track3"
104
+ CKPT_SPK = f"{T3_REPO}/official-egs/spk_sim_adamw_lr1e-3/model_spk_sim_step20000.pt"
105
+ CKPT_ACC = f"{T3_REPO}/official-egs/acc_sim_adamw_lr1e-3/model_acc_sim_step20000.pt"
106
+ _T3 = {}
107
+
108
+ def _t3_load():
109
+ if "spk" in _T3:
110
+ return _T3
111
+ repo_root = "/kaggle/working/vmc2026-baselines"
112
+ if not os.path.isdir(repo_root):
113
+ subprocess.run(f"git clone -q https://github.com/voicemos-challenge/vmc2026-baselines.git {repo_root}",
114
+ shell=True, check=True)
115
+ if T3_REPO not in sys.path:
116
+ sys.path.insert(0, T3_REPO)
117
+ from model import Model
118
+ spk = Model(mlp_heads=["spk_sim"])
119
+ spk.load_state_dict(torch.load(CKPT_SPK, map_location="cpu"))
120
+ acc = Model(mlp_heads=["acc_sim"])
121
+ acc.load_state_dict(torch.load(CKPT_ACC, map_location="cpu"))
122
+ _T3.update(spk=spk.to(DEVICE).eval(), acc=acc.to(DEVICE).eval())
123
+ return _T3
124
+
125
+ def t3_predict(audio_test, audio_ref):
126
+ if not audio_test or not audio_ref:
127
+ return "⚠️ Cần **cả 2 file**: audio test + audio reference."
128
+ try:
129
+ M = _t3_load()
130
+ ta = torch.from_numpy(librosa.load(audio_test, sr=SR, mono=True)[0]).float().unsqueeze(0).to(DEVICE)
131
+ tb = torch.from_numpy(librosa.load(audio_ref, sr=SR, mono=True)[0]).float().unsqueeze(0).to(DEVICE)
132
+ with torch.no_grad():
133
+ o_spk = M["spk"](ta, tb)
134
+ spk = float(o_spk["spk_sim"].item())
135
+ acc = float(M["acc"](ta, tb)["acc_sim"].item())
136
+ cos = float(o_spk["cos_sim"].item())
137
+ return (f"**Speaker similarity: {spk:.3f}** (1–5)\n\n"
138
+ f"**Accent similarity : {acc:.3f}** (1–5)\n\n"
139
+ f"Cosine zero-shot (tham khảo): {cos:.3f}")
140
+ except Exception as e:
141
+ return f"❌ Track 3 lỗi: `{repr(e)}`\n\nKiểm tra **Internet On** (clone repo baseline chứa checkpoint)."
142
+
143
+ # %% [markdown]
144
+ # ## 4. TRACK 2 — exp08 Emotional TTS Evaluator (EMOS/CAT/VAD) · lazy-load
145
+ #
146
+ # Model TỐT NHẤT: WavLM fine-tune (warm-start SAILER) + audeering frozen → trunk → 3 head.
147
+ # Hằng kiến trúc PHẢI khớp exp08 (ckpt không lưu các số này).
148
+
149
+ # %%
150
+ EMO_MAX_SEC, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, USE_AMP = 8, 512, 128, 0.3, True
151
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
152
+ _EMO_ALIAS = {
153
+ "angry": "angry", "anger": "angry", "happy": "happy", "happiness": "happy", "joy": "happy",
154
+ "neutral": "neutral", "calm": "neutral", "sad": "sad", "sadness": "sad",
155
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
156
+ }
157
+ def norm_emotion(label):
158
+ key = str(label).strip().lower()
159
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
160
+
161
+ def _t2_find_ckpt():
162
+ for pat in ["ft_emotion_full_20epoch*.pt", "ft_emotion_full*.pt"]:
163
+ for base in ["/kaggle/input", "/kaggle/working"]:
164
+ hits = sorted(glob.glob(os.path.join(base, "**", pat), recursive=True))
165
+ if hits:
166
+ return hits[0]
167
+ return ""
168
+
169
+ _T2 = {}
170
+
171
+ def _t2_load():
172
+ if "infer" in _T2:
173
+ return _T2["infer"]
174
+ import torch.nn as nn
175
+ ckpt_path = _t2_find_ckpt()
176
+ assert ckpt_path, "Không thấy ft_emotion_full*.pt — Add Input dataset checkpoint exp08 (slug toanminh222/cache-exp8)?"
177
+ # code SAILER để dựng backbone WavLM
178
+ repo = "/kaggle/working/vox-profile-release"
179
+ if not os.path.exists(repo):
180
+ subprocess.run(["git", "clone", "--depth", "1",
181
+ "https://github.com/tiantiaf0627/vox-profile-release.git", repo], check=True)
182
+ if repo not in sys.path:
183
+ sys.path.insert(0, repo)
184
+
185
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
186
+ assert "wavlm" in ckpt and "heads" in ckpt, "Checkpoint thiếu 'wavlm'/'heads' → cần bản đủ ft_emotion_full_20epoch.pt."
187
+ AUD_DIM = int(ckpt.get("AUD_DIM", 0)); USE_AUDEERING = AUD_DIM > 0
188
+
189
+ def find_hf_backbone(module):
190
+ cands = []
191
+ for name, m in module.named_modules():
192
+ enc = getattr(m, "encoder", None)
193
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
194
+ and getattr(enc, "layers", None) is not None:
195
+ cands.append((name, m))
196
+ if not cands:
197
+ return None, None
198
+ cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
199
+ return cands[0]
200
+
201
+ wavlm = None
202
+ try:
203
+ from src.model.emotion.wavlm_emotion import WavLMWrapper
204
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
205
+ _name, wavlm = find_hf_backbone(_wrapper)
206
+ except Exception as e:
207
+ print("⚠️ SAILER wrapper lỗi:", repr(e), "→ fallback WavLM trắng.")
208
+ if wavlm is None:
209
+ from transformers import WavLMModel
210
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
211
+ wavlm = wavlm.to(DEVICE).eval()
212
+ WAVLM_DIM = int(wavlm.config.hidden_size)
213
+ wavlm.config.layerdrop = 0.0
214
+ wavlm.load_state_dict(ckpt["wavlm"], strict=False)
215
+
216
+ def masked_mean(hidden, attn_mask):
217
+ if attn_mask is None:
218
+ return hidden.mean(dim=1)
219
+ try:
220
+ fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
221
+ except Exception:
222
+ return hidden.mean(dim=1)
223
+ fm = fm.unsqueeze(-1).to(hidden.dtype)
224
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
225
+
226
+ # audeering frozen (nếu ckpt dùng)
227
+ aud_backbone = aud_head = aud_proc = None
228
+ if USE_AUDEERING:
229
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
230
+ from huggingface_hub import hf_hub_download
231
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
232
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
233
+ aud_backbone = Wav2Vec2Model(Wav2Vec2Config.from_pretrained(AUD_NAME))
234
+ try:
235
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
236
+ hf_hub_download(AUD_NAME, "model.safetensors"))
237
+ except Exception:
238
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
239
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
240
+ aud_backbone.load_state_dict(bb_sd, strict=False)
241
+ _hid = _sd["classifier.dense.weight"].shape[0]
242
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(),
243
+ nn.Linear(_hid, _sd["classifier.out_proj.weight"].shape[0]))
244
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
245
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
246
+ aud_backbone = aud_backbone.to(DEVICE).eval(); aud_head = aud_head.to(DEVICE).eval()
247
+
248
+ @torch.no_grad()
249
+ def audeering_feat(wave):
250
+ x = aud_proc(wave, sampling_rate=SR).input_values[0]
251
+ x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(DEVICE)
252
+ h = aud_backbone(x)[0].mean(dim=1)
253
+ out = aud_head(h)[0].cpu().numpy()
254
+ vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32)
255
+ return np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
256
+
257
+ N_EMO = len(EMOTIONS5)
258
+ TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
259
+
260
+ class EmoHeads(nn.Module):
261
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
262
+ super().__init__()
263
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
264
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
265
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
266
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
267
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
268
+ def forward(self, feat, tgt):
269
+ h = self.trunk(feat)
270
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
271
+
272
+ heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(DEVICE).eval()
273
+ heads.load_state_dict(ckpt["heads"], strict=False)
274
+ emos_mu, emos_sd = float(ckpt["emos_mu"]), float(ckpt["emos_sd"])
275
+ vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
276
+
277
+ def onehot_target(tgt):
278
+ v = np.zeros(N_EMO, dtype=np.float32)
279
+ if tgt in EMOTIONS5:
280
+ v[EMOTIONS5.index(tgt)] = 1.0
281
+ return v
282
+
283
+ @torch.no_grad()
284
+ def infer_wave(wave, target_emotion):
285
+ wave = wave[: EMO_MAX_SEC * SR].astype(np.float32)
286
+ iv = torch.from_numpy(wave).unsqueeze(0).to(DEVICE)
287
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=DEVICE)
288
+ tgt = torch.from_numpy(onehot_target(norm_emotion(target_emotion) if target_emotion else None)).unsqueeze(0).to(DEVICE)
289
+ with torch.cuda.amp.autocast(enabled=USE_AMP and DEVICE == "cuda"):
290
+ fw = wavlm(iv, attention_mask=am).last_hidden_state
291
+ fw = masked_mean(fw, am)
292
+ if USE_AUDEERING:
293
+ fw = torch.cat([fw, torch.from_numpy(audeering_feat(wave)).unsqueeze(0).to(DEVICE)], dim=1)
294
+ emos_p, cat_l, vad_p = heads(fw, tgt)
295
+ emos = float(emos_p.item()) * emos_sd + emos_mu
296
+ cat5 = torch.softmax(cat_l, 1)[0].float().cpu().numpy()
297
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
298
+ return emos, cat5, vad3
299
+
300
+ print(f"✅ Track 2 exp08 nạp xong (audeering {'ON' if USE_AUDEERING else 'OFF'}) từ {ckpt_path}")
301
+ _T2["infer"] = infer_wave
302
+ return infer_wave
303
+
304
+ def t2_predict(audio, target_emotion):
305
+ """Trả: verdict(md), EMOS(number), CAT(label dict), VAL, ARO, DOM."""
306
+ if not audio:
307
+ return "### ⚠️ Hãy tải audio (giọng TTS).", None, {}, None, None, None
308
+ try:
309
+ infer_wave = _t2_load()
310
+ wave, _ = librosa.load(audio, sr=SR, mono=True)
311
+ emos, cat5, vad3 = infer_wave(wave, target_emotion)
312
+ cat_dict = {e: float(cat5[i]) for i, e in enumerate(EMOTIONS5)}
313
+ perceived = EMOTIONS5[int(np.argmax(cat5))]
314
+ if target_emotion:
315
+ match = "✅ **KHỚP** target" if perceived == norm_emotion(target_emotion) else "⚠️ **LỆCH** target"
316
+ band = "🟢 tốt" if emos >= 4 else ("🟡 khá" if emos >= 3 else "🔴 yếu")
317
+ verdict = (f"### Kết luận biểu cảm\n"
318
+ f"- Cảm xúc cảm nhận: **{perceived}** → {match} (`{target_emotion}`)\n"
319
+ f"- EMOS = **{emos:.2f}/5** → biểu cảm {band}")
320
+ else:
321
+ verdict = (f"### Kết luận biểu cảm\n- Cảm xúc cảm nhận: **{perceived}**\n"
322
+ f"- *(Chọn cảm xúc target để bật EMOS — độ khớp ý đồ)*")
323
+ emos = None
324
+ return verdict, (round(emos, 3) if emos is not None else None), cat_dict, \
325
+ round(float(vad3[0]), 3), round(float(vad3[1]), 3), round(float(vad3[2]), 3)
326
+ except Exception as e:
327
+ return f"### ❌ Track 2 lỗi\n`{repr(e)}`\n\nĐã Add Input checkpoint exp08 + Internet On chưa?", None, {}, None, None, None
328
+
329
+ # %% [markdown]
330
+ # ## 5. Giao diện Gradio GỘP — 3 tab + launch
331
+
332
+ # %%
333
+ import gradio as gr
334
+
335
+ INTRO = (
336
+ "# 🎙️ VoiceMOS Challenge 2026 — Demo 3 Track\n"
337
+ "Một link cho cả 3 track. Mỗi tab nhận audio → trả điểm bộ chấm tự động.\n\n"
338
+ "| Track | Bài toán | Output |\n|---|---|---|\n"
339
+ "| **1** | Speech Enhancement | ACR (chất lượng) · CCR (so sánh cặp) |\n"
340
+ "| **2** | Emotional TTS | EMOS · CAT · VAD (5 cột cảm xúc) — *model tốt nhất exp08* |\n"
341
+ "| **3** | Speaker/Accent | spk_sim · acc_sim |\n\n"
342
+ "> Model nạp **lần đầu bấm nút** (chờ ~1–2 phút tải). Tab thiếu checkpoint chỉ báo lỗi trong tab đó."
343
+ )
344
+
345
+ with gr.Blocks(title="VMC2026 — Demo 3 Track") as demo:
346
+ gr.Markdown(INTRO)
347
+
348
+ with gr.Tab("1️⃣ Track 1 · Chất lượng (ACR/CCR)"):
349
+ gr.Markdown("Tải **Audio A** → ACR (1–5). Tải thêm **Audio B** → CCR (A vs B, −3..+3, >0 = A tốt hơn).")
350
+ t1a = gr.Audio(type="filepath", label="Audio A (bắt buộc)")
351
+ t1b = gr.Audio(type="filepath", label="Audio B (tùy chọn — để tính CCR)")
352
+ t1out = gr.Markdown()
353
+ gr.Button("Dự đoán", variant="primary").click(t1_predict, [t1a, t1b], t1out)
354
+
355
+ with gr.Tab("2️⃣ Track 2 · Cảm xúc (EMOS/CAT/VAD)"):
356
+ gr.Markdown("Model tốt nhất **exp08** (WavLM fine-tune + audeering, offline). "
357
+ "Chọn **cảm xúc target** để bật EMOS (độ khớp ý đồ).")
358
+ with gr.Row():
359
+ with gr.Column(scale=1):
360
+ t2a = gr.Audio(type="filepath", label="Audio (giọng TTS)")
361
+ t2tgt = gr.Dropdown(EMOTIONS5, label="🎯 Cảm xúc target (cho EMOS)")
362
+ t2btn = gr.Button("Chấm cảm xúc", variant="primary")
363
+ with gr.Column(scale=2):
364
+ t2verdict = gr.Markdown()
365
+ t2emos = gr.Number(label="EMOS — khớp cảm xúc target (1–5)", interactive=False)
366
+ t2cat = gr.Label(label="CAT — phân bố cảm xúc cảm nhận (5 lớp)")
367
+ gr.Markdown("**VAD — toạ độ cảm xúc liên tục (1–5):**")
368
+ with gr.Row():
369
+ t2val = gr.Number(label="Valence (tích cực↑)", interactive=False)
370
+ t2aro = gr.Number(label="Arousal (kích động↑)", interactive=False)
371
+ t2dom = gr.Number(label="Dominance (chi phối↑)", interactive=False)
372
+ t2btn.click(t2_predict, [t2a, t2tgt], [t2verdict, t2emos, t2cat, t2val, t2aro, t2dom])
373
+
374
+ with gr.Tab("3️⃣ Track 3 · Speaker/Accent"):
375
+ gr.Markdown("Tải **audio cần đánh giá** + **audio tham chiếu** → độ giống người nói & accent (1–5).")
376
+ t3t = gr.Audio(type="filepath", label="Audio cần đánh giá (test)")
377
+ t3r = gr.Audio(type="filepath", label="Audio tham chiếu (reference)")
378
+ t3out = gr.Markdown()
379
+ gr.Button("Dự đoán", variant="primary").click(t3_predict, [t3t, t3r], t3out)
380
+
381
+ demo.launch(share=True)
382
+
383
+ # %% [markdown]
384
+ # ## Ghi chú
385
+ # - **Lazy-load:** mỗi `_tN_load()` nạp model 1 lần rồi cache module-level → tab nào không bấm thì không tốn RAM/VRAM.
386
+ # - Track 1 cần URGENT-MOS (GitHub + HuggingFace); Track 3 clone repo baseline (có sẵn checkpoint); Track 2 cần
387
+ # checkpoint exp08 (`ft_emotion_full_20epoch.pt`, slug `toanminh222/cache-exp8`) + tải WavLM/SAILER/audeering.
388
+ # - Hằng `TRUNK_HIDDEN/HEAD_HIDDEN/EMO_MAX_SEC` của Track 2 PHẢI khớp exp08 (ckpt không lưu) — sai là lệch shape.
389
+ # - 3 tab độc lập: thiếu checkpoint/Internet của 1 track chỉ báo lỗi trong tab đó, 2 tab còn lại vẫn chạy.
390
+ # - Cần **GPU T4 + Internet On**. Bản chỉ Track 2 đầy đủ (có tab metric val nội bộ) ở `track2/demo_track2_emotion_gradio`.
demo_run_from_hf.ipynb ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "50fca144",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 — Chạy demo Gradio trên KAGGLE bằng cách KÉO code UI từ Hugging Face\n",
9
+ "\n",
10
+ "Chiến lược: **HF = nơi chứa code UI** (Space `tranminhtoan140601/voicemos2026-demo`),\n",
11
+ "**Kaggle = nơi chạy** (GPU T4 free). Notebook này tải `app.py` từ Space về rồi chạy →\n",
12
+ "ra link `*.gradio.live` (sống ~72h) để gửi mentor. KHÔNG tốn GPU trả phí của HF.\n",
13
+ "\n",
14
+ "`app.py` tự nhận môi trường: trên Kaggle → `share=True` (link công khai); checkpoint Track 2\n",
15
+ "tự tải từ HF Models repo `tranminhtoan140601/voicemos2026-track2-emotion`.\n",
16
+ "\n",
17
+ "### Cách chạy\n",
18
+ "1. Settings → **GPU T4 + Internet On**.\n",
19
+ "2. **Run All** → cell cuối in link `*.gradio.live`."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "53fd0798",
25
+ "metadata": {},
26
+ "source": [
27
+ "## 1. Cài deps (khớp Space) — KHÔNG đụng numpy/torch có sẵn Kaggle"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "id": "59468886",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "!pip install -q gradio==6.17.3 huggingface_hub librosa soundfile speechbrain loralib scipy scikit-learn pandas\n",
38
+ "\n",
39
+ "import subprocess, sys\n",
40
+ "\n",
41
+ "def pip_install(*pkgs):\n",
42
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=False)\n",
43
+ "\n",
44
+ "pip_install(\"gradio==6.17.3\", \"huggingface_hub\", \"librosa\", \"soundfile\",\n",
45
+ " \"speechbrain\", \"loralib\", \"scipy\", \"scikit-learn\", \"pandas\")"
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "id": "1feff99f",
51
+ "metadata": {},
52
+ "source": [
53
+ "## 2. Kéo code UI (app.py) từ HF Space về Kaggle"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "id": "ca69f24e",
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "import os\n",
64
+ "from huggingface_hub import snapshot_download\n",
65
+ "\n",
66
+ "SPACE_REPO = \"tranminhtoan140601/voicemos2026-demo\"\n",
67
+ "LOCAL_DIR = \"/kaggle/working/vmc_demo\"\n",
68
+ "\n",
69
+ "# Tải toàn bộ repo Space (app.py + requirements + README) về local\n",
70
+ "snapshot_download(repo_id=SPACE_REPO, repo_type=\"space\", local_dir=LOCAL_DIR)\n",
71
+ "print(\"✅ Đã kéo Space về:\", LOCAL_DIR)\n",
72
+ "print(\"Files:\", os.listdir(LOCAL_DIR))"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "markdown",
77
+ "id": "613d7d61",
78
+ "metadata": {},
79
+ "source": [
80
+ "## 3. Chạy app.py (Kaggle có GPU → nhanh; app.py tự share=True ra link gradio.live)\n",
81
+ "\n",
82
+ "`app.py` tải checkpoint Track 2 từ HF Models repo, clone URGENT-MOS/SAILER/baseline lúc bấm nút.\n",
83
+ "Cell này sẽ **chạy mãi** (server Gradio) — đợi dòng `Running on public URL: https://....gradio.live`."
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "8e779da3",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "# Chạy như tiến trình con để giữ log; KHÔNG có SPACE_ID nên app.py tự bật share=True\n",
94
+ "subprocess.run([sys.executable, \"app.py\"], cwd=LOCAL_DIR, check=True)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "id": "288e61cb",
100
+ "metadata": {},
101
+ "source": [
102
+ "## Ghi chú\n",
103
+ "- Đây là cách \"1 nguồn code, chạy nơi có GPU free\": sửa UI thì sửa trên Space HF → chạy lại notebook này.\n",
104
+ "- Nếu muốn chạy bản local trong `kaggle_baseline/demo_all_tracks_gradio` (code inline) thì dùng notebook đó.\n",
105
+ "- Lần đầu bấm nút mỗi track sẽ tải model (WavLM/SAILER/URGENT-MOS/ECAPA) → chờ chút; Kaggle có GPU nên inference nhanh.\n",
106
+ "- Cần **Internet On** (tải code HF + model) + **GPU T4**."
107
+ ]
108
+ }
109
+ ],
110
+ "metadata": {
111
+ "jupytext": {
112
+ "cell_metadata_filter": "-all",
113
+ "main_language": "python",
114
+ "notebook_metadata_filter": "-all"
115
+ }
116
+ },
117
+ "nbformat": 4,
118
+ "nbformat_minor": 5
119
+ }
demo_run_from_hf_pipeline.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 — Chạy demo Gradio trên KAGGLE bằng cách KÉO code UI từ Hugging Face
3
+ #
4
+ # Chiến lược: **HF = nơi chứa code UI** (Space `tranminhtoan140601/voicemos2026-demo`),
5
+ # **Kaggle = nơi chạy** (GPU T4 free). Notebook này tải `app.py` từ Space về rồi chạy →
6
+ # ra link `*.gradio.live` (sống ~72h) để gửi mentor. KHÔNG tốn GPU trả phí của HF.
7
+ #
8
+ # `app.py` tự nhận môi trường: trên Kaggle → `share=True` (link công khai); checkpoint Track 2
9
+ # tự tải từ HF Models repo `tranminhtoan140601/voicemos2026-track2-emotion`.
10
+ #
11
+ # ### Cách chạy
12
+ # 1. Settings → **GPU T4 + Internet On**.
13
+ # 2. **Run All** → cell cuối in link `*.gradio.live`.
14
+
15
+ # %% [markdown]
16
+ # ## 1. Cài deps (khớp Space) — KHÔNG đụng numpy/torch có sẵn Kaggle
17
+
18
+ # %%
19
+ # !pip install -q gradio==6.17.3 huggingface_hub librosa soundfile speechbrain loralib scipy scikit-learn pandas
20
+
21
+ import subprocess, sys
22
+
23
+ def pip_install(*pkgs):
24
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=False)
25
+
26
+ pip_install("gradio==6.17.3", "huggingface_hub", "librosa", "soundfile",
27
+ "speechbrain", "loralib", "scipy", "scikit-learn", "pandas")
28
+
29
+ # %% [markdown]
30
+ # ## 2. Kéo code UI (app.py) từ HF Space về Kaggle
31
+
32
+ # %%
33
+ import os
34
+ from huggingface_hub import snapshot_download
35
+
36
+ SPACE_REPO = "tranminhtoan140601/voicemos2026-demo"
37
+ LOCAL_DIR = "/kaggle/working/vmc_demo"
38
+
39
+ # Tải toàn bộ repo Space (app.py + requirements + README) về local
40
+ snapshot_download(repo_id=SPACE_REPO, repo_type="space", local_dir=LOCAL_DIR)
41
+ print("✅ Đã kéo Space về:", LOCAL_DIR)
42
+ print("Files:", os.listdir(LOCAL_DIR))
43
+
44
+ # %% [markdown]
45
+ # ## 3. Chạy app.py (Kaggle có GPU → nhanh; app.py tự share=True ra link gradio.live)
46
+ #
47
+ # `app.py` tải checkpoint Track 2 từ HF Models repo, clone URGENT-MOS/SAILER/baseline lúc bấm nút.
48
+ # Cell này sẽ **chạy mãi** (server Gradio) — đợi dòng `Running on public URL: https://....gradio.live`.
49
+
50
+ # %%
51
+ # Chạy như tiến trình con để giữ log; KHÔNG có SPACE_ID nên app.py tự bật share=True
52
+ subprocess.run([sys.executable, "app.py"], cwd=LOCAL_DIR, check=True)
53
+
54
+ # %% [markdown]
55
+ # ## Ghi chú
56
+ # - Đây là cách "1 nguồn code, chạy nơi có GPU free": sửa UI thì sửa trên Space HF → chạy lại notebook này.
57
+ # - Nếu muốn chạy bản local trong `kaggle_baseline/demo_all_tracks_gradio` (code inline) thì dùng notebook đó.
58
+ # - Lần đầu bấm nút mỗi track sẽ tải model (WavLM/SAILER/URGENT-MOS/ECAPA) → chờ chút; Kaggle có GPU nên inference nhanh.
59
+ # - Cần **Internet On** (tải code HF + model) + **GPU T4**.
track1/demo_track1_gradio.ipynb ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# VMC2026 Track 1 — Demo Gradio (Speech Enhancement: ACR + CCR)\n",
8
+ "\n",
9
+ "Baseline **URGENT-MOS**. Tải **Audio A** → **ACR** (chất lượng 1–5).\n",
10
+ "Tải thêm **Audio B** → **CCR** (so sánh A vs B, thang −3..+3, >0 nghĩa là A tốt hơn).\n",
11
+ "\n",
12
+ "### Cách dùng trên Kaggle\n",
13
+ "1. Settings → **GPU T4 + Internet On**.\n",
14
+ "2. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h) → gửi mentor."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {},
20
+ "source": [
21
+ "## 1. Cài đặt + clone URGENT-MOS"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "!pip install -q gradio librosa soundfile\n",
31
+ "!git clone -q https://github.com/vvwangvv/URGENT-MOS.git /kaggle/working/URGENT-MOS\n",
32
+ "!pip install -q -e /kaggle/working/URGENT-MOS"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "metadata": {},
38
+ "source": [
39
+ "## 2. Nạp model + hàm dự đoán"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "68754b2d",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "import os, sys, subprocess, librosa\n",
50
+ "\n",
51
+ "DEVICE = \"cuda\"\n",
52
+ "URGENT_REPO = \"/kaggle/working/URGENT-MOS\"\n",
53
+ "URGENT_CKPT = \"urgent-challenge/urgent-mos-f1c1m5dcorpus\" # tự tải từ HuggingFace\n",
54
+ "\n",
55
+ "\n",
56
+ "def _ensure_urgent_mos():\n",
57
+ " \"\"\"Tự clone + cài URGENT-MOS nếu chưa có (phòng khi cell cài chưa chạy).\"\"\"\n",
58
+ " if not os.path.isdir(URGENT_REPO):\n",
59
+ " subprocess.run(f\"git clone -q https://github.com/vvwangvv/URGENT-MOS.git {URGENT_REPO}\",\n",
60
+ " shell=True, check=True)\n",
61
+ " subprocess.run(f\"pip install -q -e {URGENT_REPO}\", shell=True, check=True)\n",
62
+ " if URGENT_REPO not in sys.path: # package nằm ở root repo → thêm vào path là import được\n",
63
+ " sys.path.insert(0, URGENT_REPO)\n",
64
+ "\n",
65
+ "\n",
66
+ "_M = {}\n",
67
+ "\n",
68
+ "def _load():\n",
69
+ " if \"m\" not in _M:\n",
70
+ " _ensure_urgent_mos()\n",
71
+ " import torch\n",
72
+ " from urgent_mos.utils import load_model_from_checkpoint\n",
73
+ " dev = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
74
+ " m = load_model_from_checkpoint(URGENT_CKPT, dev)\n",
75
+ " m.eval()\n",
76
+ " _M[\"m\"] = m\n",
77
+ " return _M[\"m\"]\n",
78
+ "\n",
79
+ "\n",
80
+ "def _scalar(x):\n",
81
+ " return float(x.item()) if hasattr(x, \"item\") else float(x)\n",
82
+ "\n",
83
+ "\n",
84
+ "def predict(audio_a, audio_b):\n",
85
+ " import torch\n",
86
+ " from urgent_mos.api.infer import infer, infer_pairs\n",
87
+ " if not audio_a:\n",
88
+ " return \"⚠️ Hãy tải lên ít nhất Audio A.\"\n",
89
+ " m = _load()\n",
90
+ " wa = torch.from_numpy(librosa.load(audio_a, sr=16000, mono=True)[0]).float()\n",
91
+ " acr_a = max(1.0, min(5.0, _scalar(infer(m, [wa], sample_rate=[16000],\n",
92
+ " batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
93
+ " out = f\"ACR (Audio A): {acr_a:.3f} (chất lượng tuyệt đối, thang 1–5)\"\n",
94
+ " if audio_b:\n",
95
+ " wb = torch.from_numpy(librosa.load(audio_b, sr=16000, mono=True)[0]).float()\n",
96
+ " acr_b = max(1.0, min(5.0, _scalar(infer(m, [wb], sample_rate=[16000],\n",
97
+ " batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
98
+ " ccr = max(-3.0, min(3.0, _scalar(infer_pairs(m, [(wa, wb)], sample_rate=[(16000, 16000)],\n",
99
+ " batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
100
+ " out += (f\"\\nACR (Audio B): {acr_b:.3f}\"\n",
101
+ " f\"\\nCCR (A so với B): {ccr:+.3f} (>0: A tốt hơn B; thang −3..+3)\")\n",
102
+ " return out"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {},
108
+ "source": [
109
+ "## 3. Giao diện Gradio + launch"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "import gradio as gr\n",
119
+ "\n",
120
+ "with gr.Blocks(title=\"VMC2026 Track 1 — ACR/CCR\") as demo:\n",
121
+ " gr.Markdown(\"# 🎙️ Track 1 · Speech Enhancement (ACR / CCR)\\n\"\n",
122
+ " \"Tải **Audio A** để có ACR. Tải thêm **Audio B** để so sánh CCR (A vs B).\")\n",
123
+ " a = gr.Audio(type=\"filepath\", label=\"Audio A (bắt buộc)\")\n",
124
+ " b = gr.Audio(type=\"filepath\", label=\"Audio B (tùy chọn — để tính CCR)\")\n",
125
+ " out = gr.Textbox(label=\"Kết quả\", lines=4)\n",
126
+ " gr.Button(\"Dự đoán\", variant=\"primary\").click(predict, [a, b], out)\n",
127
+ "\n",
128
+ "demo.launch(share=True)"
129
+ ]
130
+ }
131
+ ],
132
+ "metadata": {
133
+ "kernelspec": {
134
+ "display_name": "Python 3",
135
+ "language": "python",
136
+ "name": "python3"
137
+ },
138
+ "language_info": {
139
+ "name": "python"
140
+ }
141
+ },
142
+ "nbformat": 4,
143
+ "nbformat_minor": 5
144
+ }
track1/demo_track1_gradio_pipeline.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 1 — Demo Gradio (Speech Enhancement: ACR + CCR)
3
+ #
4
+ # Baseline **URGENT-MOS**. Tải **Audio A** → **ACR** (chất lượng 1–5).
5
+ # Tải thêm **Audio B** → **CCR** (so sánh A vs B, thang −3..+3, >0 nghĩa là A tốt hơn).
6
+ #
7
+ # ### Cách dùng trên Kaggle
8
+ # 1. Settings → **GPU T4 + Internet On**.
9
+ # 2. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h) → gửi mentor.
10
+
11
+ # %% [markdown]
12
+ # ## 1. Cài đặt + clone URGENT-MOS
13
+
14
+ # %%
15
+ # !pip install -q gradio librosa soundfile
16
+ # !git clone -q https://github.com/vvwangvv/URGENT-MOS.git /kaggle/working/URGENT-MOS
17
+ # !pip install -q -e /kaggle/working/URGENT-MOS
18
+
19
+ # %% [markdown]
20
+ # ## 2. Nạp model + hàm dự đoán
21
+
22
+ # %%
23
+ import os, sys, subprocess, librosa
24
+
25
+ DEVICE = "cuda"
26
+ URGENT_REPO = "/kaggle/working/URGENT-MOS"
27
+ URGENT_CKPT = "urgent-challenge/urgent-mos-f1c1m5dcorpus" # tự tải từ HuggingFace
28
+
29
+
30
+ def _ensure_urgent_mos():
31
+ """Đảm bảo import được `urgent_mos`: clone repo nếu thiếu, cài deps, thêm vào sys.path."""
32
+ if not os.path.isdir(URGENT_REPO):
33
+ subprocess.run(f"git clone -q https://github.com/vvwangvv/URGENT-MOS.git {URGENT_REPO}",
34
+ shell=True, check=True)
35
+ # package nằm ở ROOT repo → thêm vào path là import được, KHÔNG phụ thuộc pip install thành công
36
+ if URGENT_REPO not in sys.path:
37
+ sys.path.insert(0, URGENT_REPO)
38
+ import importlib
39
+ importlib.invalidate_caches()
40
+ # thử import; nếu thiếu dependency thì cài editable (kéo theo torchcodec, hydra-core, omegaconf...)
41
+ try:
42
+ importlib.import_module("urgent_mos.api.infer")
43
+ except Exception:
44
+ subprocess.run(f"pip install -q -e {URGENT_REPO}", shell=True)
45
+ importlib.invalidate_caches()
46
+
47
+
48
+ _M = {}
49
+
50
+ def _load():
51
+ if "m" not in _M:
52
+ _ensure_urgent_mos()
53
+ import torch
54
+ from urgent_mos.utils import load_model_from_checkpoint
55
+ dev = DEVICE if torch.cuda.is_available() else "cpu"
56
+ m = load_model_from_checkpoint(URGENT_CKPT, dev)
57
+ m.eval()
58
+ _M["m"] = m
59
+ return _M["m"]
60
+
61
+
62
+ def _scalar(x):
63
+ return float(x.item()) if hasattr(x, "item") else float(x)
64
+
65
+
66
+ def predict(audio_a, audio_b):
67
+ if not audio_a:
68
+ return "⚠️ Hãy tải lên ít nhất Audio A."
69
+ import torch
70
+ m = _load() # _load() tự ensure repo + sys.path TRƯỚC
71
+ from urgent_mos.api.infer import infer, infer_pairs # giờ mới import được
72
+ wa = torch.from_numpy(librosa.load(audio_a, sr=16000, mono=True)[0]).float()
73
+ acr_a = max(1.0, min(5.0, _scalar(infer(m, [wa], sample_rate=[16000],
74
+ batch_frames=None, num_workers=0)[0]["mos_overall"])))
75
+ out = f"ACR (Audio A): {acr_a:.3f} (chất lượng tuyệt đối, thang 1–5)"
76
+ if audio_b:
77
+ wb = torch.from_numpy(librosa.load(audio_b, sr=16000, mono=True)[0]).float()
78
+ acr_b = max(1.0, min(5.0, _scalar(infer(m, [wb], sample_rate=[16000],
79
+ batch_frames=None, num_workers=0)[0]["mos_overall"])))
80
+ ccr = max(-3.0, min(3.0, _scalar(infer_pairs(m, [(wa, wb)], sample_rate=[(16000, 16000)],
81
+ batch_frames=None, num_workers=0)[0]["mos_overall"])))
82
+ out += (f"\nACR (Audio B): {acr_b:.3f}"
83
+ f"\nCCR (A so với B): {ccr:+.3f} (>0: A tốt hơn B; thang −3..+3)")
84
+ return out
85
+
86
+ # %% [markdown]
87
+ # ## 3. Giao diện Gradio + launch
88
+
89
+ # %%
90
+ import gradio as gr
91
+
92
+ with gr.Blocks(title="VMC2026 Track 1 — ACR/CCR") as demo:
93
+ gr.Markdown("# 🎙️ Track 1 · Speech Enhancement (ACR / CCR)\n"
94
+ "Tải **Audio A** để có ACR. Tải thêm **Audio B** để so sánh CCR (A vs B).")
95
+ a = gr.Audio(type="filepath", label="Audio A (bắt buộc)")
96
+ b = gr.Audio(type="filepath", label="Audio B (tùy chọn — để tính CCR)")
97
+ out = gr.Textbox(label="Kết quả", lines=4)
98
+ gr.Button("Dự đoán", variant="primary").click(predict, [a, b], out)
99
+
100
+ demo.launch(share=True)
track1/track1_baseline.ipynb ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# VMC2026 Track 1 — Baseline (URGENT-MOS)\n",
8
+ "\n",
9
+ "Chạy ngay được — data dev công khai trên HuggingFace, checkpoint tự tải.\n",
10
+ "\n",
11
+ "**Trước khi chạy:** Session options → Accelerator = **GPU T4**, Internet = **On** (verify phone nếu cần).\n",
12
+ "\n",
13
+ "Output: `submission_track1.zip` (chứa `predictions.csv`) → nộp Track 1."
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "## 1. Cài đặt URGENT-MOS"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "!git clone -q https://github.com/vvwangvv/URGENT-MOS.git /kaggle/working/URGENT-MOS\n",
30
+ "!pip install -q -e /kaggle/working/URGENT-MOS"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {},
36
+ "source": [
37
+ "## 2. Smoke test (10 mẫu)\n",
38
+ "Kiểm tra môi trường + tải checkpoint `urgent-challenge/urgent-mos-f1c1m5dcorpus` từ HF (lần đầu hơi lâu)."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "!cd /kaggle/working/URGENT-MOS && python scripts/infer_vmc2026_track1.py --split dev --limit 10 --output /kaggle/working/predictions_smoke.csv\n",
48
+ "import pandas as pd\n",
49
+ "pd.read_csv('/kaggle/working/predictions_smoke.csv').head()"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {},
55
+ "source": [
56
+ "## 3. Inference đầy đủ (ACR 1008 + CCR 2520)\n",
57
+ "Nếu OOM: thêm `--batch-frames 8000` (hoặc nhỏ hơn)."
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": "!cd /kaggle/working/URGENT-MOS && python scripts/infer_vmc2026_track1.py --split dev --output /kaggle/working/predictions.csv"
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": [
71
+ "## 4. Validate"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "execution_count": null,
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": "import pandas as pd\ndf = pd.read_csv('/kaggle/working/predictions.csv')\nassert list(df.columns) == ['sample_id', 'pred_score'], df.columns.tolist()\nacr = df[df['sample_id'].str.contains('-acr_')]\nccr = df[df['sample_id'].str.contains('-ccr_')]\nprint(f'Tổng {len(df)} | ACR {len(acr)} | CCR {len(ccr)}')\nprint('ACR:', acr['pred_score'].min(), '→', acr['pred_score'].max(), '(cần [1,5])')\nprint('CCR:', ccr['pred_score'].min(), '→', ccr['pred_score'].max(), '(cần [-3,+3])')"
80
+ },
81
+ {
82
+ "cell_type": "markdown",
83
+ "metadata": {},
84
+ "source": [
85
+ "## 5. Đóng zip nộp"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": "!cd /kaggle/working && zip -j submission_track1.zip predictions.csv && unzip -l submission_track1.zip\nprint('Tải /kaggle/working/submission_track1.zip → nộp My Submissions (chọn Track 1, bỏ chọn track khác)')"
94
+ }
95
+ ],
96
+ "metadata": {
97
+ "kernelspec": {
98
+ "display_name": "Python 3",
99
+ "language": "python",
100
+ "name": "python3"
101
+ },
102
+ "language_info": {
103
+ "name": "python"
104
+ }
105
+ },
106
+ "nbformat": 4,
107
+ "nbformat_minor": 5
108
+ }
track1/track1_baseline_pipeline.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 1 — Baseline Pipeline (Kaggle)
3
+ #
4
+ # Baseline = **URGENT-MOS** (checkpoint pre-trained tự tải từ HuggingFace).
5
+ # Repo có sẵn script xuất đúng format nộp → gần như không phải code gì.
6
+ #
7
+ # **Không bị chặn bởi license:** dev data Track 1 công khai trên HuggingFace
8
+ # (`urgent-challenge/vmc2026-track1-dev`, configs `acr` + `ccr`).
9
+ #
10
+ # **Cách dùng:** Notebook → GPU T4 + Internet On → chạy lần lượt các cell.
11
+ # Output: `predictions.csv` (`sample_id,pred_score`) → zip → nộp Track 1.
12
+ # ⚠️ File trong zip BẮT BUỘC tên `predictions.csv` (guideline Track 1) — đừng để `predictions_dev.csv`.
13
+
14
+ # %% [markdown]
15
+ # ## 1. Cài đặt URGENT-MOS
16
+
17
+ # %%
18
+ # !git clone -q https://github.com/vvwangvv/URGENT-MOS.git /kaggle/working/URGENT-MOS
19
+ # !pip install -q -e /kaggle/working/URGENT-MOS
20
+
21
+ # %% [markdown]
22
+ # ## 2. Smoke test (vài mẫu) — kiểm tra môi trường + tải checkpoint
23
+ # Checkpoint `urgent-challenge/urgent-mos-f1c1m5dcorpus` tự tải từ HF lần chạy đầu.
24
+
25
+ # %%
26
+ # !cd /kaggle/working/URGENT-MOS && python scripts/infer_vmc2026_track1.py \
27
+ # --split dev --limit 10 --output /kaggle/working/predictions_smoke.csv
28
+ # import pandas as pd; pd.read_csv("/kaggle/working/predictions_smoke.csv").head()
29
+
30
+ # %% [markdown]
31
+ # ## 3. Inference đầy đủ trên dev set (ACR + CCR)
32
+ # Script tự tải dataset từ HF, chạy cả ACR (1008) + CCR (2520) → 1 file predictions.
33
+
34
+ # %%
35
+ # !cd /kaggle/working/URGENT-MOS && python scripts/infer_vmc2026_track1.py \
36
+ # --split dev --output /kaggle/working/predictions.csv
37
+ # Nếu OOM: thêm --batch-frames <N> để giảm bộ nhớ.
38
+
39
+ # %% [markdown]
40
+ # ## 4. Validate + đóng zip nộp
41
+
42
+ # %%
43
+ import pandas as pd
44
+
45
+ PRED = "/kaggle/working/predictions.csv"
46
+ df = pd.read_csv(PRED)
47
+ assert list(df.columns) == ["sample_id", "pred_score"], f"Header sai: {df.columns.tolist()}"
48
+
49
+ acr = df[df["sample_id"].str.contains("-acr_")]
50
+ ccr = df[df["sample_id"].str.contains("-ccr_")]
51
+ print(f"Tổng {len(df)} dòng | ACR {len(acr)} | CCR {len(ccr)}")
52
+ print("ACR range:", acr["pred_score"].min(), "→", acr["pred_score"].max(), "(cần [1,5])")
53
+ print("CCR range:", ccr["pred_score"].min(), "→", ccr["pred_score"].max(), "(cần [-3,+3])")
54
+ # Kỳ vọng dev: ACR=1008, CCR=2520
55
+
56
+ # %%
57
+ # !cd /kaggle/working && zip -j submission_track1.zip predictions.csv && unzip -l submission_track1.zip
58
+
59
+ # %% [markdown]
60
+ # ## Ghi chú
61
+ # - Nộp: My Submissions → chọn **Track 1**, **bỏ chọn** track khác → upload `submission_track1.zip`.
62
+ # - File nộp Track 1 tên **`predictions.csv`** (KHÁC Track 2/3 dùng `answer.txt`). Script đã xuất đúng cột `sample_id,pred_score`.
63
+ # - Eval phase: đổi `--split test` (sau khi eval data ra 31/7).
64
+ # - GPU khuyến nghị; chỉ inference nên nhẹ, fit T4 16GB.
track2/demo_track2_emotion_gradio.ipynb ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "73831f26",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — Demo Gradio \"Emotional TTS Evaluator\" (model TỐT NHẤT = exp08)\n",
9
+ "\n",
10
+ "Demo này dùng **checkpoint cảm xúc tốt nhất** (`ft_emotion_full_20epoch.pt`: WavLM fine-tune warm-start\n",
11
+ "SAILER + audeering frozen) để chấm **5 cột cảm xúc** của 1 file giọng TTS: **EMOS / CAT / VAL / ARO / DOM**.\n",
12
+ "Khác demo cũ (`demo_track2_gradio`) dùng baseline UTMOS+emotion2vec+Gemini — bản này KHÔNG cần API.\n",
13
+ "\n",
14
+ "**2 tab:**\n",
15
+ "1. *Chấm 1 file TTS* — tải audio + chọn cảm xúc target → ra điểm biểu cảm cảm xúc + diễn giải KHỚP/LỆCH.\n",
16
+ "2. *Metric bộ chấm* — tính UTT-SRCC (EMOS/VAD) + CAT-err trên val nội bộ (train.csv) → cho biết độ tin cậy.\n",
17
+ "\n",
18
+ "**Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input (1) dataset Track 2, (2) dataset chứa\n",
19
+ "`ft_emotion_full_20epoch.pt` → Run All → cell cuối in link `*.gradio.live`."
20
+ ]
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "id": "d505183e",
25
+ "metadata": {},
26
+ "source": [
27
+ "## 0. Cấu hình — auto-dò DATA_ROOT + checkpoint"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": null,
33
+ "id": "20d538e6",
34
+ "metadata": {},
35
+ "outputs": [],
36
+ "source": [
37
+ "import os, glob\n",
38
+ "\n",
39
+ "def find_data_root(search_root=\"/kaggle/input\"):\n",
40
+ " cands = []\n",
41
+ " for train_csv in glob.glob(os.path.join(search_root, \"**\", \"sets\", \"train.csv\"), recursive=True):\n",
42
+ " root = os.path.dirname(os.path.dirname(train_csv))\n",
43
+ " score = os.path.isdir(os.path.join(root, \"wav\")) + os.path.exists(os.path.join(root, \"metadata.csv\"))\n",
44
+ " cands.append((score, root))\n",
45
+ " cands.sort(reverse=True)\n",
46
+ " return cands\n",
47
+ "\n",
48
+ "_cands = find_data_root(\"/kaggle/input\")\n",
49
+ "if _cands:\n",
50
+ " print(\"🔎 Ứng viên DATA_ROOT:\")\n",
51
+ " for sc, r in _cands:\n",
52
+ " print(f\" [{sc}/2] {r}\")\n",
53
+ " DATA_ROOT = _cands[0][1]\n",
54
+ " print(f\"👉 Tự chọn DATA_ROOT = {DATA_ROOT}\")\n",
55
+ "else:\n",
56
+ " DATA_ROOT = \"/kaggle/input/datasets/minhtoan2\" # dự phòng\n",
57
+ " print(f\"❌ Không thấy sets/train.csv → dự phòng {DATA_ROOT} (đã Add Input chưa?)\")\n",
58
+ "\n",
59
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
60
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\"\n",
61
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\"\n",
62
+ "\n",
63
+ "# ── Checkpoint cảm xúc exp08 (ưu tiên bản 20 epoch = TỐT NHẤT) ─────────────────\n",
64
+ "CKPT_PATH = \"\" # << \"\" = auto-dò; hoặc trỏ tay \"/kaggle/input/<slug>/ft_emotion_full_20epoch.pt\"\n",
65
+ "\n",
66
+ "def find_ckpt(explicit):\n",
67
+ " if explicit and os.path.exists(explicit):\n",
68
+ " return explicit\n",
69
+ " pats = [\"ft_emotion_full_20epoch*.pt\", \"ft_emotion_full*.pt\"] # ưu tiên bản 20epoch\n",
70
+ " for pat in pats:\n",
71
+ " for base in [\"/kaggle/input\", \"/kaggle/working\"]:\n",
72
+ " hits = sorted(glob.glob(os.path.join(base, \"**\", pat), recursive=True))\n",
73
+ " if hits:\n",
74
+ " return hits[0]\n",
75
+ " return \"\"\n",
76
+ "\n",
77
+ "CKPT_PATH = find_ckpt(CKPT_PATH)\n",
78
+ "assert CKPT_PATH, \"❌ Không thấy ft_emotion_full*.pt. Add Input dataset chứa checkpoint exp08 chưa?\"\n",
79
+ "print(\"✅ Checkpoint:\", CKPT_PATH)\n",
80
+ "\n",
81
+ "# ── Hằng kiến trúc PHẢI khớp exp08 (ckpt không lưu các số này) ────────────────\n",
82
+ "DEVICE = \"cuda\"\n",
83
+ "SR = 16000\n",
84
+ "EMO_MAX_SEC = 8\n",
85
+ "TRUNK_HIDDEN = 512\n",
86
+ "HEAD_HIDDEN = 128\n",
87
+ "DROPOUT = 0.3 # không ảnh hưởng eval\n",
88
+ "USE_AMP = True\n",
89
+ "\n",
90
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
91
+ "_EMO_ALIAS = {\n",
92
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
93
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
94
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
95
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
96
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
97
+ "}\n",
98
+ "\n",
99
+ "def norm_emotion(label):\n",
100
+ " key = str(label).strip().lower()\n",
101
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
102
+ "\n",
103
+ "def stem(p):\n",
104
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
105
+ "\n",
106
+ "# Mốc exp08 (val nội bộ / DEV) để so trong tab metric\n",
107
+ "EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751}"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "markdown",
112
+ "id": "daadb3d8",
113
+ "metadata": {},
114
+ "source": [
115
+ "## 1. Cài đặt + clone code SAILER"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "id": "6303e9cf",
122
+ "metadata": {},
123
+ "outputs": [],
124
+ "source": [
125
+ "import sys, subprocess\n",
126
+ "\n",
127
+ "def pip_install(*pkgs):\n",
128
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
129
+ "\n",
130
+ "pip_install(\"gradio\", \"loralib\", \"speechbrain\", \"librosa\", \"soundfile\",\n",
131
+ " \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
132
+ "\n",
133
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
134
+ "if not os.path.exists(REPO_DIR):\n",
135
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
136
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
137
+ "if REPO_DIR not in sys.path:\n",
138
+ " sys.path.insert(0, REPO_DIR)"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "id": "233b6770",
144
+ "metadata": {},
145
+ "source": [
146
+ "## 2. Nạp model exp08 (backbone WavLM ft + audeering frozen + heads) — 1 lần"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "1881d27c",
153
+ "metadata": {
154
+ "lines_to_next_cell": 1
155
+ },
156
+ "outputs": [],
157
+ "source": [
158
+ "import torch\n",
159
+ "import torch.nn as nn\n",
160
+ "import torch.nn.functional as F\n",
161
+ "import numpy as np\n",
162
+ "import librosa\n",
163
+ "\n",
164
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
165
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (chậm)\")\n",
166
+ "\n",
167
+ "ckpt = torch.load(CKPT_PATH, map_location=\"cpu\", weights_only=False) # ckpt có numpy → cần False\n",
168
+ "assert \"wavlm\" in ckpt and \"heads\" in ckpt, \"❌ Checkpoint thiếu 'wavlm'/'heads' → cần ft_emotion_full_20epoch.pt đủ.\"\n",
169
+ "AUD_DIM = int(ckpt.get(\"AUD_DIM\", 0))\n",
170
+ "USE_AUDEERING = AUD_DIM > 0\n",
171
+ "print(\"✅ Nạp ckpt | keys:\", list(ckpt.keys()), \"| AUD_DIM:\", AUD_DIM, \"(audeering\", \"ON)\" if USE_AUDEERING else \"OFF)\")\n",
172
+ "\n",
173
+ "def find_hf_backbone(module):\n",
174
+ " cands = []\n",
175
+ " for name, m in module.named_modules():\n",
176
+ " enc = getattr(m, \"encoder\", None)\n",
177
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
178
+ " and getattr(enc, \"layers\", None) is not None:\n",
179
+ " cands.append((name, m))\n",
180
+ " if not cands:\n",
181
+ " return None, None\n",
182
+ " cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
183
+ " return cands[0]\n",
184
+ "\n",
185
+ "wavlm = None\n",
186
+ "try:\n",
187
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
188
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
189
+ " _name, wavlm = find_hf_backbone(_wrapper)\n",
190
+ " if wavlm is not None:\n",
191
+ " print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{_name}'\")\n",
192
+ "except Exception as e:\n",
193
+ " print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
194
+ "if wavlm is None:\n",
195
+ " from transformers import WavLMModel\n",
196
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
197
+ " print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
198
+ "\n",
199
+ "wavlm = wavlm.to(device).eval()\n",
200
+ "WAVLM_DIM = int(wavlm.config.hidden_size)\n",
201
+ "wavlm.config.layerdrop = 0.0\n",
202
+ "_miss, _unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
203
+ "print(f\"🔁 load wavlm: thiếu {len(_miss)} / dư {len(_unexp)} key (kỳ vọng ~0)\")\n",
204
+ "\n",
205
+ "def masked_mean(hidden, attn_mask):\n",
206
+ " if attn_mask is None:\n",
207
+ " return hidden.mean(dim=1)\n",
208
+ " try:\n",
209
+ " fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
210
+ " except Exception:\n",
211
+ " return hidden.mean(dim=1)\n",
212
+ " fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
213
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
214
+ "\n",
215
+ "@torch.no_grad()\n",
216
+ "def wavlm_embed(input_values, attn_mask):\n",
217
+ " out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
218
+ " return masked_mean(out, attn_mask)\n",
219
+ "\n",
220
+ "# ── audeering frozen (đặc trưng phụ) — chỉ dựng nếu ckpt có dùng ──\n",
221
+ "aud_backbone = aud_head = aud_proc = None\n",
222
+ "if USE_AUDEERING:\n",
223
+ " from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
224
+ " from huggingface_hub import hf_hub_download\n",
225
+ " AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
226
+ " aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
227
+ " aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
228
+ " aud_backbone = Wav2Vec2Model(aud_cfg)\n",
229
+ " try:\n",
230
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
231
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
232
+ " except Exception:\n",
233
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
234
+ " bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
235
+ " aud_backbone.load_state_dict(bb_sd, strict=False)\n",
236
+ " _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
237
+ " aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd[\"classifier.out_proj.weight\"].shape[0]))\n",
238
+ " aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
239
+ " aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
240
+ " aud_backbone = aud_backbone.to(device).eval()\n",
241
+ " aud_head = aud_head.to(device).eval()\n",
242
+ " assert _hid + 3 == AUD_DIM, f\"⚠️ AUD_DIM dựng ({_hid+3}) ≠ ckpt ({AUD_DIM})\"\n",
243
+ " print(f\"✅ audeering frozen ({AUD_DIM}-D)\")\n",
244
+ "\n",
245
+ "@torch.no_grad()\n",
246
+ "def audeering_feat(wave):\n",
247
+ " x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
248
+ " x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
249
+ " h = aud_backbone(x)[0].mean(dim=1)\n",
250
+ " out = aud_head(h)[0].cpu().numpy()\n",
251
+ " vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
252
+ " return np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
253
+ "\n",
254
+ "# ── EmoHeads (khớp exp08) + nạp trọng số + chuẩn hóa từ ckpt ──\n",
255
+ "N_EMO = len(EMOTIONS5)\n",
256
+ "TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
257
+ "\n",
258
+ "class EmoHeads(nn.Module):\n",
259
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
260
+ " super().__init__()\n",
261
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
262
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
263
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
264
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
265
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
266
+ " def forward(self, feat, tgt):\n",
267
+ " h = self.trunk(feat)\n",
268
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
269
+ "\n",
270
+ "heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device).eval()\n",
271
+ "_hm, _hu = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
272
+ "print(f\"🔁 load heads: thiếu {len(_hm)} / dư {len(_hu)} key (kỳ vọng 0)\")\n",
273
+ "\n",
274
+ "emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
275
+ "vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
276
+ "print(f\"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
277
+ "\n",
278
+ "def onehot_target(tgt):\n",
279
+ " v = np.zeros(N_EMO, dtype=np.float32)\n",
280
+ " if tgt in EMOTIONS5:\n",
281
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
282
+ " return v"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "id": "3c7ce7b5",
288
+ "metadata": {},
289
+ "source": [
290
+ "## 3. Hàm suy luận lõi (1 wave numpy → emos/cat5/vad3)"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "id": "daf81b82",
297
+ "metadata": {
298
+ "lines_to_next_cell": 1
299
+ },
300
+ "outputs": [],
301
+ "source": [
302
+ "@torch.no_grad()\n",
303
+ "def infer_wave(wave, target_emotion):\n",
304
+ " \"\"\"wave: numpy float32 (đã 16k mono). target_emotion: str hoặc None. Trả (emos, cat5, vad3).\"\"\"\n",
305
+ " wave = wave[: EMO_MAX_SEC * SR].astype(np.float32)\n",
306
+ " iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
307
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
308
+ " tgt = torch.from_numpy(onehot_target(norm_emotion(target_emotion) if target_emotion else None)).unsqueeze(0).to(device)\n",
309
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
310
+ " fw = wavlm_embed(iv, am)\n",
311
+ " if USE_AUDEERING:\n",
312
+ " fw = torch.cat([fw, torch.from_numpy(audeering_feat(wave)).unsqueeze(0).to(device)], dim=1)\n",
313
+ " emos_p, cat_l, vad_p = heads(fw, tgt)\n",
314
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
315
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
316
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
317
+ " return emos, cat5, vad3"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "markdown",
322
+ "id": "768a4de5",
323
+ "metadata": {},
324
+ "source": [
325
+ "## 4. Hàm metric val nội bộ (UTT-SRCC + CAT-err) — đánh giá độ tin cậy bộ chấm"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "f4900fe6",
332
+ "metadata": {
333
+ "lines_to_next_cell": 1
334
+ },
335
+ "outputs": [],
336
+ "source": [
337
+ "import pandas as pd\n",
338
+ "from scipy.stats import spearmanr\n",
339
+ "from sklearn.model_selection import train_test_split\n",
340
+ "from tqdm.auto import tqdm\n",
341
+ "\n",
342
+ "def parse_emocat_votes(cell):\n",
343
+ " v = np.zeros(N_EMO, dtype=np.float32)\n",
344
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
345
+ " e = norm_emotion(tok)\n",
346
+ " if e in EMOTIONS5:\n",
347
+ " v[EMOTIONS5.index(e)] += 1.0\n",
348
+ " return v\n",
349
+ "\n",
350
+ "def _col(cols_map, *names, df=None, default_idx=None):\n",
351
+ " for n in names:\n",
352
+ " if n in cols_map:\n",
353
+ " return cols_map[n]\n",
354
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
355
+ "\n",
356
+ "def load_train_labels():\n",
357
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
358
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
359
+ " wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
360
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
361
+ " val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
362
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
363
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
364
+ " rows = []\n",
365
+ " for sid, g in df.groupby(\"_stem\"):\n",
366
+ " rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
367
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
368
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
369
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
370
+ " votes = np.zeros(N_EMO, dtype=np.float32)\n",
371
+ " if cat_col:\n",
372
+ " for cell in g[cat_col]:\n",
373
+ " votes += parse_emocat_votes(cell)\n",
374
+ " s = votes.sum()\n",
375
+ " cat = votes / s if s > 0 else np.full(N_EMO, 0.2, dtype=np.float32)\n",
376
+ " for i in range(N_EMO):\n",
377
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
378
+ " rows.append(rec)\n",
379
+ " return pd.DataFrame(rows)\n",
380
+ "\n",
381
+ "# target cảm xúc theo wav (cho EMOS) từ metadata\n",
382
+ "def load_target_emotions():\n",
383
+ " tgt = {}\n",
384
+ " if os.path.exists(METADATA_CSV):\n",
385
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
386
+ " for ln in f:\n",
387
+ " parts = ln.strip().split(\"|\")\n",
388
+ " if len(parts) >= 2:\n",
389
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
390
+ " return tgt\n",
391
+ "\n",
392
+ "_target_map = None\n",
393
+ "_val_df = None\n",
394
+ "def _prep_eval():\n",
395
+ " \"\"\"Lazy: đọc nhãn + tách 10% val nội bộ (seed 42, khớp exp08).\"\"\"\n",
396
+ " global _target_map, _val_df\n",
397
+ " if _val_df is None:\n",
398
+ " _target_map = load_target_emotions()\n",
399
+ " df = load_train_labels()\n",
400
+ " df = df[df[\"wavID\"].map(lambda s: os.path.exists(os.path.join(WAV_DIR, s + \".wav\")))].reset_index(drop=True)\n",
401
+ " _, va = train_test_split(np.arange(len(df)), test_size=0.10, random_state=42)\n",
402
+ " _val_df = df.iloc[va].reset_index(drop=True)\n",
403
+ " return _target_map, _val_df\n",
404
+ "\n",
405
+ "def eval_metrics(limit):\n",
406
+ " tmap, vdf = _prep_eval()\n",
407
+ " n = min(int(limit), len(vdf))\n",
408
+ " P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
409
+ " catP, catY = [], []\n",
410
+ " for i in tqdm(range(n), desc=\"eval\"):\n",
411
+ " r = vdf.iloc[i]; sid = r[\"wavID\"]\n",
412
+ " wav = os.path.join(WAV_DIR, sid + \".wav\")\n",
413
+ " wave, _ = librosa.load(wav, sr=SR, mono=True)\n",
414
+ " emos, cat5, vad3 = infer_wave(wave, tmap.get(sid))\n",
415
+ " P[\"emos\"].append(emos); Y[\"emos\"].append(float(r[\"emos\"]))\n",
416
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
417
+ " P[t].append(float(vad3[j])); Y[t].append(float(r[t]))\n",
418
+ " catP.append(cat5); catY.append([r[f\"cat{k}\"] for k in range(N_EMO)])\n",
419
+ " rows = []\n",
420
+ " for t in [\"emos\", \"val\", \"aro\", \"dom\"]:\n",
421
+ " srcc = spearmanr(P[t], Y[t]).correlation\n",
422
+ " rows.append([t.upper(), f\"{srcc:.4f}\", f\"{EXP08.get(t, float('nan')):.3f}\"])\n",
423
+ " cat_err = float(np.abs(np.array(catP) - np.array(catY)).sum(1).mean())\n",
424
+ " rows.append([\"CAT-err ↓\", f\"{cat_err:.4f}\", f\"{EXP08['cat_err']:.3f}\"])\n",
425
+ " return rows"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "markdown",
430
+ "id": "36b32835",
431
+ "metadata": {},
432
+ "source": [
433
+ "## 5. Giao diện Gradio (2 tab)"
434
+ ]
435
+ },
436
+ {
437
+ "cell_type": "code",
438
+ "execution_count": null,
439
+ "id": "5ea0f8f6",
440
+ "metadata": {},
441
+ "outputs": [],
442
+ "source": [
443
+ "import gradio as gr\n",
444
+ "\n",
445
+ "def ui_predict(audio, target_emotion):\n",
446
+ " \"\"\"Trả về: verdict(md) · EMOS(number) · CAT(label) · VAL/ARO/DOM(number).\"\"\"\n",
447
+ " if not audio:\n",
448
+ " return \"### ⚠️ Hãy tải audio.\", None, {}, None, None, None\n",
449
+ " wave, _ = librosa.load(audio, sr=SR, mono=True)\n",
450
+ " emos, cat5, vad3 = infer_wave(wave, target_emotion)\n",
451
+ " cat_dict = {e: float(cat5[i]) for i, e in enumerate(EMOTIONS5)}\n",
452
+ " perceived = EMOTIONS5[int(np.argmax(cat5))]\n",
453
+ " if target_emotion:\n",
454
+ " match = \"✅ **KHỚP** target\" if perceived == norm_emotion(target_emotion) else \"⚠️ **LỆCH** target\"\n",
455
+ " band = \"🟢 tốt\" if emos >= 4 else (\"🟡 khá\" if emos >= 3 else \"🔴 yếu\")\n",
456
+ " verdict = (f\"### Kết luận biểu cảm\\n\"\n",
457
+ " f\"- Cảm xúc cảm nhận: **{perceived}** → {match} (`{target_emotion}`)\\n\"\n",
458
+ " f\"- EMOS = **{emos:.2f}/5** → biểu cảm {band}\")\n",
459
+ " else:\n",
460
+ " verdict = (f\"### Kết luận biểu cảm\\n\"\n",
461
+ " f\"- Cảm xúc cảm nhận: **{perceived}**\\n\"\n",
462
+ " f\"- *(Chọn cảm xúc target để bật EMOS — độ khớp ý đồ)*\")\n",
463
+ " emos = None\n",
464
+ " return verdict, (round(emos, 3) if emos is not None else None), cat_dict, \\\n",
465
+ " round(float(vad3[0]), 3), round(float(vad3[1]), 3), round(float(vad3[2]), 3)\n",
466
+ "\n",
467
+ "def ui_eval(limit):\n",
468
+ " return eval_metrics(limit)\n",
469
+ "\n",
470
+ "INTRO = (\n",
471
+ " \"# 🎙️ Emotional TTS Evaluator — VoiceMOS 2026 Track 2\\n\"\n",
472
+ " \"Bộ chấm **độ biểu cảm cảm xúc** của giọng TTS, chạy bằng model tốt nhất (**exp08**: WavLM fine-tune + \"\n",
473
+ " \"audeering). Offline, không cần API.\\n\\n\"\n",
474
+ " \"> **5 output dưới đây CHÍNH LÀ định nghĩa \\\"expressive emotion\\\" của Track 2** — mỗi cái trả lời một câu hỏi:\\n\"\n",
475
+ " \"> **EMOS** = có đúng cảm xúc được yêu cầu không · **CAT** = người nghe cảm nhận cảm xúc nào · \"\n",
476
+ " \"**VAD** = hóa trị / cường độ / chi phối.\"\n",
477
+ ")\n",
478
+ "\n",
479
+ "with gr.Blocks(title=\"VMC2026 Track 2 — Emotional TTS Evaluator (exp08)\") as demo:\n",
480
+ " gr.Markdown(INTRO)\n",
481
+ " with gr.Tab(\"🎯 Chấm 1 file TTS\"):\n",
482
+ " with gr.Row():\n",
483
+ " with gr.Column(scale=1):\n",
484
+ " a = gr.Audio(type=\"filepath\", label=\"Audio (giọng TTS)\")\n",
485
+ " tgt = gr.Dropdown(EMOTIONS5, label=\"🎯 Cảm xúc target (cho EMOS)\")\n",
486
+ " btn = gr.Button(\"Chấm cảm xúc\", variant=\"primary\")\n",
487
+ " with gr.Column(scale=2):\n",
488
+ " verdict = gr.Markdown()\n",
489
+ " with gr.Row():\n",
490
+ " emos_o = gr.Number(label=\"EMOS — khớp cảm xúc target (1–5)\", interactive=False)\n",
491
+ " cat_o = gr.Label(label=\"CAT — phân bố cảm xúc cảm nhận (5 lớp)\")\n",
492
+ " gr.Markdown(\"**VAD — toạ độ cảm xúc liên tục (1–5):**\")\n",
493
+ " with gr.Row():\n",
494
+ " val_o = gr.Number(label=\"Valence (tích cực↑)\", interactive=False)\n",
495
+ " aro_o = gr.Number(label=\"Arousal (kích động↑)\", interactive=False)\n",
496
+ " dom_o = gr.Number(label=\"Dominance (chi phối↑)\", interactive=False)\n",
497
+ " btn.click(ui_predict, [a, tgt], [verdict, emos_o, cat_o, val_o, aro_o, dom_o])\n",
498
+ " with gr.Tab(\"📊 Độ tin cậy bộ chấm\"):\n",
499
+ " gr.Markdown(\"Đo model tái lập nhãn người tốt tới đâu trên **val nội bộ** (10% train.csv, seed 42) — \"\n",
500
+ " \"**UTT-SRCC** (EMOS/VAD, cao=tốt) + **CAT-err** (thấp=tốt).\\n\"\n",
501
+ " \"⚠️ Dev label ẩn → đây **KHÔNG** phải điểm leaderboard, chỉ để biết bộ chấm đáng tin cỡ nào.\")\n",
502
+ " lim = gr.Slider(20, 300, value=100, step=20, label=\"Số mẫu val để chấm (nhiều = chậm)\")\n",
503
+ " tbl = gr.Dataframe(headers=[\"Cột\", \"Model (val nội bộ)\", \"Mốc exp08\"],\n",
504
+ " label=\"UTT-SRCC / CAT-err\", interactive=False)\n",
505
+ " gr.Button(\"Chạy đánh giá\", variant=\"primary\").click(ui_eval, [lim], [tbl])\n",
506
+ "\n",
507
+ "demo.launch(share=True)"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "markdown",
512
+ "id": "fcea6f73",
513
+ "metadata": {},
514
+ "source": [
515
+ "## Ghi chú\n",
516
+ "- Hằng `TRUNK_HIDDEN/HEAD_HIDDEN` PHẢI khớp exp08 (ckpt không lưu) — sai là lệch key/shape.\n",
517
+ "- EMOS cần cảm xúc target → chưa chọn dropdown thì chỉ hiện CAT/VAD.\n",
518
+ "- exp08 = mean-pool (không Mamba) → demo dùng `masked_mean`.\n",
519
+ "- Metric tab chấm trên val nội bộ train.csv (dev ẩn) → con số ~ mốc exp08 nếu trùng tập val.\n",
520
+ "- Cần GPU T4 + Internet On (tải WavLM/SAILER/audeering lần đầu)."
521
+ ]
522
+ }
523
+ ],
524
+ "metadata": {
525
+ "jupytext": {
526
+ "cell_metadata_filter": "-all",
527
+ "main_language": "python",
528
+ "notebook_metadata_filter": "-all"
529
+ }
530
+ },
531
+ "nbformat": 4,
532
+ "nbformat_minor": 5
533
+ }
track2/demo_track2_emotion_gradio_pipeline.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — Demo Gradio "Emotional TTS Evaluator" (model TỐT NHẤT = exp08)
3
+ #
4
+ # Demo này dùng **checkpoint cảm xúc tốt nhất** (`ft_emotion_full_20epoch.pt`: WavLM fine-tune warm-start
5
+ # SAILER + audeering frozen) để chấm **5 cột cảm xúc** của 1 file giọng TTS: **EMOS / CAT / VAL / ARO / DOM**.
6
+ # Khác demo cũ (`demo_track2_gradio`) dùng baseline UTMOS+emotion2vec+Gemini — bản này KHÔNG cần API.
7
+ #
8
+ # **2 tab:**
9
+ # 1. *Chấm 1 file TTS* — tải audio + chọn cảm xúc target → ra điểm biểu cảm cảm xúc + diễn giải KHỚP/LỆCH.
10
+ # 2. *Metric bộ chấm* — tính UTT-SRCC (EMOS/VAD) + CAT-err trên val nội bộ (train.csv) → cho biết độ tin cậy.
11
+ #
12
+ # **Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input (1) dataset Track 2, (2) dataset chứa
13
+ # `ft_emotion_full_20epoch.pt` → Run All → cell cuối in link `*.gradio.live`.
14
+
15
+ # %% [markdown]
16
+ # ## 0. Cấu hình — auto-dò DATA_ROOT + checkpoint
17
+
18
+ # %%
19
+ import os, glob
20
+
21
+ def find_data_root(search_root="/kaggle/input"):
22
+ cands = []
23
+ for train_csv in glob.glob(os.path.join(search_root, "**", "sets", "train.csv"), recursive=True):
24
+ root = os.path.dirname(os.path.dirname(train_csv))
25
+ score = os.path.isdir(os.path.join(root, "wav")) + os.path.exists(os.path.join(root, "metadata.csv"))
26
+ cands.append((score, root))
27
+ cands.sort(reverse=True)
28
+ return cands
29
+
30
+ _cands = find_data_root("/kaggle/input")
31
+ if _cands:
32
+ print("🔎 Ứng viên DATA_ROOT:")
33
+ for sc, r in _cands:
34
+ print(f" [{sc}/2] {r}")
35
+ DATA_ROOT = _cands[0][1]
36
+ print(f"👉 Tự chọn DATA_ROOT = {DATA_ROOT}")
37
+ else:
38
+ DATA_ROOT = "/kaggle/input/datasets/minhtoan2" # dự phòng
39
+ print(f"❌ Không thấy sets/train.csv → dự phòng {DATA_ROOT} (đã Add Input chưa?)")
40
+
41
+ WAV_DIR = f"{DATA_ROOT}/wav"
42
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv"
43
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv"
44
+
45
+ # ── Checkpoint cảm xúc exp08 (ưu tiên bản 20 epoch = TỐT NHẤT) ─────────────────
46
+ CKPT_PATH = "" # << "" = auto-dò; hoặc trỏ tay "/kaggle/input/<slug>/ft_emotion_full_20epoch.pt"
47
+
48
+ def find_ckpt(explicit):
49
+ if explicit and os.path.exists(explicit):
50
+ return explicit
51
+ pats = ["ft_emotion_full_20epoch*.pt", "ft_emotion_full*.pt"] # ưu tiên bản 20epoch
52
+ for pat in pats:
53
+ for base in ["/kaggle/input", "/kaggle/working"]:
54
+ hits = sorted(glob.glob(os.path.join(base, "**", pat), recursive=True))
55
+ if hits:
56
+ return hits[0]
57
+ return ""
58
+
59
+ CKPT_PATH = find_ckpt(CKPT_PATH)
60
+ assert CKPT_PATH, "❌ Không thấy ft_emotion_full*.pt. Add Input dataset chứa checkpoint exp08 chưa?"
61
+ print("✅ Checkpoint:", CKPT_PATH)
62
+
63
+ # ── Hằng kiến trúc PHẢI khớp exp08 (ckpt không lưu các số này) ────────────────
64
+ DEVICE = "cuda"
65
+ SR = 16000
66
+ EMO_MAX_SEC = 8
67
+ TRUNK_HIDDEN = 512
68
+ HEAD_HIDDEN = 128
69
+ DROPOUT = 0.3 # không ảnh hưởng eval
70
+ USE_AMP = True
71
+
72
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
73
+ _EMO_ALIAS = {
74
+ "angry": "angry", "anger": "angry",
75
+ "happy": "happy", "happiness": "happy", "joy": "happy",
76
+ "neutral": "neutral", "calm": "neutral",
77
+ "sad": "sad", "sadness": "sad",
78
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
79
+ }
80
+
81
+ def norm_emotion(label):
82
+ key = str(label).strip().lower()
83
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
84
+
85
+ def stem(p):
86
+ return os.path.splitext(os.path.basename(str(p)))[0]
87
+
88
+ # Mốc exp08 (val nội bộ / DEV) để so trong tab metric
89
+ EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751}
90
+
91
+ # %% [markdown]
92
+ # ## 1. Cài đặt + clone code SAILER
93
+
94
+ # %%
95
+ import sys, subprocess
96
+
97
+ def pip_install(*pkgs):
98
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
99
+
100
+ pip_install("gradio", "loralib", "speechbrain", "librosa", "soundfile",
101
+ "scipy", "scikit-learn", "pandas", "tqdm")
102
+
103
+ REPO_DIR = "/kaggle/working/vox-profile-release"
104
+ if not os.path.exists(REPO_DIR):
105
+ subprocess.run(["git", "clone", "--depth", "1",
106
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
107
+ if REPO_DIR not in sys.path:
108
+ sys.path.insert(0, REPO_DIR)
109
+
110
+ # %% [markdown]
111
+ # ## 2. Nạp model exp08 (backbone WavLM ft + audeering frozen + heads) — 1 lần
112
+
113
+ # %%
114
+ import torch
115
+ import torch.nn as nn
116
+ import torch.nn.functional as F
117
+ import numpy as np
118
+ import librosa
119
+
120
+ device = DEVICE if torch.cuda.is_available() else "cpu"
121
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (chậm)")
122
+
123
+ ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False) # ckpt có numpy → cần False
124
+ assert "wavlm" in ckpt and "heads" in ckpt, "❌ Checkpoint thiếu 'wavlm'/'heads' → cần ft_emotion_full_20epoch.pt đủ."
125
+ AUD_DIM = int(ckpt.get("AUD_DIM", 0))
126
+ USE_AUDEERING = AUD_DIM > 0
127
+ print("✅ Nạp ckpt | keys:", list(ckpt.keys()), "| AUD_DIM:", AUD_DIM, "(audeering", "ON)" if USE_AUDEERING else "OFF)")
128
+
129
+ def find_hf_backbone(module):
130
+ cands = []
131
+ for name, m in module.named_modules():
132
+ enc = getattr(m, "encoder", None)
133
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
134
+ and getattr(enc, "layers", None) is not None:
135
+ cands.append((name, m))
136
+ if not cands:
137
+ return None, None
138
+ cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
139
+ return cands[0]
140
+
141
+ wavlm = None
142
+ try:
143
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
144
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
145
+ _name, wavlm = find_hf_backbone(_wrapper)
146
+ if wavlm is not None:
147
+ print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{_name}'")
148
+ except Exception as e:
149
+ print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
150
+ if wavlm is None:
151
+ from transformers import WavLMModel
152
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
153
+ print("ℹ️ Fallback: microsoft/wavlm-large.")
154
+
155
+ wavlm = wavlm.to(device).eval()
156
+ WAVLM_DIM = int(wavlm.config.hidden_size)
157
+ wavlm.config.layerdrop = 0.0
158
+ _miss, _unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
159
+ print(f"🔁 load wavlm: thiếu {len(_miss)} / dư {len(_unexp)} key (kỳ vọng ~0)")
160
+
161
+ def masked_mean(hidden, attn_mask):
162
+ if attn_mask is None:
163
+ return hidden.mean(dim=1)
164
+ try:
165
+ fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
166
+ except Exception:
167
+ return hidden.mean(dim=1)
168
+ fm = fm.unsqueeze(-1).to(hidden.dtype)
169
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
170
+
171
+ @torch.no_grad()
172
+ def wavlm_embed(input_values, attn_mask):
173
+ out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
174
+ return masked_mean(out, attn_mask)
175
+
176
+ # ── audeering frozen (đặc trưng phụ) — chỉ dựng nếu ckpt có dùng ──
177
+ aud_backbone = aud_head = aud_proc = None
178
+ if USE_AUDEERING:
179
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
180
+ from huggingface_hub import hf_hub_download
181
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
182
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
183
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
184
+ aud_backbone = Wav2Vec2Model(aud_cfg)
185
+ try:
186
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
187
+ hf_hub_download(AUD_NAME, "model.safetensors"))
188
+ except Exception:
189
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
190
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
191
+ aud_backbone.load_state_dict(bb_sd, strict=False)
192
+ _hid = _sd["classifier.dense.weight"].shape[0]
193
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd["classifier.out_proj.weight"].shape[0]))
194
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
195
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
196
+ aud_backbone = aud_backbone.to(device).eval()
197
+ aud_head = aud_head.to(device).eval()
198
+ assert _hid + 3 == AUD_DIM, f"⚠️ AUD_DIM dựng ({_hid+3}) ≠ ckpt ({AUD_DIM})"
199
+ print(f"✅ audeering frozen ({AUD_DIM}-D)")
200
+
201
+ @torch.no_grad()
202
+ def audeering_feat(wave):
203
+ x = aud_proc(wave, sampling_rate=SR).input_values[0]
204
+ x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
205
+ h = aud_backbone(x)[0].mean(dim=1)
206
+ out = aud_head(h)[0].cpu().numpy()
207
+ vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
208
+ return np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
209
+
210
+ # ── EmoHeads (khớp exp08) + nạp trọng số + chuẩn hóa từ ckpt ──
211
+ N_EMO = len(EMOTIONS5)
212
+ TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
213
+
214
+ class EmoHeads(nn.Module):
215
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
216
+ super().__init__()
217
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
218
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
219
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
220
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
221
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
222
+ def forward(self, feat, tgt):
223
+ h = self.trunk(feat)
224
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
225
+
226
+ heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device).eval()
227
+ _hm, _hu = heads.load_state_dict(ckpt["heads"], strict=False)
228
+ print(f"🔁 load heads: thiếu {len(_hm)} / dư {len(_hu)} key (kỳ vọng 0)")
229
+
230
+ emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
231
+ vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
232
+ print(f"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
233
+
234
+ def onehot_target(tgt):
235
+ v = np.zeros(N_EMO, dtype=np.float32)
236
+ if tgt in EMOTIONS5:
237
+ v[EMOTIONS5.index(tgt)] = 1.0
238
+ return v
239
+
240
+ # %% [markdown]
241
+ # ## 3. Hàm suy luận lõi (1 wave numpy → emos/cat5/vad3)
242
+
243
+ # %%
244
+ @torch.no_grad()
245
+ def infer_wave(wave, target_emotion):
246
+ """wave: numpy float32 (đã 16k mono). target_emotion: str hoặc None. Trả (emos, cat5, vad3)."""
247
+ wave = wave[: EMO_MAX_SEC * SR].astype(np.float32)
248
+ iv = torch.from_numpy(wave).unsqueeze(0).to(device)
249
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
250
+ tgt = torch.from_numpy(onehot_target(norm_emotion(target_emotion) if target_emotion else None)).unsqueeze(0).to(device)
251
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
252
+ fw = wavlm_embed(iv, am)
253
+ if USE_AUDEERING:
254
+ fw = torch.cat([fw, torch.from_numpy(audeering_feat(wave)).unsqueeze(0).to(device)], dim=1)
255
+ emos_p, cat_l, vad_p = heads(fw, tgt)
256
+ emos = float(emos_p.item()) * emos_sd + emos_mu
257
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
258
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
259
+ return emos, cat5, vad3
260
+
261
+ # %% [markdown]
262
+ # ## 4. Hàm metric val nội bộ (UTT-SRCC + CAT-err) — đánh giá độ tin cậy bộ chấm
263
+
264
+ # %%
265
+ import pandas as pd
266
+ from scipy.stats import spearmanr
267
+ from sklearn.model_selection import train_test_split
268
+ from tqdm.auto import tqdm
269
+
270
+ def parse_emocat_votes(cell):
271
+ v = np.zeros(N_EMO, dtype=np.float32)
272
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
273
+ e = norm_emotion(tok)
274
+ if e in EMOTIONS5:
275
+ v[EMOTIONS5.index(e)] += 1.0
276
+ return v
277
+
278
+ def _col(cols_map, *names, df=None, default_idx=None):
279
+ for n in names:
280
+ if n in cols_map:
281
+ return cols_map[n]
282
+ return list(df.columns)[default_idx] if default_idx is not None else None
283
+
284
+ def load_train_labels():
285
+ df = pd.read_csv(TRAIN_CSV, sep="|")
286
+ cols = {c.lower().strip(): c for c in df.columns}
287
+ wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
288
+ emos_col = _col(cols, "emos", "emo", "emomos")
289
+ val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
290
+ cat_col = _col(cols, "emocat", "cat", "emotion")
291
+ df["_stem"] = df[wav_col].map(stem)
292
+ rows = []
293
+ for sid, g in df.groupby("_stem"):
294
+ rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
295
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
296
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
297
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
298
+ votes = np.zeros(N_EMO, dtype=np.float32)
299
+ if cat_col:
300
+ for cell in g[cat_col]:
301
+ votes += parse_emocat_votes(cell)
302
+ s = votes.sum()
303
+ cat = votes / s if s > 0 else np.full(N_EMO, 0.2, dtype=np.float32)
304
+ for i in range(N_EMO):
305
+ rec[f"cat{i}"] = float(cat[i])
306
+ rows.append(rec)
307
+ return pd.DataFrame(rows)
308
+
309
+ # target cảm xúc theo wav (cho EMOS) từ metadata
310
+ def load_target_emotions():
311
+ tgt = {}
312
+ if os.path.exists(METADATA_CSV):
313
+ with open(METADATA_CSV, encoding="utf-8") as f:
314
+ for ln in f:
315
+ parts = ln.strip().split("|")
316
+ if len(parts) >= 2:
317
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
318
+ return tgt
319
+
320
+ _target_map = None
321
+ _val_df = None
322
+ def _prep_eval():
323
+ """Lazy: đọc nhãn + tách 10% val nội bộ (seed 42, khớp exp08)."""
324
+ global _target_map, _val_df
325
+ if _val_df is None:
326
+ _target_map = load_target_emotions()
327
+ df = load_train_labels()
328
+ df = df[df["wavID"].map(lambda s: os.path.exists(os.path.join(WAV_DIR, s + ".wav")))].reset_index(drop=True)
329
+ _, va = train_test_split(np.arange(len(df)), test_size=0.10, random_state=42)
330
+ _val_df = df.iloc[va].reset_index(drop=True)
331
+ return _target_map, _val_df
332
+
333
+ def eval_metrics(limit):
334
+ tmap, vdf = _prep_eval()
335
+ n = min(int(limit), len(vdf))
336
+ P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
337
+ catP, catY = [], []
338
+ for i in tqdm(range(n), desc="eval"):
339
+ r = vdf.iloc[i]; sid = r["wavID"]
340
+ wav = os.path.join(WAV_DIR, sid + ".wav")
341
+ wave, _ = librosa.load(wav, sr=SR, mono=True)
342
+ emos, cat5, vad3 = infer_wave(wave, tmap.get(sid))
343
+ P["emos"].append(emos); Y["emos"].append(float(r["emos"]))
344
+ for j, t in enumerate(["val", "aro", "dom"]):
345
+ P[t].append(float(vad3[j])); Y[t].append(float(r[t]))
346
+ catP.append(cat5); catY.append([r[f"cat{k}"] for k in range(N_EMO)])
347
+ rows = []
348
+ for t in ["emos", "val", "aro", "dom"]:
349
+ srcc = spearmanr(P[t], Y[t]).correlation
350
+ rows.append([t.upper(), f"{srcc:.4f}", f"{EXP08.get(t, float('nan')):.3f}"])
351
+ cat_err = float(np.abs(np.array(catP) - np.array(catY)).sum(1).mean())
352
+ rows.append(["CAT-err ↓", f"{cat_err:.4f}", f"{EXP08['cat_err']:.3f}"])
353
+ return rows
354
+
355
+ # %% [markdown]
356
+ # ## 5. Giao diện Gradio (2 tab)
357
+
358
+ # %%
359
+ import gradio as gr
360
+
361
+ def ui_predict(audio, target_emotion):
362
+ """Trả về: verdict(md) · EMOS(number) · CAT(label) · VAL/ARO/DOM(number)."""
363
+ if not audio:
364
+ return "### ⚠️ Hãy tải audio.", None, {}, None, None, None
365
+ wave, _ = librosa.load(audio, sr=SR, mono=True)
366
+ emos, cat5, vad3 = infer_wave(wave, target_emotion)
367
+ cat_dict = {e: float(cat5[i]) for i, e in enumerate(EMOTIONS5)}
368
+ perceived = EMOTIONS5[int(np.argmax(cat5))]
369
+ if target_emotion:
370
+ match = "✅ **KHỚP** target" if perceived == norm_emotion(target_emotion) else "⚠️ **LỆCH** target"
371
+ band = "🟢 tốt" if emos >= 4 else ("🟡 khá" if emos >= 3 else "🔴 yếu")
372
+ verdict = (f"### Kết luận biểu cảm\n"
373
+ f"- Cảm xúc cảm nhận: **{perceived}** → {match} (`{target_emotion}`)\n"
374
+ f"- EMOS = **{emos:.2f}/5** → biểu cảm {band}")
375
+ else:
376
+ verdict = (f"### Kết luận biểu cảm\n"
377
+ f"- Cảm xúc cảm nhận: **{perceived}**\n"
378
+ f"- *(Chọn cảm xúc target để bật EMOS — độ khớp ý đồ)*")
379
+ emos = None
380
+ return verdict, (round(emos, 3) if emos is not None else None), cat_dict, \
381
+ round(float(vad3[0]), 3), round(float(vad3[1]), 3), round(float(vad3[2]), 3)
382
+
383
+ def ui_eval(limit):
384
+ return eval_metrics(limit)
385
+
386
+ INTRO = (
387
+ "# 🎙️ Emotional TTS Evaluator — VoiceMOS 2026 Track 2\n"
388
+ "Bộ chấm **độ biểu cảm cảm xúc** của giọng TTS, chạy bằng model tốt nhất (**exp08**: WavLM fine-tune + "
389
+ "audeering). Offline, không cần API.\n\n"
390
+ "> **5 output dưới đây CHÍNH LÀ định nghĩa \"expressive emotion\" của Track 2** — mỗi cái trả lời một câu hỏi:\n"
391
+ "> **EMOS** = có đúng cảm xúc được yêu cầu không · **CAT** = người nghe cảm nhận cảm xúc nào · "
392
+ "**VAD** = hóa trị / cường độ / chi phối."
393
+ )
394
+
395
+ with gr.Blocks(title="VMC2026 Track 2 — Emotional TTS Evaluator (exp08)") as demo:
396
+ gr.Markdown(INTRO)
397
+ with gr.Tab("🎯 Chấm 1 file TTS"):
398
+ with gr.Row():
399
+ with gr.Column(scale=1):
400
+ a = gr.Audio(type="filepath", label="Audio (giọng TTS)")
401
+ tgt = gr.Dropdown(EMOTIONS5, label="🎯 Cảm xúc target (cho EMOS)")
402
+ btn = gr.Button("Chấm cảm xúc", variant="primary")
403
+ with gr.Column(scale=2):
404
+ verdict = gr.Markdown()
405
+ with gr.Row():
406
+ emos_o = gr.Number(label="EMOS — khớp cảm xúc target (1–5)", interactive=False)
407
+ cat_o = gr.Label(label="CAT — phân bố cảm xúc cảm nhận (5 lớp)")
408
+ gr.Markdown("**VAD — toạ độ cảm xúc liên tục (1–5):**")
409
+ with gr.Row():
410
+ val_o = gr.Number(label="Valence (tích cực↑)", interactive=False)
411
+ aro_o = gr.Number(label="Arousal (kích động↑)", interactive=False)
412
+ dom_o = gr.Number(label="Dominance (chi phối↑)", interactive=False)
413
+ btn.click(ui_predict, [a, tgt], [verdict, emos_o, cat_o, val_o, aro_o, dom_o])
414
+ with gr.Tab("📊 Độ tin cậy bộ chấm"):
415
+ gr.Markdown("Đo model tái lập nhãn người tốt tới đâu trên **val nội bộ** (10% train.csv, seed 42) — "
416
+ "**UTT-SRCC** (EMOS/VAD, cao=tốt) + **CAT-err** (thấp=tốt).\n"
417
+ "⚠️ Dev label ẩn → đây **KHÔNG** phải điểm leaderboard, chỉ để biết bộ chấm đáng tin cỡ nào.")
418
+ lim = gr.Slider(20, 300, value=100, step=20, label="Số mẫu val để chấm (nhiều = chậm)")
419
+ tbl = gr.Dataframe(headers=["Cột", "Model (val nội bộ)", "Mốc exp08"],
420
+ label="UTT-SRCC / CAT-err", interactive=False)
421
+ gr.Button("Chạy đánh giá", variant="primary").click(ui_eval, [lim], [tbl])
422
+
423
+ demo.launch(share=True)
424
+
425
+ # %% [markdown]
426
+ # ## Ghi chú
427
+ # - Hằng `TRUNK_HIDDEN/HEAD_HIDDEN` PHẢI khớp exp08 (ckpt không lưu) — sai là lệch key/shape.
428
+ # - EMOS cần cảm xúc target → chưa chọn dropdown thì chỉ hiện CAT/VAD.
429
+ # - exp08 = mean-pool (không Mamba) → demo dùng `masked_mean`.
430
+ # - Metric tab chấm trên val nội bộ train.csv (dev ẩn) → con số ~ mốc exp08 nếu trùng tập val.
431
+ # - Cần GPU T4 + Internet On (tải WavLM/SAILER/audeering lần đầu).
track2/demo_track2_gradio.ipynb ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# VMC2026 Track 2 — Demo Gradio (Emotional TTS: QMOS / CAT / EMOS / VAD)\n",
8
+ "\n",
9
+ "- **QMOS** (UTMOS) + **CAT** (emotion2vec, 5 cảm xúc): chạy ngay, chỉ cần audio.\n",
10
+ "- **EMOS / VAD** (Gemini): tùy chọn — cần dán `GEMINI_API_KEY` + chọn cảm xúc target.\n",
11
+ "\n",
12
+ "### Cách dùng trên Kaggle\n",
13
+ "1. Settings → **GPU T4 + Internet On**.\n",
14
+ "2. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h)."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "markdown",
19
+ "metadata": {},
20
+ "source": [
21
+ "## 1. Cài đặt"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "!pip install -q gradio speechmos funasr librosa soundfile google-genai"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {},
36
+ "source": [
37
+ "## 2. Nạp model + hàm dự đoán"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "import re, json, librosa\n",
47
+ "\n",
48
+ "GEMINI_MODEL = \"gemini-2.0-flash\"\n",
49
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
50
+ "\n",
51
+ "_M = {}\n",
52
+ "\n",
53
+ "def _qmos():\n",
54
+ " if \"qmos\" not in _M:\n",
55
+ " import torch\n",
56
+ " _M[\"qmos\"] = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True)\n",
57
+ " return _M[\"qmos\"]\n",
58
+ "\n",
59
+ "\n",
60
+ "def _emocat():\n",
61
+ " if \"emocat\" not in _M:\n",
62
+ " from funasr import AutoModel\n",
63
+ " _M[\"emocat\"] = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\")\n",
64
+ " return _M[\"emocat\"]\n",
65
+ "\n",
66
+ "\n",
67
+ "def _gemini_emos_vad(audio_path, target_emotion, api_key):\n",
68
+ " \"\"\"EMOS (1-5 độ khớp cảm xúc target) + VAD (val/aro/dom 1-5) — bản demo gọn qua Gemini.\"\"\"\n",
69
+ " from google import genai\n",
70
+ " from google.genai import types\n",
71
+ " client = genai.Client(api_key=api_key)\n",
72
+ " part = types.Part.from_bytes(data=open(audio_path, \"rb\").read(), mime_type=\"audio/wav\")\n",
73
+ " cfg = types.GenerateContentConfig(temperature=0.0)\n",
74
+ "\n",
75
+ " p_emos = (f\"The target emotion is '{target_emotion}'. On a scale of 1 to 5, how well does the \"\n",
76
+ " f\"speaker express that emotion? 5=perfect match, 1=no match. Answer with ONLY one integer 1-5.\")\n",
77
+ " r = client.models.generate_content(model=GEMINI_MODEL, config=cfg, contents=[p_emos, part])\n",
78
+ " mm = re.search(r\"[1-5]\", getattr(r, \"text\", \"\") or \"\")\n",
79
+ " emos = int(mm.group()) if mm else None\n",
80
+ "\n",
81
+ " p_vad = ('Rate this speech on three 1-5 scales: Valence (1=very negative,5=very positive), '\n",
82
+ " 'Arousal (1=very calm,5=very excited), Dominance (1=very submissive,5=very dominant). '\n",
83
+ " 'Answer ONLY as JSON: {\"val\":x,\"aro\":y,\"dom\":z}.')\n",
84
+ " r2 = client.models.generate_content(model=GEMINI_MODEL, config=cfg, contents=[p_vad, part])\n",
85
+ " val = aro = dom = None\n",
86
+ " try:\n",
87
+ " d = json.loads(re.search(r\"\\{.*\\}\", getattr(r2, \"text\", \"\") or \"\", re.S).group())\n",
88
+ " val, aro, dom = d.get(\"val\"), d.get(\"aro\"), d.get(\"dom\")\n",
89
+ " except Exception:\n",
90
+ " pass\n",
91
+ " return emos, (val, aro, dom)\n",
92
+ "\n",
93
+ "\n",
94
+ "def predict(audio, target_emotion, gemini_key):\n",
95
+ " import torch\n",
96
+ " if not audio:\n",
97
+ " return \"⚠️ Hãy tải audio.\", {}\n",
98
+ " wav = librosa.load(audio, sr=16000, mono=True)[0]\n",
99
+ " # QMOS\n",
100
+ " qmos = float(_qmos()(torch.from_numpy(wav).unsqueeze(0), sr=16000).mean().item())\n",
101
+ " # CAT\n",
102
+ " rec = _emocat().generate(audio, granularity=\"utterance\", extract_embedding=False)\n",
103
+ " probs = {e: 0.0 for e in EMOTIONS5}\n",
104
+ " for lab, sc in zip(rec[0][\"labels\"], rec[0][\"scores\"]):\n",
105
+ " name = lab.split(\"/\")[-1]\n",
106
+ " if name in probs:\n",
107
+ " probs[name] = float(sc)\n",
108
+ " tot = sum(probs.values())\n",
109
+ " if tot > 0:\n",
110
+ " probs = {k: v / tot for k, v in probs.items()}\n",
111
+ "\n",
112
+ " lines = [f\"QMOS (chất lượng giọng, 1–5): {qmos:.3f}\"]\n",
113
+ " if gemini_key and target_emotion:\n",
114
+ " try:\n",
115
+ " emos, (val, aro, dom) = _gemini_emos_vad(audio, target_emotion, gemini_key)\n",
116
+ " lines.append(f\"EMOS (độ khớp cảm xúc '{target_emotion}', 1–5): {emos}\")\n",
117
+ " lines.append(f\"VAD — Valence: {val} · Arousal: {aro} · Dominance: {dom}\")\n",
118
+ " except Exception as e:\n",
119
+ " lines.append(f\"(EMOS/VAD lỗi: {e})\")\n",
120
+ " else:\n",
121
+ " lines.append(\"(EMOS/VAD: dán GEMINI_API_KEY + chọn cảm xúc target để bật)\")\n",
122
+ " return \"\\n\".join(lines), probs"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "metadata": {},
128
+ "source": [
129
+ "## 3. Giao diện Gradio + launch"
130
+ ]
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "execution_count": null,
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "import gradio as gr\n",
139
+ "\n",
140
+ "with gr.Blocks(title=\"VMC2026 Track 2 — Emotional TTS\") as demo:\n",
141
+ " gr.Markdown(\"# 🎙️ Track 2 · Emotional TTS (QMOS / CAT / EMOS / VAD)\\n\"\n",
142
+ " \"QMOS + phân bố cảm xúc (CAT) chạy ngay. EMOS/VAD cần Gemini key + cảm xúc target.\")\n",
143
+ " a = gr.Audio(type=\"filepath\", label=\"Audio\")\n",
144
+ " with gr.Row():\n",
145
+ " tgt = gr.Dropdown(EMOTIONS5, label=\"Cảm xúc target (cho EMOS, tùy chọn)\")\n",
146
+ " key = gr.Textbox(label=\"GEMINI_API_KEY (tùy chọn)\", type=\"password\")\n",
147
+ " out = gr.Textbox(label=\"Kết quả số\", lines=5)\n",
148
+ " lbl = gr.Label(label=\"CAT — phân bố cảm xúc cảm nhận\")\n",
149
+ " gr.Button(\"Dự đoán\", variant=\"primary\").click(predict, [a, tgt, key], [out, lbl])\n",
150
+ "\n",
151
+ "demo.launch(share=True)"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "metadata": {},
157
+ "source": [
158
+ "## Ghi chú\n",
159
+ "- EMOS/VAD là bản demo gọn (prompt rút gọn) — KHÔNG hoàn toàn giống script baseline gốc, chỉ minh họa."
160
+ ]
161
+ }
162
+ ],
163
+ "metadata": {
164
+ "kernelspec": {
165
+ "display_name": "Python 3",
166
+ "language": "python",
167
+ "name": "python3"
168
+ },
169
+ "language_info": {
170
+ "name": "python"
171
+ }
172
+ },
173
+ "nbformat": 4,
174
+ "nbformat_minor": 5
175
+ }
track2/demo_track2_gradio_pipeline.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — Demo Gradio (Emotional TTS: QMOS / CAT / EMOS / VAD)
3
+ #
4
+ # - **QMOS** (UTMOS) + **CAT** (emotion2vec, 5 cảm xúc): chạy ngay, chỉ cần audio.
5
+ # - **EMOS / VAD** (Gemini): tùy chọn — cần dán `GEMINI_API_KEY` + chọn cảm xúc target.
6
+ #
7
+ # ### Cách dùng trên Kaggle
8
+ # 1. Settings → **GPU T4 + Internet On**.
9
+ # 2. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h).
10
+
11
+ # %% [markdown]
12
+ # ## 1. Cài đặt
13
+
14
+ # %%
15
+ # !pip install -q gradio speechmos funasr librosa soundfile google-genai
16
+
17
+ # %% [markdown]
18
+ # ## 2. Nạp model + hàm dự đoán
19
+
20
+ # %%
21
+ import re, json, librosa
22
+
23
+ GEMINI_MODEL = "gemini-2.0-flash"
24
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
25
+
26
+ _M = {}
27
+
28
+ def _qmos():
29
+ if "qmos" not in _M:
30
+ import torch
31
+ _M["qmos"] = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
32
+ return _M["qmos"]
33
+
34
+
35
+ def _emocat():
36
+ if "emocat" not in _M:
37
+ from funasr import AutoModel
38
+ _M["emocat"] = AutoModel(model="iic/emotion2vec_plus_large", hub="hf")
39
+ return _M["emocat"]
40
+
41
+
42
+ def _gemini_emos_vad(audio_path, target_emotion, api_key):
43
+ """EMOS (1-5 độ khớp cảm xúc target) + VAD (val/aro/dom 1-5) — bản demo gọn qua Gemini."""
44
+ from google import genai
45
+ from google.genai import types
46
+ client = genai.Client(api_key=api_key)
47
+ part = types.Part.from_bytes(data=open(audio_path, "rb").read(), mime_type="audio/wav")
48
+ cfg = types.GenerateContentConfig(temperature=0.0)
49
+
50
+ p_emos = (f"The target emotion is '{target_emotion}'. On a scale of 1 to 5, how well does the "
51
+ f"speaker express that emotion? 5=perfect match, 1=no match. Answer with ONLY one integer 1-5.")
52
+ r = client.models.generate_content(model=GEMINI_MODEL, config=cfg, contents=[p_emos, part])
53
+ mm = re.search(r"[1-5]", getattr(r, "text", "") or "")
54
+ emos = int(mm.group()) if mm else None
55
+
56
+ p_vad = ('Rate this speech on three 1-5 scales: Valence (1=very negative,5=very positive), '
57
+ 'Arousal (1=very calm,5=very excited), Dominance (1=very submissive,5=very dominant). '
58
+ 'Answer ONLY as JSON: {"val":x,"aro":y,"dom":z}.')
59
+ r2 = client.models.generate_content(model=GEMINI_MODEL, config=cfg, contents=[p_vad, part])
60
+ val = aro = dom = None
61
+ try:
62
+ d = json.loads(re.search(r"\{.*\}", getattr(r2, "text", "") or "", re.S).group())
63
+ val, aro, dom = d.get("val"), d.get("aro"), d.get("dom")
64
+ except Exception:
65
+ pass
66
+ return emos, (val, aro, dom)
67
+
68
+
69
+ def predict(audio, target_emotion, gemini_key):
70
+ import torch
71
+ if not audio:
72
+ return "⚠️ Hãy tải audio.", {}
73
+ wav = librosa.load(audio, sr=16000, mono=True)[0]
74
+ # QMOS
75
+ qmos = float(_qmos()(torch.from_numpy(wav).unsqueeze(0), sr=16000).mean().item())
76
+ # CAT
77
+ rec = _emocat().generate(audio, granularity="utterance", extract_embedding=False)
78
+ probs = {e: 0.0 for e in EMOTIONS5}
79
+ for lab, sc in zip(rec[0]["labels"], rec[0]["scores"]):
80
+ name = lab.split("/")[-1]
81
+ if name in probs:
82
+ probs[name] = float(sc)
83
+ tot = sum(probs.values())
84
+ if tot > 0:
85
+ probs = {k: v / tot for k, v in probs.items()}
86
+
87
+ lines = [f"QMOS (chất lượng giọng, 1–5): {qmos:.3f}"]
88
+ if gemini_key and target_emotion:
89
+ try:
90
+ emos, (val, aro, dom) = _gemini_emos_vad(audio, target_emotion, gemini_key)
91
+ lines.append(f"EMOS (độ khớp cảm xúc '{target_emotion}', 1–5): {emos}")
92
+ lines.append(f"VAD — Valence: {val} · Arousal: {aro} · Dominance: {dom}")
93
+ except Exception as e:
94
+ lines.append(f"(EMOS/VAD lỗi: {e})")
95
+ else:
96
+ lines.append("(EMOS/VAD: dán GEMINI_API_KEY + chọn cảm xúc target để bật)")
97
+ return "\n".join(lines), probs
98
+
99
+ # %% [markdown]
100
+ # ## 3. Giao diện Gradio + launch
101
+
102
+ # %%
103
+ import gradio as gr
104
+
105
+ with gr.Blocks(title="VMC2026 Track 2 — Emotional TTS") as demo:
106
+ gr.Markdown("# 🎙️ Track 2 · Emotional TTS (QMOS / CAT / EMOS / VAD)\n"
107
+ "QMOS + phân bố cảm xúc (CAT) chạy ngay. EMOS/VAD cần Gemini key + cảm xúc target.")
108
+ a = gr.Audio(type="filepath", label="Audio")
109
+ with gr.Row():
110
+ tgt = gr.Dropdown(EMOTIONS5, label="Cảm xúc target (cho EMOS, tùy chọn)")
111
+ key = gr.Textbox(label="GEMINI_API_KEY (tùy chọn)", type="password")
112
+ out = gr.Textbox(label="Kết quả số", lines=5)
113
+ lbl = gr.Label(label="CAT — phân bố cảm xúc cảm nhận")
114
+ gr.Button("Dự đoán", variant="primary").click(predict, [a, tgt, key], [out, lbl])
115
+
116
+ demo.launch(share=True)
117
+
118
+ # %% [markdown]
119
+ # ## Ghi chú
120
+ # - EMOS/VAD là bản demo gọn (prompt rút gọn) — KHÔNG hoàn toàn giống script baseline gốc, chỉ minh họa.
track2/exp02_train_emos.ipynb ADDED
@@ -0,0 +1,542 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "b886ed3a",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp02 (EMOS có train) — Kaggle\n",
9
+ "\n",
10
+ "**Mục tiêu:** train một model dự đoán **EMOS** (độ khớp cảm xúc target) từ ~12.746 mẫu\n",
11
+ "có nhãn người nghe trong `sets/train.csv`, kỳ vọng **vượt baseline 0.194** (exp01 offline).\n",
12
+ "\n",
13
+ "## Ý tưởng (đọc 1 lần cho hiểu)\n",
14
+ "EMOS phụ thuộc **cả audio LẪN cảm xúc target** (cùng audio \"vui\": target=happy → điểm cao,\n",
15
+ "target=sad → điểm thấp). Vì vậy model phải nhận vào cả hai:\n",
16
+ "\n",
17
+ "```\n",
18
+ "mỗi wav ─► emotion2vec ─► (a) embedding ~D chiều ┐\n",
19
+ " (b) xác suất 5 cảm xúc ├─► nối ─► MLP head ─► EMOS (1–5)\n",
20
+ " target emotion ───► one-hot 5 chiều ┘ (CÁI MÌNH TRAIN)\n",
21
+ "```\n",
22
+ "\n",
23
+ "- **Backbone emotion2vec ĐÓNG BĂNG** (không train lại) → chỉ trích đặc trưng. Nhẹ GPU, ít data vẫn ổn.\n",
24
+ "- **Chỉ train MLP head nhỏ** → học ánh xạ `(đặc trưng + target) → điểm người chấm`.\n",
25
+ "- **Nhãn vàng** = trung bình `eMOS` của mọi listener trên cùng 1 wav (gộp theo `wavID`).\n",
26
+ "- Embedding **trích 1 lần → cache .npz** (12.746 file rất lâu, chạy lại tốn giờ GPU).\n",
27
+ "- Tách 10% train làm **validation nội bộ** → đo SRCC trong lúc train (DEV không có nhãn để tự chấm).\n",
28
+ "- Cuối cùng xuất `answer.txt` **đầy đủ**: QMOS=SpeechMOS · CAT=emotion2vec · **EMOS=head vừa train** → nộp được ngay.\n",
29
+ "\n",
30
+ "**Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On** → + Add Input dataset\n",
31
+ "Track 2 (15.477 wav, có `sets/train.csv`) → sửa `DATA_ROOT` ở cell 0 → Run All."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "id": "d0fadc26",
37
+ "metadata": {},
38
+ "source": [
39
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "7fac05b4",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "import os, glob, json, time\n",
50
+ "\n",
51
+ "# ── Data Track 2 (dataset 15.477 wav đã ráp, có sets/train.csv) ──────────────\n",
52
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
53
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
54
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) → target emotion\n",
55
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # nhãn người nghe: lisID,wavID,qMOS,emoCat,eMOS,val,dom,aro\n",
56
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV (tập cần nộp ở training phase)\n",
57
+ "\n",
58
+ "OUT_DIR = \"/kaggle/working\"\n",
59
+ "CACHE_DIR = \"/kaggle/working/emb_cache\" # nơi lưu embedding đã trích (tái dùng giữa các lần chạy)\n",
60
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
61
+ "\n",
62
+ "# ── Siêu tham số train (đổi nếu muốn thử nghiệm) ─────────────────────────────\n",
63
+ "DEVICE = \"cuda\" # \"cuda\" trên Kaggle GPU; \"cpu\" nếu không có GPU\n",
64
+ "HIDDEN = 256 # số neuron lớp ẩn của MLP head\n",
65
+ "DROPOUT = 0.3\n",
66
+ "LR = 1e-3\n",
67
+ "EPOCHS = 60\n",
68
+ "BATCH = 64\n",
69
+ "VAL_FRAC = 0.10 # 10% train → validation nội bộ (đo SRCC)\n",
70
+ "PATIENCE = 12 # early stop: dừng nếu val-SRCC không cải thiện sau N epoch\n",
71
+ "SEED = 42\n",
72
+ "\n",
73
+ "LIMIT_TRAIN = None # đặt số nhỏ (vd 300) để chạy thử nhanh; None = full\n",
74
+ "USE_CLASSPROB = True # thêm 5 xác suất cảm xúc của emotion2vec vào feature (tín hiệu exp01)\n",
75
+ "\n",
76
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
77
+ "\n",
78
+ "_EMO_ALIAS = {\n",
79
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
80
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
81
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
82
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
83
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
84
+ "}\n",
85
+ "\n",
86
+ "def norm_emotion(label):\n",
87
+ " \"\"\"Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp.\"\"\"\n",
88
+ " key = str(label).strip().lower()\n",
89
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
90
+ "\n",
91
+ "def stem(path_or_name):\n",
92
+ " \"\"\"Lấy tên file không đuôi, để khớp wavID giữa train.csv / metadata / dev.scp.\"\"\"\n",
93
+ " return os.path.splitext(os.path.basename(str(path_or_name)))[0]\n",
94
+ "\n",
95
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
96
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
97
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "id": "b947aceb",
103
+ "metadata": {},
104
+ "source": [
105
+ "## 1. Cài đặt"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "49850676",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "!pip install -q speechmos funasr librosa soundfile pandas scipy scikit-learn tqdm"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "markdown",
120
+ "id": "b4e67d6a",
121
+ "metadata": {},
122
+ "source": [
123
+ "## 2. Đọc & gộp nhãn\n",
124
+ "- `train.csv`: mỗi dòng = 1 listener chấm 1 wav → **gộp trung bình eMOS theo wavID** = nhãn vàng.\n",
125
+ "- `metadata.csv`: lấy **cảm xúc target** cho mỗi wav (chuẩn hóa về 5 lớp)."
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "88b2fb84",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "import pandas as pd\n",
136
+ "\n",
137
+ "def load_target_emotions():\n",
138
+ " \"\"\"metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}.\"\"\"\n",
139
+ " tgt = {}\n",
140
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
141
+ " for ln in f:\n",
142
+ " parts = ln.strip().split(\"|\")\n",
143
+ " if len(parts) < 2:\n",
144
+ " continue\n",
145
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
146
+ " return tgt\n",
147
+ "\n",
148
+ "def load_train_labels():\n",
149
+ " \"\"\"train.csv → DataFrame [wavID(stem), emos] đã gộp trung bình theo wav.\"\"\"\n",
150
+ " df = pd.read_csv(TRAIN_CSV)\n",
151
+ " # Chuẩn hóa tên cột (phòng khi viết hoa/thường khác nhau)\n",
152
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
153
+ " wav_col = cols.get(\"wavid\") or cols.get(\"wav\") or list(df.columns)[1]\n",
154
+ " emos_col = cols.get(\"emos\") or cols.get(\"emo\") or cols.get(\"emomos\")\n",
155
+ " assert emos_col, f\"Không thấy cột eMOS trong train.csv (cột hiện có: {list(df.columns)})\"\n",
156
+ " g = df.groupby(df[wav_col].map(stem))[emos_col].mean()\n",
157
+ " out = g.reset_index()\n",
158
+ " out.columns = [\"wavID\", \"emos\"]\n",
159
+ " return out\n",
160
+ "\n",
161
+ "target_map = load_target_emotions()\n",
162
+ "train_df = load_train_labels()\n",
163
+ "print(f\"Target emotions: {len(target_map)} | wav train (đã gộp): {len(train_df)}\")\n",
164
+ "print(\"eMOS thống kê:\", train_df[\"emos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
165
+ "train_df.head()"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "markdown",
170
+ "id": "06dea3ef",
171
+ "metadata": {},
172
+ "source": [
173
+ "## 3. Trích đặc trưng emotion2vec (có cache)\n",
174
+ "Mỗi wav → 1 lần `generate(extract_embedding=True)` cho ra **embedding** (cho EMOS) +\n",
175
+ "**xác suất 5 lớp** (cho CAT và làm feature). Lưu cache `.npz` để lần sau khỏi chạy lại."
176
+ ]
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "id": "c63cc1f5",
182
+ "metadata": {
183
+ "lines_to_next_cell": 1
184
+ },
185
+ "outputs": [],
186
+ "source": [
187
+ "import numpy as np\n",
188
+ "\n",
189
+ "_e2v_model = None\n",
190
+ "def get_e2v():\n",
191
+ " global _e2v_model\n",
192
+ " if _e2v_model is None:\n",
193
+ " from funasr import AutoModel\n",
194
+ " _e2v_model = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\")\n",
195
+ " return _e2v_model\n",
196
+ "\n",
197
+ "def extract_one(wav_path):\n",
198
+ " \"\"\"→ (emb: np.float32[D], probs5: np.float32[5] tổng=1). None nếu lỗi/thiếu file.\"\"\"\n",
199
+ " if not os.path.exists(wav_path):\n",
200
+ " return None\n",
201
+ " rec = get_e2v().generate(wav_path, granularity=\"utterance\", extract_embedding=True)\n",
202
+ " r = rec[0]\n",
203
+ " emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
204
+ " probs = {e: 0.0 for e in EMOTIONS5}\n",
205
+ " for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
206
+ " name = lab.split(\"/\")[-1]\n",
207
+ " if name in probs:\n",
208
+ " probs[name] = float(sc)\n",
209
+ " tot = sum(probs.values())\n",
210
+ " if tot > 0:\n",
211
+ " probs = {k: v / tot for k, v in probs.items()}\n",
212
+ " probs5 = np.array([probs[e] for e in EMOTIONS5], dtype=np.float32)\n",
213
+ " return emb, probs5\n",
214
+ "\n",
215
+ "def extract_set(stems, tag):\n",
216
+ " \"\"\"Trích (hoặc nạp cache) cho danh sách stem. Trả về dict {stem: (emb, probs5)}.\n",
217
+ " Cache lưu tại CACHE_DIR/<tag>.npz; tự b�� qua stem đã có để chạy nối tiếp được.\"\"\"\n",
218
+ " from tqdm.auto import tqdm\n",
219
+ " cache_path = os.path.join(CACHE_DIR, f\"{tag}.npz\")\n",
220
+ " store = {}\n",
221
+ " if os.path.exists(cache_path):\n",
222
+ " z = np.load(cache_path, allow_pickle=True)\n",
223
+ " store = {k: z[k] for k in z.files}\n",
224
+ " print(f\"[{tag}] nạp cache: {len(store)} mẫu\")\n",
225
+ " todo = [s for s in stems if s not in store]\n",
226
+ " if not todo:\n",
227
+ " print(f\"[{tag}] đủ cache, bỏ qua trích.\")\n",
228
+ " else:\n",
229
+ " miss = 0\n",
230
+ " for i, s in enumerate(tqdm(todo, desc=f\"trích {tag}\")):\n",
231
+ " res = extract_one(os.path.join(WAV_DIR, s + \".wav\"))\n",
232
+ " if res is None:\n",
233
+ " miss += 1\n",
234
+ " continue\n",
235
+ " emb, probs5 = res\n",
236
+ " store[s] = np.concatenate([emb, probs5]).astype(np.float32) # [D + 5]\n",
237
+ " if (i + 1) % 500 == 0: # lưu cache định kỳ phòng ngắt session\n",
238
+ " np.savez(cache_path, **store)\n",
239
+ " np.savez(cache_path, **store)\n",
240
+ " if miss:\n",
241
+ " print(f\"[{tag}] {miss} file thiếu/ lỗi → bỏ qua.\")\n",
242
+ " print(f\"[{tag}] tổng cache: {len(store)} mẫu → {cache_path}\")\n",
243
+ " # tách lại thành (emb, probs5)\n",
244
+ " out = {}\n",
245
+ " for s, vec in store.items():\n",
246
+ " out[s] = (vec[:-5], vec[-5:])\n",
247
+ " return out\n",
248
+ "\n",
249
+ "# Trích cho tập train\n",
250
+ "train_stems = list(train_df[\"wavID\"])\n",
251
+ "if LIMIT_TRAIN:\n",
252
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
253
+ "train_feat = extract_set(train_stems, \"train\")\n",
254
+ "EMB_DIM = next(iter(train_feat.values()))[0].shape[0]\n",
255
+ "print(\"EMB_DIM =\", EMB_DIM)"
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "id": "015834c1",
261
+ "metadata": {},
262
+ "source": [
263
+ "## 4. Dựng feature + nhãn cho train\n",
264
+ "Feature mỗi wav = `[embedding | (probs5 nếu bật) | one-hot target(5)]`. Bỏ wav thiếu target/feature."
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "id": "62e048bd",
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "def onehot_target(tgt):\n",
275
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
276
+ " if tgt in EMOTIONS5:\n",
277
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
278
+ " return v\n",
279
+ "\n",
280
+ "def build_feature(stem_id, feat_map):\n",
281
+ " pack = feat_map.get(stem_id)\n",
282
+ " if pack is None:\n",
283
+ " return None\n",
284
+ " emb, probs5 = pack\n",
285
+ " tgt = target_map.get(stem_id)\n",
286
+ " if tgt is None: # không biết cảm xúc target → không train được mẫu này\n",
287
+ " return None\n",
288
+ " parts = [emb]\n",
289
+ " if USE_CLASSPROB:\n",
290
+ " parts.append(probs5)\n",
291
+ " parts.append(onehot_target(tgt))\n",
292
+ " return np.concatenate(parts).astype(np.float32)\n",
293
+ "\n",
294
+ "emos_label = dict(zip(train_df[\"wavID\"], train_df[\"emos\"]))\n",
295
+ "X, y = [], []\n",
296
+ "for s in train_stems:\n",
297
+ " f = build_feature(s, train_feat)\n",
298
+ " if f is None or s not in emos_label:\n",
299
+ " continue\n",
300
+ " X.append(f); y.append(emos_label[s])\n",
301
+ "X = np.stack(X); y = np.array(y, dtype=np.float32)\n",
302
+ "FEAT_DIM = X.shape[1]\n",
303
+ "print(f\"Train: X={X.shape} y={y.shape} FEAT_DIM={FEAT_DIM}\")\n",
304
+ "\n",
305
+ "# Chuẩn hóa feature (z-score) — lưu mean/std để áp dụng y hệt lúc dự đoán DEV.\n",
306
+ "feat_mean = X.mean(0, keepdims=True)\n",
307
+ "feat_std = X.std(0, keepdims=True) + 1e-6\n",
308
+ "Xn = (X - feat_mean) / feat_std"
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "id": "02718889",
314
+ "metadata": {},
315
+ "source": [
316
+ "## 5. Model (MLP head) + train loop\n",
317
+ "Loss = MSE. Theo dõi **SRCC** trên validation nội bộ; lưu model tốt nhất (early stopping)."
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "id": "f5c52fa4",
324
+ "metadata": {
325
+ "lines_to_next_cell": 1
326
+ },
327
+ "outputs": [],
328
+ "source": [
329
+ "import torch, torch.nn as nn\n",
330
+ "from scipy.stats import spearmanr\n",
331
+ "from sklearn.model_selection import train_test_split\n",
332
+ "\n",
333
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
334
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
335
+ "print(\"Device:\", device)\n",
336
+ "\n",
337
+ "Xtr, Xva, ytr, yva = train_test_split(Xn, y, test_size=VAL_FRAC, random_state=SEED)\n",
338
+ "Xtr_t = torch.tensor(Xtr, device=device); ytr_t = torch.tensor(ytr, device=device).unsqueeze(1)\n",
339
+ "Xva_t = torch.tensor(Xva, device=device); yva_t = torch.tensor(yva, device=device).unsqueeze(1)\n",
340
+ "\n",
341
+ "class EmosHead(nn.Module):\n",
342
+ " def __init__(self, d_in, hidden, p):\n",
343
+ " super().__init__()\n",
344
+ " self.net = nn.Sequential(\n",
345
+ " nn.Linear(d_in, hidden), nn.ReLU(), nn.Dropout(p),\n",
346
+ " nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Dropout(p),\n",
347
+ " nn.Linear(hidden // 2, 1),\n",
348
+ " )\n",
349
+ " def forward(self, x):\n",
350
+ " return self.net(x)\n",
351
+ "\n",
352
+ "model = EmosHead(FEAT_DIM, HIDDEN, DROPOUT).to(device)\n",
353
+ "opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)\n",
354
+ "lossf = nn.MSELoss()\n",
355
+ "\n",
356
+ "def val_srcc():\n",
357
+ " model.eval()\n",
358
+ " with torch.no_grad():\n",
359
+ " pred = model(Xva_t).cpu().numpy().ravel()\n",
360
+ " return spearmanr(pred, yva).correlation\n",
361
+ "\n",
362
+ "best_srcc, best_state, bad = -1.0, None, 0\n",
363
+ "n = Xtr_t.shape[0]\n",
364
+ "for ep in range(1, EPOCHS + 1):\n",
365
+ " model.train()\n",
366
+ " perm = torch.randperm(n, device=device)\n",
367
+ " tot = 0.0\n",
368
+ " for i in range(0, n, BATCH):\n",
369
+ " idx = perm[i:i + BATCH]\n",
370
+ " opt.zero_grad()\n",
371
+ " out = model(Xtr_t[idx])\n",
372
+ " loss = lossf(out, ytr_t[idx])\n",
373
+ " loss.backward(); opt.step()\n",
374
+ " tot += loss.item() * len(idx)\n",
375
+ " srcc = val_srcc()\n",
376
+ " if srcc > best_srcc:\n",
377
+ " best_srcc, best_state, bad = srcc, {k: v.cpu().clone() for k, v in model.state_dict().items()}, 0\n",
378
+ " else:\n",
379
+ " bad += 1\n",
380
+ " if ep % 5 == 0 or ep == 1:\n",
381
+ " print(f\"epoch {ep:3d} | train MSE {tot/n:.4f} | val SRCC {srcc:.4f} | best {best_srcc:.4f}\")\n",
382
+ " if bad >= PATIENCE:\n",
383
+ " print(f\"Early stop ở epoch {ep} (val SRCC không tăng {PATIENCE} epoch).\")\n",
384
+ " break\n",
385
+ "\n",
386
+ "model.load_state_dict(best_state)\n",
387
+ "print(f\"\\n✅ VAL SRCC tốt nhất = {best_srcc:.4f} (baseline exp01 ≈ 0.194 — so ở đây)\")\n",
388
+ "\n",
389
+ "# Lưu model + tham số chuẩn hóa để tái dùng / mô tả hệ thống.\n",
390
+ "torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
391
+ " \"EMB_DIM\": EMB_DIM, \"FEAT_DIM\": FEAT_DIM, \"USE_CLASSPROB\": USE_CLASSPROB,\n",
392
+ " \"EMOTIONS5\": EMOTIONS5, \"val_srcc\": float(best_srcc)},\n",
393
+ " os.path.join(OUT_DIR, \"emos_head.pt\"))\n",
394
+ "print(\"Đã lưu\", os.path.join(OUT_DIR, \"emos_head.pt\"))"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "markdown",
399
+ "id": "aef6f3ee",
400
+ "metadata": {},
401
+ "source": [
402
+ "## 6. Dự đoán DEV → `answer.txt` đầy đủ\n",
403
+ "- **EMOS** = head vừa train (cần embedding + target của từng wav DEV).\n",
404
+ "- **CAT** = xác suất 5 lớp emotion2vec (đã có sẵn khi trích đặc trưng).\n",
405
+ "- **QMOS** = SpeechMOS (UTMOS) — bắt buộc, chạy thêm ở đây để answer.txt hợp lệ."
406
+ ]
407
+ },
408
+ {
409
+ "cell_type": "code",
410
+ "execution_count": null,
411
+ "id": "8f94d7ab",
412
+ "metadata": {
413
+ "lines_to_next_cell": 1
414
+ },
415
+ "outputs": [],
416
+ "source": [
417
+ "def list_dev():\n",
418
+ " with open(DEV_SCP) as f:\n",
419
+ " return [ln.strip() for ln in f if ln.strip()] # tên file .wav\n",
420
+ "\n",
421
+ "dev_names = list_dev()\n",
422
+ "dev_stems = [stem(n) for n in dev_names]\n",
423
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
424
+ "\n",
425
+ "# 6a. Trích đặc trưng emotion2vec cho DEV (cache riêng)\n",
426
+ "dev_feat = extract_set(dev_stems, \"dev\")\n",
427
+ "\n",
428
+ "# 6b. EMOS từ head đã train\n",
429
+ "def predict_emos(stem_id):\n",
430
+ " f = build_feature(stem_id, dev_feat)\n",
431
+ " if f is None:\n",
432
+ " return None\n",
433
+ " fn = (f[None, :] - feat_mean) / feat_std\n",
434
+ " model.eval()\n",
435
+ " with torch.no_grad():\n",
436
+ " return float(model(torch.tensor(fn, dtype=torch.float32, device=device)).item())\n",
437
+ "\n",
438
+ "# 6c. QMOS = SpeechMOS\n",
439
+ "def run_qmos(names):\n",
440
+ " import librosa\n",
441
+ " predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True)\n",
442
+ " out = {}\n",
443
+ " from tqdm.auto import tqdm\n",
444
+ " for n in tqdm(names, desc=\"QMOS\"):\n",
445
+ " p = os.path.join(WAV_DIR, n)\n",
446
+ " if not os.path.exists(p):\n",
447
+ " continue\n",
448
+ " wave, _ = librosa.load(p, sr=16000, mono=True)\n",
449
+ " out[n] = float(predictor(torch.from_numpy(wave).unsqueeze(0), sr=16000).mean().item())\n",
450
+ " return out\n",
451
+ "\n",
452
+ "qmos_scores = run_qmos(dev_names)"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "code",
457
+ "execution_count": null,
458
+ "id": "6a6680d0",
459
+ "metadata": {
460
+ "lines_to_next_cell": 1
461
+ },
462
+ "outputs": [],
463
+ "source": [
464
+ "def fmt_cat(probs5):\n",
465
+ " return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
466
+ "\n",
467
+ "def build_answer(out_path):\n",
468
+ " n_emos = n_default = 0\n",
469
+ " with open(out_path, \"w\") as f:\n",
470
+ " f.write(\"wav,QMOS,EMOS,CAT\\n\")\n",
471
+ " for name in dev_names:\n",
472
+ " sid = stem(name)\n",
473
+ " emos = predict_emos(sid)\n",
474
+ " if emos is None:\n",
475
+ " emos = 3.0; n_default += 1\n",
476
+ " else:\n",
477
+ " n_emos += 1\n",
478
+ " qmos = qmos_scores.get(name, 3.0)\n",
479
+ " probs5 = dev_feat[sid][1] if sid in dev_feat else np.full(5, 0.2, dtype=np.float32)\n",
480
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(probs5)}\\n\")\n",
481
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default}\")\n",
482
+ "\n",
483
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
484
+ "build_answer(answer_path)"
485
+ ]
486
+ },
487
+ {
488
+ "cell_type": "markdown",
489
+ "id": "6179873f",
490
+ "metadata": {},
491
+ "source": [
492
+ "## 7. Validate + đóng zip"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "execution_count": null,
498
+ "id": "30ee8626",
499
+ "metadata": {},
500
+ "outputs": [],
501
+ "source": [
502
+ "def validate(path):\n",
503
+ " import csv\n",
504
+ " with open(path) as f:\n",
505
+ " rows = list(csv.reader(f))\n",
506
+ " header = rows[0]\n",
507
+ " assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
508
+ " for i, r in enumerate(rows[1:], 2):\n",
509
+ " assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
510
+ " print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
511
+ "\n",
512
+ "validate(answer_path)\n",
513
+ "!cd /kaggle/working && zip -j submission_track2_exp02.zip answer.txt && unzip -l submission_track2_exp02.zip\n",
514
+ "print(\"Sẵn sàng nộp: /kaggle/working/submission_track2_exp02.zip\")"
515
+ ]
516
+ },
517
+ {
518
+ "cell_type": "markdown",
519
+ "id": "316e3e1f",
520
+ "metadata": {},
521
+ "source": [
522
+ "## Ghi chú\n",
523
+ "- **VAL SRCC** in ở mục 5 là ước lượng nội bộ (10% train) — so với baseline 0.194 để biết có khá hơn không.\n",
524
+ " Điểm DEV thật phải nộp lên CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).\n",
525
+ "- Muốn thử nhanh: đặt `LIMIT_TRAIN = 300` ở cell 0.\n",
526
+ "- Embedding đã cache trong `/kaggle/working/emb_cache/` → **Save Version** để giữ, lần sau train head khỏi trích lại.\n",
527
+ "- Hướng cải tiến tiếp: thêm head QMOS/CAT/VAD dùng chung backbone (exp02 multi-task đầy đủ);\n",
528
+ " thử backbone wav2vec2/WavLM; thêm ranking loss; fine-tune nhẹ backbone.\n",
529
+ "- Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp02)."
530
+ ]
531
+ }
532
+ ],
533
+ "metadata": {
534
+ "jupytext": {
535
+ "cell_metadata_filter": "-all",
536
+ "main_language": "python",
537
+ "notebook_metadata_filter": "-all"
538
+ }
539
+ },
540
+ "nbformat": 4,
541
+ "nbformat_minor": 5
542
+ }
track2/exp02_train_emos_pipeline.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp02 (EMOS có train) — Kaggle
3
+ #
4
+ # **Mục tiêu:** train một model dự đoán **EMOS** (độ khớp cảm xúc target) từ ~12.746 mẫu
5
+ # có nhãn người nghe trong `sets/train.csv`, kỳ vọng **vượt baseline 0.194** (exp01 offline).
6
+ #
7
+ # ## Ý tưởng (đọc 1 lần cho hiểu)
8
+ # EMOS phụ thuộc **cả audio LẪN cảm xúc target** (cùng audio "vui": target=happy → điểm cao,
9
+ # target=sad → điểm thấp). Vì vậy model phải nhận vào cả hai:
10
+ #
11
+ # ```
12
+ # mỗi wav ─► emotion2vec ─► (a) embedding ~D chiều ┐
13
+ # (b) xác suất 5 cảm xúc ├─► nối ─► MLP head ─► EMOS (1–5)
14
+ # target emotion ───► one-hot 5 chiều ┘ (CÁI MÌNH TRAIN)
15
+ # ```
16
+ #
17
+ # - **Backbone emotion2vec ĐÓNG BĂNG** (không train lại) → chỉ trích đặc trưng. Nhẹ GPU, ít data vẫn ổn.
18
+ # - **Chỉ train MLP head nhỏ** → học ánh xạ `(đặc trưng + target) → điểm người chấm`.
19
+ # - **Nhãn vàng** = trung bình `eMOS` của mọi listener trên cùng 1 wav (gộp theo `wavID`).
20
+ # - Embedding **trích 1 lần → cache .npz** (12.746 file rất lâu, chạy lại tốn giờ GPU).
21
+ # - Tách 10% train làm **validation nội bộ** → đo SRCC trong lúc train (DEV không có nhãn để tự chấm).
22
+ # - Cuối cùng xuất `answer.txt` **đầy đủ**: QMOS=SpeechMOS · CAT=emotion2vec · **EMOS=head vừa train** → nộp được ngay.
23
+ #
24
+ # **Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On** → + Add Input dataset
25
+ # Track 2 (15.477 wav, có `sets/train.csv`) → sửa `DATA_ROOT` ở cell 0 → Run All.
26
+
27
+ # %% [markdown]
28
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
29
+
30
+ # %%
31
+ import os, glob, json, time
32
+
33
+ # ── Data Track 2 (dataset 15.477 wav đã ráp, có sets/train.csv) ──────────────
34
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
35
+ WAV_DIR = f"{DATA_ROOT}/wav"
36
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) → target emotion
37
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # nhãn người nghe: lisID,wavID,qMOS,emoCat,eMOS,val,dom,aro
38
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV (tập cần nộp ở training phase)
39
+
40
+ OUT_DIR = "/kaggle/working"
41
+ CACHE_DIR = "/kaggle/working/emb_cache" # nơi lưu embedding đã trích (tái dùng giữa các lần chạy)
42
+ os.makedirs(CACHE_DIR, exist_ok=True)
43
+
44
+ # ── Siêu tham số train (đổi nếu muốn thử nghiệm) ─────────────────────────────
45
+ DEVICE = "cuda" # "cuda" trên Kaggle GPU; "cpu" nếu không có GPU
46
+ HIDDEN = 256 # số neuron lớp ẩn của MLP head
47
+ DROPOUT = 0.3
48
+ LR = 1e-3
49
+ EPOCHS = 60
50
+ BATCH = 64
51
+ VAL_FRAC = 0.10 # 10% train → validation nội bộ (đo SRCC)
52
+ PATIENCE = 12 # early stop: dừng nếu val-SRCC không cải thiện sau N epoch
53
+ SEED = 42
54
+
55
+ LIMIT_TRAIN = None # đặt số nhỏ (vd 300) để chạy thử nhanh; None = full
56
+ USE_CLASSPROB = True # thêm 5 xác suất cảm xúc của emotion2vec vào feature (tín hiệu exp01)
57
+
58
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
59
+
60
+ _EMO_ALIAS = {
61
+ "angry": "angry", "anger": "angry",
62
+ "happy": "happy", "happiness": "happy", "joy": "happy",
63
+ "neutral": "neutral", "calm": "neutral",
64
+ "sad": "sad", "sadness": "sad",
65
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
66
+ }
67
+
68
+ def norm_emotion(label):
69
+ """Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp."""
70
+ key = str(label).strip().lower()
71
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
72
+
73
+ def stem(path_or_name):
74
+ """Lấy tên file không đuôi, để khớp wavID giữa train.csv / metadata / dev.scp."""
75
+ return os.path.splitext(os.path.basename(str(path_or_name)))[0]
76
+
77
+ print("DATA_ROOT:", DATA_ROOT)
78
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
79
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
80
+
81
+ # %% [markdown]
82
+ # ## 1. Cài đặt
83
+
84
+ # %%
85
+ # !pip install -q speechmos funasr librosa soundfile pandas scipy scikit-learn tqdm
86
+
87
+ # %% [markdown]
88
+ # ## 2. Đọc & gộp nhãn
89
+ # - `train.csv`: mỗi dòng = 1 listener chấm 1 wav → **gộp trung bình eMOS theo wavID** = nhãn vàng.
90
+ # - `metadata.csv`: lấy **cảm xúc target** cho mỗi wav (chuẩn hóa về 5 lớp).
91
+
92
+ # %%
93
+ import pandas as pd
94
+
95
+ def load_target_emotions():
96
+ """metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}."""
97
+ tgt = {}
98
+ with open(METADATA_CSV, encoding="utf-8") as f:
99
+ for ln in f:
100
+ parts = ln.strip().split("|")
101
+ if len(parts) < 2:
102
+ continue
103
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
104
+ return tgt
105
+
106
+ def load_train_labels():
107
+ """train.csv → DataFrame [wavID(stem), emos] đã gộp trung bình theo wav."""
108
+ df = pd.read_csv(TRAIN_CSV)
109
+ # Chuẩn hóa tên cột (phòng khi viết hoa/thường khác nhau)
110
+ cols = {c.lower().strip(): c for c in df.columns}
111
+ wav_col = cols.get("wavid") or cols.get("wav") or list(df.columns)[1]
112
+ emos_col = cols.get("emos") or cols.get("emo") or cols.get("emomos")
113
+ assert emos_col, f"Không thấy cột eMOS trong train.csv (cột hiện có: {list(df.columns)})"
114
+ g = df.groupby(df[wav_col].map(stem))[emos_col].mean()
115
+ out = g.reset_index()
116
+ out.columns = ["wavID", "emos"]
117
+ return out
118
+
119
+ target_map = load_target_emotions()
120
+ train_df = load_train_labels()
121
+ print(f"Target emotions: {len(target_map)} | wav train (đã gộp): {len(train_df)}")
122
+ print("eMOS thống kê:", train_df["emos"].describe()[["mean", "std", "min", "max"]].to_dict())
123
+ train_df.head()
124
+
125
+ # %% [markdown]
126
+ # ## 3. Trích đặc trưng emotion2vec (có cache)
127
+ # Mỗi wav → 1 lần `generate(extract_embedding=True)` cho ra **embedding** (cho EMOS) +
128
+ # **xác suất 5 lớp** (cho CAT và làm feature). Lưu cache `.npz` để lần sau khỏi chạy lại.
129
+
130
+ # %%
131
+ import numpy as np
132
+
133
+ _e2v_model = None
134
+ def get_e2v():
135
+ global _e2v_model
136
+ if _e2v_model is None:
137
+ from funasr import AutoModel
138
+ _e2v_model = AutoModel(model="iic/emotion2vec_plus_large", hub="hf")
139
+ return _e2v_model
140
+
141
+ def extract_one(wav_path):
142
+ """→ (emb: np.float32[D], probs5: np.float32[5] tổng=1). None nếu lỗi/thiếu file."""
143
+ if not os.path.exists(wav_path):
144
+ return None
145
+ rec = get_e2v().generate(wav_path, granularity="utterance", extract_embedding=True)
146
+ r = rec[0]
147
+ emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
148
+ probs = {e: 0.0 for e in EMOTIONS5}
149
+ for lab, sc in zip(r["labels"], r["scores"]):
150
+ name = lab.split("/")[-1]
151
+ if name in probs:
152
+ probs[name] = float(sc)
153
+ tot = sum(probs.values())
154
+ if tot > 0:
155
+ probs = {k: v / tot for k, v in probs.items()}
156
+ probs5 = np.array([probs[e] for e in EMOTIONS5], dtype=np.float32)
157
+ return emb, probs5
158
+
159
+ def extract_set(stems, tag):
160
+ """Trích (hoặc nạp cache) cho danh sách stem. Trả về dict {stem: (emb, probs5)}.
161
+ Cache lưu tại CACHE_DIR/<tag>.npz; tự bỏ qua stem đã có để chạy nối tiếp được."""
162
+ from tqdm.auto import tqdm
163
+ cache_path = os.path.join(CACHE_DIR, f"{tag}.npz")
164
+ store = {}
165
+ if os.path.exists(cache_path):
166
+ z = np.load(cache_path, allow_pickle=True)
167
+ store = {k: z[k] for k in z.files}
168
+ print(f"[{tag}] nạp cache: {len(store)} mẫu")
169
+ todo = [s for s in stems if s not in store]
170
+ if not todo:
171
+ print(f"[{tag}] đủ cache, bỏ qua trích.")
172
+ else:
173
+ miss = 0
174
+ for i, s in enumerate(tqdm(todo, desc=f"trích {tag}")):
175
+ res = extract_one(os.path.join(WAV_DIR, s + ".wav"))
176
+ if res is None:
177
+ miss += 1
178
+ continue
179
+ emb, probs5 = res
180
+ store[s] = np.concatenate([emb, probs5]).astype(np.float32) # [D + 5]
181
+ if (i + 1) % 500 == 0: # lưu cache định kỳ phòng ngắt session
182
+ np.savez(cache_path, **store)
183
+ np.savez(cache_path, **store)
184
+ if miss:
185
+ print(f"[{tag}] {miss} file thiếu/ lỗi → bỏ qua.")
186
+ print(f"[{tag}] tổng cache: {len(store)} mẫu → {cache_path}")
187
+ # tách lại thành (emb, probs5)
188
+ out = {}
189
+ for s, vec in store.items():
190
+ out[s] = (vec[:-5], vec[-5:])
191
+ return out
192
+
193
+ # Trích cho tập train
194
+ train_stems = list(train_df["wavID"])
195
+ if LIMIT_TRAIN:
196
+ train_stems = train_stems[:LIMIT_TRAIN]
197
+ train_feat = extract_set(train_stems, "train")
198
+ EMB_DIM = next(iter(train_feat.values()))[0].shape[0]
199
+ print("EMB_DIM =", EMB_DIM)
200
+
201
+ # %% [markdown]
202
+ # ## 4. Dựng feature + nhãn cho train
203
+ # Feature mỗi wav = `[embedding | (probs5 nếu bật) | one-hot target(5)]`. Bỏ wav thiếu target/feature.
204
+
205
+ # %%
206
+ def onehot_target(tgt):
207
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
208
+ if tgt in EMOTIONS5:
209
+ v[EMOTIONS5.index(tgt)] = 1.0
210
+ return v
211
+
212
+ def build_feature(stem_id, feat_map):
213
+ pack = feat_map.get(stem_id)
214
+ if pack is None:
215
+ return None
216
+ emb, probs5 = pack
217
+ tgt = target_map.get(stem_id)
218
+ if tgt is None: # không biết cảm xúc target → không train được mẫu này
219
+ return None
220
+ parts = [emb]
221
+ if USE_CLASSPROB:
222
+ parts.append(probs5)
223
+ parts.append(onehot_target(tgt))
224
+ return np.concatenate(parts).astype(np.float32)
225
+
226
+ emos_label = dict(zip(train_df["wavID"], train_df["emos"]))
227
+ X, y = [], []
228
+ for s in train_stems:
229
+ f = build_feature(s, train_feat)
230
+ if f is None or s not in emos_label:
231
+ continue
232
+ X.append(f); y.append(emos_label[s])
233
+ X = np.stack(X); y = np.array(y, dtype=np.float32)
234
+ FEAT_DIM = X.shape[1]
235
+ print(f"Train: X={X.shape} y={y.shape} FEAT_DIM={FEAT_DIM}")
236
+
237
+ # Chuẩn hóa feature (z-score) — lưu mean/std để áp dụng y hệt lúc dự đoán DEV.
238
+ feat_mean = X.mean(0, keepdims=True)
239
+ feat_std = X.std(0, keepdims=True) + 1e-6
240
+ Xn = (X - feat_mean) / feat_std
241
+
242
+ # %% [markdown]
243
+ # ## 5. Model (MLP head) + train loop
244
+ # Loss = MSE. Theo dõi **SRCC** trên validation nội bộ; lưu model tốt nhất (early stopping).
245
+
246
+ # %%
247
+ import torch, torch.nn as nn
248
+ from scipy.stats import spearmanr
249
+ from sklearn.model_selection import train_test_split
250
+
251
+ torch.manual_seed(SEED); np.random.seed(SEED)
252
+ device = DEVICE if torch.cuda.is_available() else "cpu"
253
+ print("Device:", device)
254
+
255
+ Xtr, Xva, ytr, yva = train_test_split(Xn, y, test_size=VAL_FRAC, random_state=SEED)
256
+ Xtr_t = torch.tensor(Xtr, device=device); ytr_t = torch.tensor(ytr, device=device).unsqueeze(1)
257
+ Xva_t = torch.tensor(Xva, device=device); yva_t = torch.tensor(yva, device=device).unsqueeze(1)
258
+
259
+ class EmosHead(nn.Module):
260
+ def __init__(self, d_in, hidden, p):
261
+ super().__init__()
262
+ self.net = nn.Sequential(
263
+ nn.Linear(d_in, hidden), nn.ReLU(), nn.Dropout(p),
264
+ nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Dropout(p),
265
+ nn.Linear(hidden // 2, 1),
266
+ )
267
+ def forward(self, x):
268
+ return self.net(x)
269
+
270
+ model = EmosHead(FEAT_DIM, HIDDEN, DROPOUT).to(device)
271
+ opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
272
+ lossf = nn.MSELoss()
273
+
274
+ def val_srcc():
275
+ model.eval()
276
+ with torch.no_grad():
277
+ pred = model(Xva_t).cpu().numpy().ravel()
278
+ return spearmanr(pred, yva).correlation
279
+
280
+ best_srcc, best_state, bad = -1.0, None, 0
281
+ n = Xtr_t.shape[0]
282
+ for ep in range(1, EPOCHS + 1):
283
+ model.train()
284
+ perm = torch.randperm(n, device=device)
285
+ tot = 0.0
286
+ for i in range(0, n, BATCH):
287
+ idx = perm[i:i + BATCH]
288
+ opt.zero_grad()
289
+ out = model(Xtr_t[idx])
290
+ loss = lossf(out, ytr_t[idx])
291
+ loss.backward(); opt.step()
292
+ tot += loss.item() * len(idx)
293
+ srcc = val_srcc()
294
+ if srcc > best_srcc:
295
+ best_srcc, best_state, bad = srcc, {k: v.cpu().clone() for k, v in model.state_dict().items()}, 0
296
+ else:
297
+ bad += 1
298
+ if ep % 5 == 0 or ep == 1:
299
+ print(f"epoch {ep:3d} | train MSE {tot/n:.4f} | val SRCC {srcc:.4f} | best {best_srcc:.4f}")
300
+ if bad >= PATIENCE:
301
+ print(f"Early stop ở epoch {ep} (val SRCC không tăng {PATIENCE} epoch).")
302
+ break
303
+
304
+ model.load_state_dict(best_state)
305
+ print(f"\n✅ VAL SRCC tốt nhất = {best_srcc:.4f} (baseline exp01 ≈ 0.194 — so ở đây)")
306
+
307
+ # Lưu model + tham số chuẩn hóa để tái dùng / mô tả hệ thống.
308
+ torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
309
+ "EMB_DIM": EMB_DIM, "FEAT_DIM": FEAT_DIM, "USE_CLASSPROB": USE_CLASSPROB,
310
+ "EMOTIONS5": EMOTIONS5, "val_srcc": float(best_srcc)},
311
+ os.path.join(OUT_DIR, "emos_head.pt"))
312
+ print("Đã lưu", os.path.join(OUT_DIR, "emos_head.pt"))
313
+
314
+ # %% [markdown]
315
+ # ## 6. Dự đoán DEV → `answer.txt` đầy đủ
316
+ # - **EMOS** = head vừa train (cần embedding + target của từng wav DEV).
317
+ # - **CAT** = xác suất 5 lớp emotion2vec (đã có sẵn khi trích đặc trưng).
318
+ # - **QMOS** = SpeechMOS (UTMOS) — bắt buộc, chạy thêm ở đây để answer.txt hợp lệ.
319
+
320
+ # %%
321
+ def list_dev():
322
+ with open(DEV_SCP) as f:
323
+ return [ln.strip() for ln in f if ln.strip()] # tên file .wav
324
+
325
+ dev_names = list_dev()
326
+ dev_stems = [stem(n) for n in dev_names]
327
+ print("DEV:", len(dev_names), "mẫu")
328
+
329
+ # 6a. Trích đặc trưng emotion2vec cho DEV (cache riêng)
330
+ dev_feat = extract_set(dev_stems, "dev")
331
+
332
+ # 6b. EMOS từ head đã train
333
+ def predict_emos(stem_id):
334
+ f = build_feature(stem_id, dev_feat)
335
+ if f is None:
336
+ return None
337
+ fn = (f[None, :] - feat_mean) / feat_std
338
+ model.eval()
339
+ with torch.no_grad():
340
+ return float(model(torch.tensor(fn, dtype=torch.float32, device=device)).item())
341
+
342
+ # 6c. QMOS = SpeechMOS
343
+ def run_qmos(names):
344
+ import librosa
345
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
346
+ out = {}
347
+ from tqdm.auto import tqdm
348
+ for n in tqdm(names, desc="QMOS"):
349
+ p = os.path.join(WAV_DIR, n)
350
+ if not os.path.exists(p):
351
+ continue
352
+ wave, _ = librosa.load(p, sr=16000, mono=True)
353
+ out[n] = float(predictor(torch.from_numpy(wave).unsqueeze(0), sr=16000).mean().item())
354
+ return out
355
+
356
+ qmos_scores = run_qmos(dev_names)
357
+
358
+ # %%
359
+ def fmt_cat(probs5):
360
+ return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
361
+
362
+ def build_answer(out_path):
363
+ n_emos = n_default = 0
364
+ with open(out_path, "w") as f:
365
+ f.write("wav,QMOS,EMOS,CAT\n")
366
+ for name in dev_names:
367
+ sid = stem(name)
368
+ emos = predict_emos(sid)
369
+ if emos is None:
370
+ emos = 3.0; n_default += 1
371
+ else:
372
+ n_emos += 1
373
+ qmos = qmos_scores.get(name, 3.0)
374
+ probs5 = dev_feat[sid][1] if sid in dev_feat else np.full(5, 0.2, dtype=np.float32)
375
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(probs5)}\n")
376
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default}")
377
+
378
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
379
+ build_answer(answer_path)
380
+
381
+ # %% [markdown]
382
+ # ## 7. Validate + đóng zip
383
+
384
+ # %%
385
+ def validate(path):
386
+ import csv
387
+ with open(path) as f:
388
+ rows = list(csv.reader(f))
389
+ header = rows[0]
390
+ assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
391
+ for i, r in enumerate(rows[1:], 2):
392
+ assert len(r) == len(header), f"Dòng {i} sai số cột"
393
+ print(f"OK: {len(rows)-1} dòng, header = {header}")
394
+
395
+ validate(answer_path)
396
+ # !cd /kaggle/working && zip -j submission_track2_exp02.zip answer.txt && unzip -l submission_track2_exp02.zip
397
+ print("Sẵn sàng nộp: /kaggle/working/submission_track2_exp02.zip")
398
+
399
+ # %% [markdown]
400
+ # ## Ghi chú
401
+ # - **VAL SRCC** in ở mục 5 là ước lượng nội bộ (10% train) — so với baseline 0.194 để biết có khá hơn không.
402
+ # Điểm DEV thật phải nộp lên CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).
403
+ # - Muốn thử nhanh: đặt `LIMIT_TRAIN = 300` ở cell 0.
404
+ # - Embedding đã cache trong `/kaggle/working/emb_cache/` → **Save Version** để giữ, lần sau train head khỏi trích lại.
405
+ # - Hướng cải tiến tiếp: thêm head QMOS/CAT/VAD dùng chung backbone (exp02 multi-task đầy đủ);
406
+ # thử backbone wav2vec2/WavLM; thêm ranking loss; fine-tune nhẹ backbone.
407
+ # - Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp02).
track2/exp03_emos_sailer.ipynb ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a6ae46f8",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp03 (EMOS bằng SAILER, offline) — Kaggle\n",
9
+ "\n",
10
+ "**Mục tiêu:** chấm **EMOS** (độ khớp cảm xúc target) bằng model **SAILER**\n",
11
+ "(`tiantiaf/wavlm-large-categorical-emotion`, vô địch Interspeech 2025 SER),\n",
12
+ "thay cho emotion2vec — KHÔNG train, chỉ lấy xác suất lớp cảm xúc target.\n",
13
+ "\n",
14
+ "## Ý tưởng (đọc 1 lần cho hiểu)\n",
15
+ "SAILER nhận 1 wav → xuất **logits 9 lớp cảm xúc** → softmax → **xác suất từng lớp**.\n",
16
+ "EMOS = mức khớp cảm xúc target → lấy thẳng **P(cảm xúc target)** rồi kéo về thang 1–5:\n",
17
+ "\n",
18
+ "```\n",
19
+ "mỗi wav ─► SAILER (WavLM-large) ─► softmax 9 lớp ─┬─► P(target) ─► EMOS = 1 + 4·P\n",
20
+ " └─► 5 lớp (renorm) ─► CAT\n",
21
+ " target emotion (metadata.csv) ─────────────────┘\n",
22
+ "```\n",
23
+ "\n",
24
+ "- **9 lớp SAILER:** `Anger, Contempt, Disgust, Fear, Happiness, Neutral, Sadness, Surprise, Other`.\n",
25
+ " → đủ cả 5 lớp challenge (angry/happy/neutral/sad/surprised).\n",
26
+ "- **EMOS** = `1 + 4·P(target)` (scale [0,1]→[1,5]); SRCC bất biến với scale tuyến tính.\n",
27
+ "- **CAT** = lấy xác suất 5 lớp challenge từ chính SAILER (renormalize tổng=1).\n",
28
+ "- **VAD** = arousal/valence/dominance SAILER xuất sẵn (sigmoid 0–1 → 1–5) → 1 model lo EMOS+CAT+VAD!\n",
29
+ "- **QMOS** = SpeechMOS (UTMOS) — bắt buộc để `answer.txt` hợp lệ.\n",
30
+ "- KHÔNG train → nộp được ngay. So điểm EMOS với baseline emotion2vec (0.194) và exp01.\n",
31
+ "\n",
32
+ "**Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**\n",
33
+ "→ + Add Input dataset Track 2 (15.477 wav, có `sets/dev.scp`, `metadata.csv`)\n",
34
+ "→ sửa `DATA_ROOT` ở cell 0 → Run All.\n",
35
+ "\n",
36
+ "⚠️ License SAILER = **Open RAIL** (phi thương mại) → phải khai báo trong `docs/12_system_description.md`."
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "id": "f25f6ed7",
42
+ "metadata": {},
43
+ "source": [
44
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "4cb1fd8c",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "import os\n",
55
+ "\n",
56
+ "# ── Data Track 2 (dataset 15.477 wav đã ráp) ────────────────────────────────\n",
57
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
58
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
59
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) → target emotion\n",
60
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV (tập cần nộp ở training phase)\n",
61
+ "\n",
62
+ "OUT_DIR = \"/kaggle/working\"\n",
63
+ "\n",
64
+ "DEVICE = \"cuda\" # \"cuda\" trên Kaggle GPU; \"cpu\" nếu không có GPU\n",
65
+ "MAX_SECONDS = 15 # SAILER nhận tối đa 15s (giới hạn của model)\n",
66
+ "SR = 16000 # SAILER cần 16kHz mono\n",
67
+ "LIMIT = None # đặt số nhỏ (vd 20) để chạy thử nhanh; None = full DEV\n",
68
+ "\n",
69
+ "# 5 lớp cảm xúc challenge (thứ tự cố định cho cột CAT)\n",
70
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
71
+ "\n",
72
+ "# 9 lớp SAILER (đúng thứ tự model xuất) + chỉ số của 5 lớp challenge trong đó\n",
73
+ "SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
74
+ "EMO2SAILER = {\"angry\": 0, \"happy\": 4, \"neutral\": 5, \"sad\": 6, \"surprised\": 7} # EMOTIONS5 → index trong SAILER9\n",
75
+ "\n",
76
+ "_EMO_ALIAS = {\n",
77
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
78
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
79
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
80
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
81
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
82
+ "}\n",
83
+ "\n",
84
+ "def norm_emotion(label):\n",
85
+ " \"\"\"Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp.\"\"\"\n",
86
+ " key = str(label).strip().lower()\n",
87
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
88
+ "\n",
89
+ "def stem(path_or_name):\n",
90
+ " return os.path.splitext(os.path.basename(str(path_or_name)))[0]\n",
91
+ "\n",
92
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
93
+ "for p in [WAV_DIR, METADATA_CSV, DEV_SCP]:\n",
94
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "id": "18c48274",
100
+ "metadata": {},
101
+ "source": [
102
+ "## 1. Cài đặt + tải code SAILER\n",
103
+ "SAILER cần file `WavLMWrapper` trong repo `vox-profile-release`.\n",
104
+ "⚠️ **KHÔNG** `pip install -e .` (build wheel của repo hay lỗi trên Kaggle). Thay vào đó:\n",
105
+ "chỉ **clone + thêm repo vào `sys.path`** rồi cài đúng vài thư viện model cần\n",
106
+ "(`transformers/torch/huggingface_hub` Kaggle đã có sẵn; chỉ thiếu `loralib`, `speechbrain`)."
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "id": "bd8f98d9",
113
+ "metadata": {},
114
+ "outputs": [],
115
+ "source": [
116
+ "import sys, subprocess\n",
117
+ "\n",
118
+ "def pip_install(*pkgs):\n",
119
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
120
+ "\n",
121
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
122
+ "if not os.path.exists(REPO_DIR):\n",
123
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
124
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
125
+ "\n",
126
+ "# Deps mà WavLMWrapper cần (xem import trong src/model/emotion/wavlm_emotion.py) + thư viện chấm QMOS.\n",
127
+ "pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\", \"scipy\", \"tqdm\")\n",
128
+ "\n",
129
+ "if REPO_DIR not in sys.path:\n",
130
+ " sys.path.insert(0, REPO_DIR) # để `from src.model.emotion... import WavLMWrapper` chạy được"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "id": "00a49544",
136
+ "metadata": {},
137
+ "source": [
138
+ "## 2. Nạp model SAILER"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": null,
144
+ "id": "2756567f",
145
+ "metadata": {
146
+ "lines_to_next_cell": 1
147
+ },
148
+ "outputs": [],
149
+ "source": [
150
+ "import torch\n",
151
+ "import torch.nn.functional as F\n",
152
+ "\n",
153
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
154
+ "print(\"Device:\", device)\n",
155
+ "\n",
156
+ "from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
157
+ "\n",
158
+ "sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device)\n",
159
+ "sailer.eval()\n",
160
+ "print(\"✅ Đã nạp SAILER (wavlm-large-categorical-emotion)\")"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "6c7b6e84",
166
+ "metadata": {},
167
+ "source": [
168
+ "## 3. Đọc cảm xúc target cho mỗi wav (từ metadata.csv)"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "id": "8e6f4b36",
175
+ "metadata": {},
176
+ "outputs": [],
177
+ "source": [
178
+ "def load_target_emotions():\n",
179
+ " \"\"\"metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}.\"\"\"\n",
180
+ " tgt = {}\n",
181
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
182
+ " for ln in f:\n",
183
+ " parts = ln.strip().split(\"|\")\n",
184
+ " if len(parts) < 2:\n",
185
+ " continue\n",
186
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
187
+ " return tgt\n",
188
+ "\n",
189
+ "target_map = load_target_emotions()\n",
190
+ "print(f\"Target emotions: {len(target_map)} wav | ví dụ:\", dict(list(target_map.items())[:3]))"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "markdown",
195
+ "id": "7467cfc6",
196
+ "metadata": {},
197
+ "source": [
198
+ "## 4. Hàm chấm 1 wav bằng SAILER → xác suất 9 lớp + VAD\n",
199
+ "WavLMWrapper khi `return_feature=True` trả **6 giá trị**:\n",
200
+ "`predicted(logits 9 lớp), features, detailed_logits, arousal, valence, dominance` (VAD sigmoid 0–1).\n",
201
+ "→ 1 model lo cả **EMOS** (P target), **CAT** (5 lớp renorm) **và VAD** (mở 3 cột đang trống!)."
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "id": "d3bb3d01",
208
+ "metadata": {
209
+ "lines_to_next_cell": 1
210
+ },
211
+ "outputs": [],
212
+ "source": [
213
+ "import numpy as np\n",
214
+ "import librosa\n",
215
+ "\n",
216
+ "@torch.no_grad()\n",
217
+ "def sailer_infer(wav_path):\n",
218
+ " \"\"\"→ (probs9: float32[9], vad3: float32[3] theo thứ tự [VAL,ARO,DOM] thang 1–5);\n",
219
+ " None nếu thiếu/lỗi file.\"\"\"\n",
220
+ " if not os.path.exists(wav_path):\n",
221
+ " return None\n",
222
+ " wave, _ = librosa.load(wav_path, sr=SR, mono=True)\n",
223
+ " wave = wave[: MAX_SECONDS * SR] # cắt tối đa 15s\n",
224
+ " data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
225
+ " logits, _feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
226
+ " probs9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
227
+ " # VAD sigmoid [0,1] → thang 1–5 cho khớp ví dụ BTC (SRCC bất biến với scale tuyến tính)\n",
228
+ " v, a, d = float(valence.item()), float(arousal.item()), float(dominance.item())\n",
229
+ " vad3 = np.array([1 + 4 * v, 1 + 4 * a, 1 + 4 * d], dtype=np.float32) # [VAL, ARO, DOM]\n",
230
+ " return probs9, vad3\n",
231
+ "\n",
232
+ "def emos_from_probs(probs9, target):\n",
233
+ " \"\"\"EMOS = 1 + 4·P(target). None nếu không biết target → để caller xử lý mặc định.\"\"\"\n",
234
+ " if target is None or target not in EMO2SAILER:\n",
235
+ " return None\n",
236
+ " return 1.0 + 4.0 * float(probs9[EMO2SAILER[target]])\n",
237
+ "\n",
238
+ "def cat5_from_probs(probs9):\n",
239
+ " \"\"\"Lấy 5 lớp challenge từ 9 lớp SAILER rồi renormalize tổng=1.\"\"\"\n",
240
+ " v = np.array([probs9[EMO2SAILER[e]] for e in EMOTIONS5], dtype=np.float32)\n",
241
+ " s = v.sum()\n",
242
+ " return v / s if s > 0 else np.full(5, 0.2, dtype=np.float32)"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "id": "14f8e54a",
248
+ "metadata": {},
249
+ "source": [
250
+ "## 5. QMOS = SpeechMOS (UTMOS) — bắt buộc cho answer.txt"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": null,
256
+ "id": "992bd84b",
257
+ "metadata": {
258
+ "lines_to_next_cell": 1
259
+ },
260
+ "outputs": [],
261
+ "source": [
262
+ "@torch.no_grad()\n",
263
+ "def run_qmos(names):\n",
264
+ " predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True).to(device).eval()\n",
265
+ " from tqdm.auto import tqdm\n",
266
+ " out = {}\n",
267
+ " for n in tqdm(names, desc=\"QMOS\"):\n",
268
+ " p = os.path.join(WAV_DIR, n)\n",
269
+ " if not os.path.exists(p):\n",
270
+ " continue\n",
271
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
272
+ " x = torch.from_numpy(wave).unsqueeze(0).to(device) # đẩy input lên GPU\n",
273
+ " out[n] = float(predictor(x, sr=SR).mean().item())\n",
274
+ " return out"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "id": "58afec1d",
280
+ "metadata": {},
281
+ "source": [
282
+ "## 6. Chạy trên DEV → `answer.txt` đầy đủ (QMOS, EMOS, CAT, VAL, ARO, DOM)"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "id": "77b1eb8c",
289
+ "metadata": {
290
+ "lines_to_next_cell": 1
291
+ },
292
+ "outputs": [],
293
+ "source": [
294
+ "def list_dev():\n",
295
+ " with open(DEV_SCP) as f:\n",
296
+ " return [ln.strip() for ln in f if ln.strip()]\n",
297
+ "\n",
298
+ "dev_names = list_dev()\n",
299
+ "if LIMIT:\n",
300
+ " dev_names = dev_names[:LIMIT]\n",
301
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
302
+ "\n",
303
+ "qmos_scores = run_qmos(dev_names)\n",
304
+ "\n",
305
+ "def fmt_cat(probs5):\n",
306
+ " return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
307
+ "\n",
308
+ "def build_answer(out_path):\n",
309
+ " from tqdm.auto import tqdm\n",
310
+ " n_emos = n_default = 0\n",
311
+ " with open(out_path, \"w\") as f:\n",
312
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
313
+ " for name in tqdm(dev_names, desc=\"SAILER EMOS/CAT/VAD\"):\n",
314
+ " sid = stem(name)\n",
315
+ " out = sailer_infer(os.path.join(WAV_DIR, name))\n",
316
+ " if out is None:\n",
317
+ " emos, cat5 = 3.0, np.full(5, 0.2, dtype=np.float32)\n",
318
+ " vad3 = np.array([3.0, 3.0, 3.0], dtype=np.float32)\n",
319
+ " n_default += 1\n",
320
+ " else:\n",
321
+ " probs9, vad3 = out\n",
322
+ " emos = emos_from_probs(probs9, target_map.get(sid))\n",
323
+ " if emos is None:\n",
324
+ " emos = 3.0; n_default += 1\n",
325
+ " else:\n",
326
+ " n_emos += 1\n",
327
+ " cat5 = cat5_from_probs(probs9)\n",
328
+ " qmos = qmos_scores.get(name, 3.0)\n",
329
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
330
+ " f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
331
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default}\")\n",
332
+ "\n",
333
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
334
+ "build_answer(answer_path)"
335
+ ]
336
+ },
337
+ {
338
+ "cell_type": "markdown",
339
+ "id": "5e3efd8d",
340
+ "metadata": {},
341
+ "source": [
342
+ "## 7. Validate + đóng zip"
343
+ ]
344
+ },
345
+ {
346
+ "cell_type": "code",
347
+ "execution_count": null,
348
+ "id": "f816cbb2",
349
+ "metadata": {},
350
+ "outputs": [],
351
+ "source": [
352
+ "def validate(path):\n",
353
+ " import csv\n",
354
+ " with open(path) as f:\n",
355
+ " rows = list(csv.reader(f))\n",
356
+ " header = rows[0]\n",
357
+ " assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
358
+ " for i, r in enumerate(rows[1:], 2):\n",
359
+ " assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
360
+ " print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
361
+ "\n",
362
+ "validate(answer_path)\n",
363
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp03_sailer.zip answer.txt && unzip -l submission_track2_exp03_sailer.zip\")\n",
364
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp03_sailer.zip\"))"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "markdown",
369
+ "id": "8ca84ef6",
370
+ "metadata": {},
371
+ "source": [
372
+ "## Ghi chú\n",
373
+ "- **Chưa chạy thật bao giờ** → lần đầu đặt `LIMIT = 20` ở cell 0 để bắt lỗi setup (clone repo / import / model).\n",
374
+ "- Điểm DEV thật phải nộp lên CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).\n",
375
+ "- Notebook này đổi **EMOS + CAT + VAD** sang SAILER (1 model lo 6 cột metric). QMOS vẫn SpeechMOS cũ.\n",
376
+ " Muốn ablation EMOS sạch (giữ CAT=emotion2vec) thì chỉ lấy cột EMOS từ đây, ghép với CAT của `track2_baseline`.\n",
377
+ "- Rủi ro setup duy nhất = import `src.model.emotion.wavlm_emotion` (cần repo vox-profile-release).\n",
378
+ " Nếu lỗi import: kiểm tra `REPO_DIR` đã clone + `sys.path` đã thêm REPO_DIR (KHÔNG dùng pip install -e .).\n",
379
+ "- Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp03)."
380
+ ]
381
+ }
382
+ ],
383
+ "metadata": {
384
+ "jupytext": {
385
+ "cell_metadata_filter": "-all",
386
+ "main_language": "python",
387
+ "notebook_metadata_filter": "-all"
388
+ }
389
+ },
390
+ "nbformat": 4,
391
+ "nbformat_minor": 5
392
+ }
track2/exp03_emos_sailer_pipeline.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp03 (EMOS bằng SAILER, offline) — Kaggle
3
+ #
4
+ # **Mục tiêu:** chấm **EMOS** (độ khớp cảm xúc target) bằng model **SAILER**
5
+ # (`tiantiaf/wavlm-large-categorical-emotion`, vô địch Interspeech 2025 SER),
6
+ # thay cho emotion2vec — KHÔNG train, chỉ lấy xác suất lớp cảm xúc target.
7
+ #
8
+ # ## Ý tưởng (đọc 1 lần cho hiểu)
9
+ # SAILER nhận 1 wav → xuất **logits 9 lớp cảm xúc** → softmax → **xác suất từng lớp**.
10
+ # EMOS = mức khớp cảm xúc target → lấy thẳng **P(cảm xúc target)** rồi kéo về thang 1–5:
11
+ #
12
+ # ```
13
+ # mỗi wav ─► SAILER (WavLM-large) ─► softmax 9 lớp ─┬─► P(target) ─► EMOS = 1 + 4·P
14
+ # └─► 5 lớp (renorm) ─► CAT
15
+ # target emotion (metadata.csv) ─────────────────┘
16
+ # ```
17
+ #
18
+ # - **9 lớp SAILER:** `Anger, Contempt, Disgust, Fear, Happiness, Neutral, Sadness, Surprise, Other`.
19
+ # → đủ cả 5 lớp challenge (angry/happy/neutral/sad/surprised).
20
+ # - **EMOS** = `1 + 4·P(target)` (scale [0,1]→[1,5]); SRCC bất biến với scale tuyến tính.
21
+ # - **CAT** = lấy xác suất 5 lớp challenge từ chính SAILER (renormalize tổng=1).
22
+ # - **VAD** = arousal/valence/dominance SAILER xuất sẵn (sigmoid 0–1 → 1–5) → 1 model lo EMOS+CAT+VAD!
23
+ # - **QMOS** = SpeechMOS (UTMOS) — bắt buộc để `answer.txt` hợp lệ.
24
+ # - KHÔNG train → nộp được ngay. So điểm EMOS với baseline emotion2vec (0.194) và exp01.
25
+ #
26
+ # **Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**
27
+ # → + Add Input dataset Track 2 (15.477 wav, có `sets/dev.scp`, `metadata.csv`)
28
+ # → sửa `DATA_ROOT` ở cell 0 → Run All.
29
+ #
30
+ # ⚠️ License SAILER = **Open RAIL** (phi thương mại) → phải khai báo trong `docs/12_system_description.md`.
31
+
32
+ # %% [markdown]
33
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
34
+
35
+ # %%
36
+ import os
37
+
38
+ # ── Data Track 2 (dataset 15.477 wav đã ráp) ────────────────────────────────
39
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
40
+ WAV_DIR = f"{DATA_ROOT}/wav"
41
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) → target emotion
42
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV (tập cần nộp ở training phase)
43
+
44
+ OUT_DIR = "/kaggle/working"
45
+
46
+ DEVICE = "cuda" # "cuda" trên Kaggle GPU; "cpu" nếu không có GPU
47
+ MAX_SECONDS = 15 # SAILER nhận tối đa 15s (giới hạn của model)
48
+ SR = 16000 # SAILER cần 16kHz mono
49
+ LIMIT = None # đặt số nhỏ (vd 20) để chạy thử nhanh; None = full DEV
50
+
51
+ # 5 lớp cảm xúc challenge (thứ tự cố định cho cột CAT)
52
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
53
+
54
+ # 9 lớp SAILER (đúng thứ tự model xuất) + chỉ số của 5 lớp challenge trong đó
55
+ SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
56
+ EMO2SAILER = {"angry": 0, "happy": 4, "neutral": 5, "sad": 6, "surprised": 7} # EMOTIONS5 → index trong SAILER9
57
+
58
+ _EMO_ALIAS = {
59
+ "angry": "angry", "anger": "angry",
60
+ "happy": "happy", "happiness": "happy", "joy": "happy",
61
+ "neutral": "neutral", "calm": "neutral",
62
+ "sad": "sad", "sadness": "sad",
63
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
64
+ }
65
+
66
+ def norm_emotion(label):
67
+ """Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp."""
68
+ key = str(label).strip().lower()
69
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
70
+
71
+ def stem(path_or_name):
72
+ return os.path.splitext(os.path.basename(str(path_or_name)))[0]
73
+
74
+ print("DATA_ROOT:", DATA_ROOT)
75
+ for p in [WAV_DIR, METADATA_CSV, DEV_SCP]:
76
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
77
+
78
+ # %% [markdown]
79
+ # ## 1. Cài đặt + tải code SAILER
80
+ # SAILER cần file `WavLMWrapper` trong repo `vox-profile-release`.
81
+ # ⚠️ **KHÔNG** `pip install -e .` (build wheel của repo hay lỗi trên Kaggle). Thay vào đó:
82
+ # chỉ **clone + thêm repo vào `sys.path`** rồi cài đúng vài thư viện model cần
83
+ # (`transformers/torch/huggingface_hub` Kaggle đã có sẵn; chỉ thiếu `loralib`, `speechbrain`).
84
+
85
+ # %%
86
+ import sys, subprocess
87
+
88
+ def pip_install(*pkgs):
89
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
90
+
91
+ REPO_DIR = "/kaggle/working/vox-profile-release"
92
+ if not os.path.exists(REPO_DIR):
93
+ subprocess.run(["git", "clone", "--depth", "1",
94
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
95
+
96
+ # Deps mà WavLMWrapper cần (xem import trong src/model/emotion/wavlm_emotion.py) + thư viện chấm QMOS.
97
+ pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile", "scipy", "tqdm")
98
+
99
+ if REPO_DIR not in sys.path:
100
+ sys.path.insert(0, REPO_DIR) # để `from src.model.emotion... import WavLMWrapper` chạy được
101
+
102
+ # %% [markdown]
103
+ # ## 2. Nạp model SAILER
104
+
105
+ # %%
106
+ import torch
107
+ import torch.nn.functional as F
108
+
109
+ device = DEVICE if torch.cuda.is_available() else "cpu"
110
+ print("Device:", device)
111
+
112
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
113
+
114
+ sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device)
115
+ sailer.eval()
116
+ print("✅ Đã nạp SAILER (wavlm-large-categorical-emotion)")
117
+
118
+ # %% [markdown]
119
+ # ## 3. Đọc cảm xúc target cho mỗi wav (từ metadata.csv)
120
+
121
+ # %%
122
+ def load_target_emotions():
123
+ """metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}."""
124
+ tgt = {}
125
+ with open(METADATA_CSV, encoding="utf-8") as f:
126
+ for ln in f:
127
+ parts = ln.strip().split("|")
128
+ if len(parts) < 2:
129
+ continue
130
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
131
+ return tgt
132
+
133
+ target_map = load_target_emotions()
134
+ print(f"Target emotions: {len(target_map)} wav | ví dụ:", dict(list(target_map.items())[:3]))
135
+
136
+ # %% [markdown]
137
+ # ## 4. Hàm chấm 1 wav bằng SAILER → xác suất 9 lớp + VAD
138
+ # WavLMWrapper khi `return_feature=True` trả **6 giá trị**:
139
+ # `predicted(logits 9 lớp), features, detailed_logits, arousal, valence, dominance` (VAD sigmoid 0–1).
140
+ # → 1 model lo cả **EMOS** (P target), **CAT** (5 lớp renorm) **và VAD** (mở 3 cột đang trống!).
141
+
142
+ # %%
143
+ import numpy as np
144
+ import librosa
145
+
146
+ @torch.no_grad()
147
+ def sailer_infer(wav_path):
148
+ """→ (probs9: float32[9], vad3: float32[3] theo thứ tự [VAL,ARO,DOM] thang 1–5);
149
+ None nếu thiếu/lỗi file."""
150
+ if not os.path.exists(wav_path):
151
+ return None
152
+ wave, _ = librosa.load(wav_path, sr=SR, mono=True)
153
+ wave = wave[: MAX_SECONDS * SR] # cắt tối đa 15s
154
+ data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
155
+ logits, _feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
156
+ probs9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
157
+ # VAD sigmoid [0,1] → thang 1–5 cho khớp ví dụ BTC (SRCC bất biến với scale tuyến tính)
158
+ v, a, d = float(valence.item()), float(arousal.item()), float(dominance.item())
159
+ vad3 = np.array([1 + 4 * v, 1 + 4 * a, 1 + 4 * d], dtype=np.float32) # [VAL, ARO, DOM]
160
+ return probs9, vad3
161
+
162
+ def emos_from_probs(probs9, target):
163
+ """EMOS = 1 + 4·P(target). None nếu không biết target → để caller xử lý mặc định."""
164
+ if target is None or target not in EMO2SAILER:
165
+ return None
166
+ return 1.0 + 4.0 * float(probs9[EMO2SAILER[target]])
167
+
168
+ def cat5_from_probs(probs9):
169
+ """Lấy 5 lớp challenge từ 9 lớp SAILER rồi renormalize tổng=1."""
170
+ v = np.array([probs9[EMO2SAILER[e]] for e in EMOTIONS5], dtype=np.float32)
171
+ s = v.sum()
172
+ return v / s if s > 0 else np.full(5, 0.2, dtype=np.float32)
173
+
174
+ # %% [markdown]
175
+ # ## 5. QMOS = SpeechMOS (UTMOS) — bắt buộc cho answer.txt
176
+
177
+ # %%
178
+ @torch.no_grad()
179
+ def run_qmos(names):
180
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).to(device).eval()
181
+ from tqdm.auto import tqdm
182
+ out = {}
183
+ for n in tqdm(names, desc="QMOS"):
184
+ p = os.path.join(WAV_DIR, n)
185
+ if not os.path.exists(p):
186
+ continue
187
+ wave, _ = librosa.load(p, sr=SR, mono=True)
188
+ x = torch.from_numpy(wave).unsqueeze(0).to(device) # đẩy input lên GPU
189
+ out[n] = float(predictor(x, sr=SR).mean().item())
190
+ return out
191
+
192
+ # %% [markdown]
193
+ # ## 6. Chạy trên DEV → `answer.txt` đầy đủ (QMOS, EMOS, CAT, VAL, ARO, DOM)
194
+
195
+ # %%
196
+ def list_dev():
197
+ with open(DEV_SCP) as f:
198
+ return [ln.strip() for ln in f if ln.strip()]
199
+
200
+ dev_names = list_dev()
201
+ if LIMIT:
202
+ dev_names = dev_names[:LIMIT]
203
+ print("DEV:", len(dev_names), "mẫu")
204
+
205
+ qmos_scores = run_qmos(dev_names)
206
+
207
+ def fmt_cat(probs5):
208
+ return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
209
+
210
+ def build_answer(out_path):
211
+ from tqdm.auto import tqdm
212
+ n_emos = n_default = 0
213
+ with open(out_path, "w") as f:
214
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
215
+ for name in tqdm(dev_names, desc="SAILER EMOS/CAT/VAD"):
216
+ sid = stem(name)
217
+ out = sailer_infer(os.path.join(WAV_DIR, name))
218
+ if out is None:
219
+ emos, cat5 = 3.0, np.full(5, 0.2, dtype=np.float32)
220
+ vad3 = np.array([3.0, 3.0, 3.0], dtype=np.float32)
221
+ n_default += 1
222
+ else:
223
+ probs9, vad3 = out
224
+ emos = emos_from_probs(probs9, target_map.get(sid))
225
+ if emos is None:
226
+ emos = 3.0; n_default += 1
227
+ else:
228
+ n_emos += 1
229
+ cat5 = cat5_from_probs(probs9)
230
+ qmos = qmos_scores.get(name, 3.0)
231
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
232
+ f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
233
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default}")
234
+
235
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
236
+ build_answer(answer_path)
237
+
238
+ # %% [markdown]
239
+ # ## 7. Validate + đóng zip
240
+
241
+ # %%
242
+ def validate(path):
243
+ import csv
244
+ with open(path) as f:
245
+ rows = list(csv.reader(f))
246
+ header = rows[0]
247
+ assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
248
+ for i, r in enumerate(rows[1:], 2):
249
+ assert len(r) == len(header), f"Dòng {i} sai số cột"
250
+ print(f"OK: {len(rows)-1} dòng, header = {header}")
251
+
252
+ validate(answer_path)
253
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp03_sailer.zip answer.txt && unzip -l submission_track2_exp03_sailer.zip")
254
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp03_sailer.zip"))
255
+
256
+ # %% [markdown]
257
+ # ## Ghi chú
258
+ # - **Chưa chạy thật bao giờ** → lần đầu đặt `LIMIT = 20` ở cell 0 để bắt lỗi setup (clone repo / import / model).
259
+ # - Điểm DEV thật phải nộp lên CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).
260
+ # - Notebook này đổi **EMOS + CAT + VAD** sang SAILER (1 model lo 6 cột metric). QMOS vẫn SpeechMOS cũ.
261
+ # Muốn ablation EMOS sạch (giữ CAT=emotion2vec) thì chỉ lấy cột EMOS từ đây, ghép với CAT của `track2_baseline`.
262
+ # - Rủi ro setup duy nhất = import `src.model.emotion.wavlm_emotion` (cần repo vox-profile-release).
263
+ # Nếu lỗi import: kiểm tra `REPO_DIR` đã clone + `sys.path` đã thêm REPO_DIR (KHÔNG dùng pip install -e .).
264
+ # - Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp03).
track2/exp04_fusion.ipynb ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d85dcf89",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp04 (FUSION multi-task) — Kaggle\n",
9
+ "\n",
10
+ "**Mục tiêu:** gộp 2 backbone bổ sung nhau (**emotion2vec** thắng EMOS · **SAILER/WavLM** thắng VAD)\n",
11
+ "thành **1 model multi-task** dự đoán chung 5 đầu ra cảm xúc: **EMOS · CAT · VAL · ARO · DOM**.\n",
12
+ "QMOS để **riêng** (giữ SpeechMOS) — đúng thiết kế đã chốt: *\"QMOS riêng + 5 cảm xúc chung\"*.\n",
13
+ "\n",
14
+ "## Ý tưởng (đọc 1 lần cho hiểu)\n",
15
+ "Bằng chứng để fusion (từ exp01 & exp03): emotion2vec đứng đầu **EMOS** (0.637), SAILER đứng đầu\n",
16
+ "**VAD** (ARO 0.712 / DOM 0.630). Hai model \"nhìn\" cảm xúc theo cách khác nhau → **nối đặc trưng**\n",
17
+ "của cả hai rồi cho một mạng nhỏ học → kỳ vọng mạnh hơn từng model lẻ.\n",
18
+ "\n",
19
+ "```\n",
20
+ " ┌─ emotion2vec ─► embedding ~D1 + xác suất 5 lớp ─┐\n",
21
+ " mỗi wav ──────►│ ├─► NỐI ─► TRUNK chung\n",
22
+ " └─ SAILER(WavLM) ► embedding ~D2 + 9 lớp + VAD3 ─┘ (Linear+ReLU)\n",
23
+ " │\n",
24
+ " ┌───────────────────────────────────────────────┤\n",
25
+ " target emotion(one-hot)│ │\n",
26
+ " ▼ ▼\n",
27
+ " [EMOS head] [CAT head] [VAD head]\n",
28
+ " (cần target) (5 lớp) (VAL/ARO/DOM)\n",
29
+ "```\n",
30
+ "\n",
31
+ "- **Cả 2 backbone ĐÓNG BĂNG** → chỉ trích đặc trưng (cache `.npz`), **chỉ train phần trunk + head nhỏ**\n",
32
+ " → nhẹ GPU, train vài phút, hợp T4. (Né fine-tune end-to-end lúc đầu.)\n",
33
+ "- **EMOS phụ thuộc target** (cùng audio, target khác → điểm khác) → EMOS head nhận thêm one-hot target.\n",
34
+ " **CAT/VAD** là cảm nhận về chính audio → chỉ cần trunk (không cần target).\n",
35
+ "- **Nhãn vàng** gộp theo `wavID` từ `sets/train.csv`:\n",
36
+ " EMOS = TB `eMOS` · VAL/ARO/DOM = TB `val/aro/dom` · CAT = **tỉ lệ vote 5 lớp** của `emoCat`.\n",
37
+ "- **Cân loss = uncertainty weighting** (Kendall 2018): mỗi task có 1 trọng số σ **tự học**\n",
38
+ " → không phải dò tay. Có cờ `USE_UNCERTAINTY=False` để quay về trọng số cố định khi cần debug.\n",
39
+ "- Cuối cùng xuất `answer.txt` **đủ 7 cột**: `wav,QMOS,EMOS,CAT,VAL,ARO,DOM`\n",
40
+ " (QMOS=SpeechMOS · 5 cột còn lại = model fusion) → nộp được ngay. So mốc: EMOS 0.637 · VAD ARO 0.712.\n",
41
+ "\n",
42
+ "**Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**\n",
43
+ "→ + Add Input dataset Track 2 (15.477 wav, có `sets/train.csv`, `sets/dev.scp`, `metadata.csv`)\n",
44
+ "→ sửa `DATA_ROOT` ở cell 0 → Run All. Lần đầu nên đặt `LIMIT_TRAIN = 300`, `LIMIT_DEV = 20` để bắt lỗi setup."
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "id": "5101bb4e",
50
+ "metadata": {},
51
+ "source": [
52
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "code",
57
+ "execution_count": null,
58
+ "id": "3fee9b16",
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "import os\n",
63
+ "\n",
64
+ "# ── Data Track 2 (dataset 15.477 wav đã ráp, có sets/train.csv) ──────────────\n",
65
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
66
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
67
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) → target emotion\n",
68
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # nhãn người nghe: lisID,wavID,qMOS,emoCat,eMOS,val,dom,aro\n",
69
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV (tập cần nộp ở training phase)\n",
70
+ "\n",
71
+ "OUT_DIR = \"/kaggle/working\"\n",
72
+ "CACHE_DIR = \"/kaggle/working/fusion_cache\" # cache embedding 2 backbone (tái dùng giữa các lần chạy)\n",
73
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
74
+ "\n",
75
+ "# ── Siêu tham số train ───────────────────────────────────────────────────────\n",
76
+ "DEVICE = \"cuda\" # \"cuda\" trên Kaggle GPU; \"cpu\" nếu không có GPU\n",
77
+ "TRUNK_HIDDEN = 512 # số neuron lớp trunk chung\n",
78
+ "HEAD_HIDDEN = 128 # số neuron lớp ẩn mỗi head\n",
79
+ "DROPOUT = 0.3\n",
80
+ "LR = 1e-3\n",
81
+ "EPOCHS = 80\n",
82
+ "BATCH = 64\n",
83
+ "VAL_FRAC = 0.10 # 10% train → validation nội bộ (đo SRCC từng task)\n",
84
+ "PATIENCE = 15 # early stop theo điểm tổng val (xem SCORE_FOR_STOP)\n",
85
+ "SEED = 42\n",
86
+ "\n",
87
+ "USE_UNCERTAINTY = True # True = tự cân loss (Kendall); False = dùng LOSS_W cố định bên dưới\n",
88
+ "LOSS_W = {\"emos\": 1.0, \"cat\": 1.0, \"val\": 1.0, \"aro\": 1.0, \"dom\": 1.0} # chỉ dùng khi tắt uncertainty\n",
89
+ "USE_E2V = True # bật/tắt nhánh emotion2vec trong fusion (để ablation)\n",
90
+ "USE_SAILER = True # bật/tắt nhánh SAILER trong fusion (để ablation)\n",
91
+ "USE_CLASSPROB = True # thêm xác suất lớp (e2v 5 + sailer 9) + VAD3 của SAILER vào feature\n",
92
+ "\n",
93
+ "LIMIT_TRAIN = None # đặt số nhỏ (vd 300) để chạy thử nhanh; None = full\n",
94
+ "LIMIT_DEV = None # đặt số nhỏ (vd 20) để chạy thử nhanh; None = full\n",
95
+ "\n",
96
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
97
+ "\n",
98
+ "# 9 lớp SAILER (đúng thứ tự model xuất) + chỉ số của 5 lớp challenge trong đó\n",
99
+ "SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
100
+ "EMO2SAILER = {\"angry\": 0, \"happy\": 4, \"neutral\": 5, \"sad\": 6, \"surprised\": 7} # EMOTIONS5 → index trong SAILER9\n",
101
+ "\n",
102
+ "_EMO_ALIAS = {\n",
103
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
104
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
105
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
106
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
107
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
108
+ "}\n",
109
+ "\n",
110
+ "def norm_emotion(label):\n",
111
+ " \"\"\"Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp.\"\"\"\n",
112
+ " key = str(label).strip().lower()\n",
113
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
114
+ "\n",
115
+ "def stem(path_or_name):\n",
116
+ " \"\"\"Lấy tên file không đuôi, để khớp wavID giữa train.csv / metadata / dev.scp.\"\"\"\n",
117
+ " return os.path.splitext(os.path.basename(str(path_or_name)))[0]\n",
118
+ "\n",
119
+ "assert USE_E2V or USE_SAILER, \"Phải bật ít nhất 1 backbone (USE_E2V hoặc USE_SAILER).\"\n",
120
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
121
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
122
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
123
+ ]
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "id": "580854fb",
128
+ "metadata": {},
129
+ "source": [
130
+ "## 1. Cài đặt + tải code SAILER\n",
131
+ "emotion2vec qua `funasr` (offline). SAILER cần `WavLMWrapper` trong repo `vox-profile-release`\n",
132
+ "→ **clone + sys.path** (KHÔNG `pip install -e .` vì build wheel hay lỗi trên Kaggle)."
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "id": "a0ea1faa",
139
+ "metadata": {},
140
+ "outputs": [],
141
+ "source": [
142
+ "import sys, subprocess\n",
143
+ "\n",
144
+ "def pip_install(*pkgs):\n",
145
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
146
+ "\n",
147
+ "pip_install(\"speechmos\", \"funasr\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
148
+ "\n",
149
+ "if USE_SAILER:\n",
150
+ " pip_install(\"loralib\", \"speechbrain\") # deps WavLMWrapper cần\n",
151
+ " REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
152
+ " if not os.path.exists(REPO_DIR):\n",
153
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
154
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
155
+ " if REPO_DIR not in sys.path:\n",
156
+ " sys.path.insert(0, REPO_DIR)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "id": "43033a70",
162
+ "metadata": {},
163
+ "source": [
164
+ "## 2. Đọc & gộp nhãn (gộp theo wavID)\n",
165
+ "- `train.csv`: mỗi dòng = 1 listener chấm 1 wav → gộp **theo wavID**:\n",
166
+ " EMOS=TB `eMOS` · VAL/ARO/DOM=TB `val/aro/dom` · CAT=**tỉ lệ vote 5 lớp** của `emoCat`.\n",
167
+ "- `metadata.csv`: lấy **cảm xúc target** cho mỗi wav (để feed EMOS head)."
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "d4051547",
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "import numpy as np\n",
178
+ "import pandas as pd\n",
179
+ "\n",
180
+ "def load_target_emotions():\n",
181
+ " \"\"\"metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}.\"\"\"\n",
182
+ " tgt = {}\n",
183
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
184
+ " for ln in f:\n",
185
+ " parts = ln.strip().split(\"|\")\n",
186
+ " if len(parts) < 2:\n",
187
+ " continue\n",
188
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
189
+ " return tgt\n",
190
+ "\n",
191
+ "def _col(cols_map, *names, default_idx=None, df=None):\n",
192
+ " for n in names:\n",
193
+ " if n in cols_map:\n",
194
+ " return cols_map[n]\n",
195
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
196
+ "\n",
197
+ "def parse_emocat_votes(cell):\n",
198
+ " \"\"\"1 ô emoCat (có thể đa nhãn, vd 'happy;surprised') → vector đếm 5 lớp (chưa chuẩn hóa).\"\"\"\n",
199
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
200
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
201
+ " e = norm_emotion(tok)\n",
202
+ " if e in EMOTIONS5:\n",
203
+ " v[EMOTIONS5.index(e)] += 1.0\n",
204
+ " return v\n",
205
+ "\n",
206
+ "def load_train_labels():\n",
207
+ " \"\"\"train.csv → DataFrame [wavID, emos, val, aro, dom, cat0..cat4] gộp theo wav.\n",
208
+ " CAT = tỉ lệ vote 5 lớp (tổng=1); nếu wav không có vote hợp lệ → phân phối đều.\"\"\"\n",
209
+ " # train.csv phân tách bằng \"|\"; cột emoCat đa nhãn dùng \",\" bên trong (vd \"Angry,Surprised\").\n",
210
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
211
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
212
+ " wav_col = _col(cols, \"wavid\", \"wav\", default_idx=1, df=df)\n",
213
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
214
+ " val_col = _col(cols, \"val\", \"valence\")\n",
215
+ " aro_col = _col(cols, \"aro\", \"arousal\")\n",
216
+ " dom_col = _col(cols, \"dom\", \"dominance\")\n",
217
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
218
+ " assert emos_col, f\"Không thấy cột eMOS trong train.csv (cột: {list(df.columns)})\"\n",
219
+ "\n",
220
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
221
+ " rows = []\n",
222
+ " for sid, g in df.groupby(\"_stem\"):\n",
223
+ " rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
224
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
225
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
226
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
227
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
228
+ " if cat_col:\n",
229
+ " for cell in g[cat_col]:\n",
230
+ " votes += parse_emocat_votes(cell)\n",
231
+ " s = votes.sum()\n",
232
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)\n",
233
+ " for i in range(len(EMOTIONS5)):\n",
234
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
235
+ " rows.append(rec)\n",
236
+ " return pd.DataFrame(rows)\n",
237
+ "\n",
238
+ "target_map = load_target_emotions()\n",
239
+ "train_df = load_train_labels()\n",
240
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
241
+ "print(f\"Target emotions: {len(target_map)} | wav train (gộp): {len(train_df)} | có nhãn VAD: {HAS_VAD}\")\n",
242
+ "print(\"eMOS:\", train_df[\"emos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
243
+ "train_df.head()"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "markdown",
248
+ "id": "6c5e27b9",
249
+ "metadata": {},
250
+ "source": [
251
+ "## 3. Trích đặc trưng 2 backbone (có cache riêng từng model)\n",
252
+ "- **emotion2vec** → embedding + xác suất 5 lớp (như exp02).\n",
253
+ "- **SAILER** → embedding (features) + xác suất 9 lớp + VAD3 (như exp03).\n",
254
+ "Mỗi backbone cache riêng (`e2v_<tag>.npz`, `sailer_<tag>.npz`) → chạy nối tiếp được, đổi 1 backbone\n",
255
+ "không phải trích lại cái kia. Trích xong **giải phóng GPU** rồi mới nạp backbone sau."
256
+ ]
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "execution_count": null,
261
+ "id": "ebaac593",
262
+ "metadata": {
263
+ "lines_to_next_cell": 1
264
+ },
265
+ "outputs": [],
266
+ "source": [
267
+ "import torch\n",
268
+ "import torch.nn.functional as F\n",
269
+ "\n",
270
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
271
+ "print(\"Device:\", device)\n",
272
+ "if device == \"cuda\":\n",
273
+ " print(\" ✅ GPU:\", torch.cuda.get_device_name(0))\n",
274
+ "else:\n",
275
+ " print(\" ⚠️ KHÔNG thấy GPU! Trích đặc trưng ~15k file trên CPU rất lâu.\")\n",
276
+ " print(\" → Settings → Accelerator = GPU T4 rồi chạy lại.\")\n",
277
+ "\n",
278
+ "# ---- emotion2vec ----\n",
279
+ "def extract_e2v(stems, tag):\n",
280
+ " \"\"\"→ dict {stem: (emb[D1], probs5[5])}. Cache CACHE_DIR/e2v_<tag>.npz.\"\"\"\n",
281
+ " from tqdm.auto import tqdm\n",
282
+ " cache_path = os.path.join(CACHE_DIR, f\"e2v_{tag}.npz\")\n",
283
+ " store = {}\n",
284
+ " if os.path.exists(cache_path):\n",
285
+ " z = np.load(cache_path, allow_pickle=True)\n",
286
+ " store = {k: z[k] for k in z.files}\n",
287
+ " print(f\"[e2v/{tag}] nạp cache: {len(store)}\")\n",
288
+ " todo = [s for s in stems if s not in store]\n",
289
+ " if todo:\n",
290
+ " import logging\n",
291
+ " logging.getLogger(\"funasr\").setLevel(logging.ERROR) # bớt log ồn của funasr\n",
292
+ " from funasr import AutoModel\n",
293
+ " m = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\", device=device,\n",
294
+ " disable_update=True, disable_pbar=True, disable_log=True) # ép GPU + tắt log\n",
295
+ " miss = 0\n",
296
+ " for i, s in enumerate(tqdm(todo, desc=f\"e2v {tag}\")):\n",
297
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
298
+ " if not os.path.exists(wav):\n",
299
+ " miss += 1; continue\n",
300
+ " r = m.generate(wav, granularity=\"utterance\", extract_embedding=True)[0]\n",
301
+ " emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
302
+ " probs = {e: 0.0 for e in EMOTIONS5}\n",
303
+ " for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
304
+ " name = lab.split(\"/\")[-1]\n",
305
+ " if name in probs:\n",
306
+ " probs[name] = float(sc)\n",
307
+ " tot = sum(probs.values())\n",
308
+ " p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)\n",
309
+ " store[s] = np.concatenate([emb, p5]).astype(np.float32) # [D1 + 5]\n",
310
+ " if (i + 1) % 500 == 0:\n",
311
+ " np.savez(cache_path, **store)\n",
312
+ " np.savez(cache_path, **store)\n",
313
+ " del m\n",
314
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
315
+ " if miss:\n",
316
+ " print(f\"[e2v/{tag}] {miss} file thiếu → bỏ qua.\")\n",
317
+ " return {s: (v[:-5], v[-5:]) for s, v in store.items()}\n",
318
+ "\n",
319
+ "# ---- SAILER ----\n",
320
+ "def _pool_feat(features):\n",
321
+ " \"\"\"features (tensor) → vector 1 chiều (mean-pool nếu còn chiều thời gian).\"\"\"\n",
322
+ " f = features.detach().cpu().numpy()\n",
323
+ " if f.ndim <= 1:\n",
324
+ " return f.reshape(-1).astype(np.float32)\n",
325
+ " return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)\n",
326
+ "\n",
327
+ "def extract_sailer(stems, tag):\n",
328
+ " \"\"\"→ dict {stem: (emb[D2], probs9[9], vad3[3] thang 1–5)}. Cache CACHE_DIR/sailer_<tag>.npz.\n",
329
+ " Mỗi mẫu lưu vector [emb | probs9(9) | vad3(3)] → cắt lại khi nạp.\"\"\"\n",
330
+ " import librosa\n",
331
+ " from tqdm.auto import tqdm\n",
332
+ " cache_path = os.path.join(CACHE_DIR, f\"sailer_{tag}.npz\")\n",
333
+ " store = {}\n",
334
+ " if os.path.exists(cache_path):\n",
335
+ " z = np.load(cache_path, allow_pickle=True)\n",
336
+ " store = {k: z[k] for k in z.files}\n",
337
+ " print(f\"[sailer/{tag}] nạp cache: {len(store)}\")\n",
338
+ " todo = [s for s in stems if s not in store]\n",
339
+ " if todo:\n",
340
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
341
+ " sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device).eval()\n",
342
+ " miss = 0\n",
343
+ " with torch.no_grad():\n",
344
+ " for i, s in enumerate(tqdm(todo, desc=f\"sailer {tag}\")):\n",
345
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
346
+ " if not os.path.exists(wav):\n",
347
+ " miss += 1; continue\n",
348
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
349
+ " wave = wave[: 15 * 16000]\n",
350
+ " data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
351
+ " logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
352
+ " emb = _pool_feat(feat)\n",
353
+ " p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
354
+ " vad3 = np.array([1 + 4 * float(valence.item()),\n",
355
+ " 1 + 4 * float(arousal.item()),\n",
356
+ " 1 + 4 * float(dominance.item())], dtype=np.float32) # [VAL,ARO,DOM]\n",
357
+ " store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32) # [D2 + 9 + 3]\n",
358
+ " if (i + 1) % 500 == 0:\n",
359
+ " np.savez(cache_path, **store)\n",
360
+ " np.savez(cache_path, **store)\n",
361
+ " del sailer\n",
362
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
363
+ " if miss:\n",
364
+ " print(f\"[sailer/{tag}] {miss} file thiếu → bỏ qua.\")\n",
365
+ " return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}"
366
+ ]
367
+ },
368
+ {
369
+ "cell_type": "markdown",
370
+ "id": "751d646c",
371
+ "metadata": {},
372
+ "source": [
373
+ "## 4. Dựng feature + nhãn cho train\n",
374
+ "Feature audio (KHÔNG gồm target) = nối các phần đang bật:\n",
375
+ "`[e2v_emb | e2v_probs5 | sailer_emb | sailer_probs9 | sailer_vad3]`.\n",
376
+ "One-hot target để **riêng** (chỉ EMOS head dùng). Bỏ wav thiếu feature."
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "id": "005cdf2f",
383
+ "metadata": {},
384
+ "outputs": [],
385
+ "source": [
386
+ "train_stems = list(train_df[\"wavID\"])\n",
387
+ "if LIMIT_TRAIN:\n",
388
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
389
+ "\n",
390
+ "e2v_tr = extract_e2v(train_stems, \"train\") if USE_E2V else {}\n",
391
+ "sailer_tr = extract_sailer(train_stems, \"train\") if USE_SAILER else {}\n",
392
+ "\n",
393
+ "def audio_feature(sid, e2v_map, sailer_map):\n",
394
+ " \"\"\"Nối đặc trưng audio cho 1 wav. None nếu thiếu phần bắt buộc.\"\"\"\n",
395
+ " parts = []\n",
396
+ " if USE_E2V:\n",
397
+ " pk = e2v_map.get(sid)\n",
398
+ " if pk is None:\n",
399
+ " return None\n",
400
+ " emb, p5 = pk\n",
401
+ " parts.append(emb)\n",
402
+ " if USE_CLASSPROB:\n",
403
+ " parts.append(p5)\n",
404
+ " if USE_SAILER:\n",
405
+ " pk = sailer_map.get(sid)\n",
406
+ " if pk is None:\n",
407
+ " return None\n",
408
+ " emb, p9, vad3 = pk\n",
409
+ " parts.append(emb)\n",
410
+ " if USE_CLASSPROB:\n",
411
+ " parts.append(p9); parts.append(vad3)\n",
412
+ " return np.concatenate(parts).astype(np.float32)\n",
413
+ "\n",
414
+ "def onehot_target(tgt):\n",
415
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
416
+ " if tgt in EMOTIONS5:\n",
417
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
418
+ " return v\n",
419
+ "\n",
420
+ "lab = train_df.set_index(\"wavID\")\n",
421
+ "X, T, y_emos, y_vad, y_cat = [], [], [], [], []\n",
422
+ "for s in train_stems:\n",
423
+ " f = audio_feature(s, e2v_tr, sailer_tr)\n",
424
+ " tgt = target_map.get(s)\n",
425
+ " if f is None or tgt is None or s not in lab.index:\n",
426
+ " continue\n",
427
+ " X.append(f)\n",
428
+ " T.append(onehot_target(tgt))\n",
429
+ " y_emos.append(lab.loc[s, \"emos\"])\n",
430
+ " y_vad.append([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]])\n",
431
+ " y_cat.append([lab.loc[s, f\"cat{i}\"] for i in range(len(EMOTIONS5))])\n",
432
+ "\n",
433
+ "X = np.stack(X).astype(np.float32)\n",
434
+ "T = np.stack(T).astype(np.float32)\n",
435
+ "y_emos = np.array(y_emos, dtype=np.float32)\n",
436
+ "y_vad = np.array(y_vad, dtype=np.float32) # [N,3] (VAL,ARO,DOM) — có thể toàn NaN nếu thiếu nhãn\n",
437
+ "y_cat = np.array(y_cat, dtype=np.float32) # [N,5] phân phối tổng=1\n",
438
+ "FEAT_DIM = X.shape[1]\n",
439
+ "print(f\"Train: X={X.shape} target={T.shape} emos={y_emos.shape} vad={y_vad.shape} cat={y_cat.shape}\")\n",
440
+ "\n",
441
+ "# Chuẩn hóa feature audio (z-score) — lưu mean/std để áp dụng y hệt lúc dự đoán DEV.\n",
442
+ "feat_mean = X.mean(0, keepdims=True)\n",
443
+ "feat_std = X.std(0, keepdims=True) + 1e-6\n",
444
+ "Xn = (X - feat_mean) / feat_std\n",
445
+ "\n",
446
+ "# Chuẩn hóa nhãn liên tục (eMOS, VAD) về z-score → các MSE cùng thang (uncertainty weighting ổn định hơn).\n",
447
+ "# SRCC bất biến với scale → khi xuất answer.txt chỉ cần đảo z-score về thang gốc cho đẹp.\n",
448
+ "emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6)\n",
449
+ "y_emos_z = (y_emos - emos_mu) / emos_sd\n",
450
+ "if HAS_VAD:\n",
451
+ " vad_mu = np.nanmean(y_vad, axis=0)\n",
452
+ " vad_sd = np.nanstd(y_vad, axis=0) + 1e-6\n",
453
+ " y_vad_z = (y_vad - vad_mu) / vad_sd\n",
454
+ "else:\n",
455
+ " vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
456
+ " y_vad_z = np.zeros_like(y_vad)"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "markdown",
461
+ "id": "f41faa42",
462
+ "metadata": {},
463
+ "source": [
464
+ "## 5. Model fusion multi-task + train loop\n",
465
+ "- **Trunk** chung: `Linear(FEAT_DIM→TRUNK_HIDDEN)+ReLU+Dropout` (×2).\n",
466
+ "- **EMOS head**: nối `[trunk | one-hot target]` → MLP → 1 (vì EMOS phụ thuộc target).\n",
467
+ "- **CAT head**: trunk → 5 logits → softmax (dự đoán phân phối vote). Loss = soft-CE (KL).\n",
468
+ "- **VAD head**: trunk → 3 (VAL/ARO/DOM). Loss = MSE (bỏ qua nếu thiếu nhãn VAD).\n",
469
+ "- **Cân loss**: uncertainty weighting — tổng `Σ exp(-sᵢ)·Lᵢ + sᵢ`, `sᵢ=log σᵢ²` **học được**."
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "code",
474
+ "execution_count": null,
475
+ "id": "dc5e0242",
476
+ "metadata": {
477
+ "lines_to_next_cell": 1
478
+ },
479
+ "outputs": [],
480
+ "source": [
481
+ "import torch.nn as nn\n",
482
+ "from scipy.stats import spearmanr\n",
483
+ "from sklearn.model_selection import train_test_split\n",
484
+ "\n",
485
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
486
+ "N_EMO = len(EMOTIONS5)\n",
487
+ "\n",
488
+ "idx_all = np.arange(X.shape[0])\n",
489
+ "tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)\n",
490
+ "\n",
491
+ "def to_t(a):\n",
492
+ " return torch.tensor(a, dtype=torch.float32, device=device)\n",
493
+ "\n",
494
+ "Xn_t, T_t = to_t(Xn), to_t(T)\n",
495
+ "emos_t = to_t(y_emos_z).unsqueeze(1)\n",
496
+ "vad_t = to_t(y_vad_z)\n",
497
+ "cat_t = to_t(y_cat)\n",
498
+ "\n",
499
+ "class FusionMTL(nn.Module):\n",
500
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
501
+ " super().__init__()\n",
502
+ " self.trunk = nn.Sequential(\n",
503
+ " nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
504
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
505
+ " )\n",
506
+ " self.emos = nn.Sequential( # nhận [trunk | target]\n",
507
+ " nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
508
+ " self.cat = nn.Sequential(\n",
509
+ " nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
510
+ " self.vad = nn.Sequential(\n",
511
+ " nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
512
+ "\n",
513
+ " def forward(self, x, tgt):\n",
514
+ " h = self.trunk(x)\n",
515
+ " emos = self.emos(torch.cat([h, tgt], dim=1))\n",
516
+ " cat_logits = self.cat(h)\n",
517
+ " vad = self.vad(h)\n",
518
+ " return emos, cat_logits, vad\n",
519
+ "\n",
520
+ "model = FusionMTL(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
521
+ "\n",
522
+ "# Trọng số bất định (log σ²) cho 5 task: emos, cat, val, aro, dom.\n",
523
+ "TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
524
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
525
+ "params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
526
+ "opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)\n",
527
+ "\n",
528
+ "mse = nn.MSELoss(reduction=\"none\")\n",
529
+ "\n",
530
+ "def soft_ce(logits, target_dist):\n",
531
+ " \"\"\"Cross-entropy với nhãn mềm (phân phối): −Σ p·log q.\"\"\"\n",
532
+ " logq = F.log_softmax(logits, dim=1)\n",
533
+ " return -(target_dist * logq).sum(dim=1)\n",
534
+ "\n",
535
+ "def task_losses(emos_p, cat_logits, vad_p, b):\n",
536
+ " \"\"\"Trả về dict loss TB từng task cho 1 batch (chỉ số b).\"\"\"\n",
537
+ " L = {}\n",
538
+ " L[\"emos\"] = mse(emos_p, emos_t[b]).mean()\n",
539
+ " L[\"cat\"] = soft_ce(cat_logits, cat_t[b]).mean()\n",
540
+ " if HAS_VAD:\n",
541
+ " L[\"val\"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()\n",
542
+ " L[\"aro\"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()\n",
543
+ " L[\"dom\"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()\n",
544
+ " else:\n",
545
+ " z = torch.zeros((), device=device)\n",
546
+ " L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
547
+ " return L\n",
548
+ "\n",
549
+ "def combine(L):\n",
550
+ " \"\"\"Gộp 5 loss thành 1 số: uncertainty weighting hoặc trọng số cố định.\"\"\"\n",
551
+ " if USE_UNCERTAINTY:\n",
552
+ " tot = 0.0\n",
553
+ " for i, t in enumerate(TASKS):\n",
554
+ " tot = tot + torch.exp(-log_var[i]) * L[t] + log_var[i]\n",
555
+ " return tot\n",
556
+ " return sum(LOSS_W[t] * L[t] for t in TASKS)\n",
557
+ "\n",
558
+ "@torch.no_grad()\n",
559
+ "def eval_val():\n",
560
+ " \"\"\"SRCC từng task trên tập val nội bộ (CAT báo bằng −KL để 'cao=tốt' cho early-stop).\"\"\"\n",
561
+ " model.eval()\n",
562
+ " ep, cl, vp = model(Xn_t[va_idx], T_t[va_idx])\n",
563
+ " ep = ep.cpu().numpy().ravel()\n",
564
+ " out = {\"emos\": spearmanr(ep, y_emos[va_idx]).correlation}\n",
565
+ " if HAS_VAD:\n",
566
+ " vp = vp.cpu().numpy()\n",
567
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
568
+ " out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation\n",
569
+ " # CAT: dùng −KL(p‖q) trung bình (càng gần 0 càng tốt) → đổi dấu để hợp early-stop\n",
570
+ " q = F.softmax(cl, dim=1).cpu().numpy()\n",
571
+ " p = y_cat[va_idx]\n",
572
+ " kl = (p * (np.log(p + 1e-9) - np.log(q + 1e-9))).sum(1).mean()\n",
573
+ " out[\"cat_negkl\"] = float(-kl)\n",
574
+ " return out\n",
575
+ "\n",
576
+ "def val_score(m):\n",
577
+ " \"\"\"Điểm tổng để early-stop = TB SRCC các task liên tục có nhãn.\"\"\"\n",
578
+ " keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
579
+ " return float(np.mean([m[k] for k in keys]))\n",
580
+ "\n",
581
+ "best_score, best_state, bad = -1e9, None, 0\n",
582
+ "tr_t = torch.tensor(tr_idx, device=device)\n",
583
+ "for ep in range(1, EPOCHS + 1):\n",
584
+ " model.train()\n",
585
+ " perm = tr_t[torch.randperm(len(tr_t), device=device)]\n",
586
+ " run = 0.0\n",
587
+ " for i in range(0, len(perm), BATCH):\n",
588
+ " b = perm[i:i + BATCH]\n",
589
+ " opt.zero_grad()\n",
590
+ " emos_p, cat_logits, vad_p = model(Xn_t[b], T_t[b])\n",
591
+ " L = task_losses(emos_p, cat_logits, vad_p, b)\n",
592
+ " loss = combine(L)\n",
593
+ " loss.backward(); opt.step()\n",
594
+ " run += loss.item() * len(b)\n",
595
+ " m = eval_val()\n",
596
+ " sc = val_score(m)\n",
597
+ " if sc > best_score:\n",
598
+ " best_score = sc\n",
599
+ " best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
600
+ " bad = 0\n",
601
+ " else:\n",
602
+ " bad += 1\n",
603
+ " if ep % 5 == 0 or ep == 1:\n",
604
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in m)\n",
605
+ " print(f\"epoch {ep:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}\")\n",
606
+ " if bad >= PATIENCE:\n",
607
+ " print(f\"Early stop ở epoch {ep}.\")\n",
608
+ " break\n",
609
+ "\n",
610
+ "model.load_state_dict(best_state)\n",
611
+ "final = eval_val()\n",
612
+ "print(\"\\n✅ VAL (nội bộ) tốt nhất:\")\n",
613
+ "print(f\" EMOS SRCC = {final['emos']:.4f} (so mốc exp01 emotion2vec = 0.637)\")\n",
614
+ "if HAS_VAD:\n",
615
+ " print(f\" VAL/ARO/DOM SRCC = {final['val']:.4f} / {final['aro']:.4f} / {final['dom']:.4f}\"\n",
616
+ " f\" (so mốc SAILER = 0.341 / 0.712 / 0.630)\")\n",
617
+ "if USE_UNCERTAINTY:\n",
618
+ " print(\" log σ² mỗi task:\", {t: round(float(log_var[i]), 3) for i, t in enumerate(TASKS)})\n",
619
+ "\n",
620
+ "# Lưu model + tham số chuẩn hóa.\n",
621
+ "torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
622
+ " \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
623
+ " \"FEAT_DIM\": FEAT_DIM, \"EMOTIONS5\": EMOTIONS5, \"HAS_VAD\": HAS_VAD,\n",
624
+ " \"USE_E2V\": USE_E2V, \"USE_SAILER\": USE_SAILER, \"USE_CLASSPROB\": USE_CLASSPROB,\n",
625
+ " \"TRUNK_HIDDEN\": TRUNK_HIDDEN, \"HEAD_HIDDEN\": HEAD_HIDDEN, \"val_score\": best_score},\n",
626
+ " os.path.join(OUT_DIR, \"fusion_mtl.pt\"))\n",
627
+ "print(\"Đã lưu\", os.path.join(OUT_DIR, \"fusion_mtl.pt\"))"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "markdown",
632
+ "id": "39e3c014",
633
+ "metadata": {},
634
+ "source": [
635
+ "## 6. Dự đoán DEV → `answer.txt` đầy đủ 7 cột\n",
636
+ "- **EMOS/CAT/VAD** = model fusion (đảo z-score về thang gốc cho EMOS/VAD; CAT = softmax 5 lớp).\n",
637
+ "- **QMOS** = SpeechMOS (UTMOS) — để riêng, đúng thiết kế."
638
+ ]
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "execution_count": null,
643
+ "id": "c9d06ec4",
644
+ "metadata": {
645
+ "lines_to_next_cell": 1
646
+ },
647
+ "outputs": [],
648
+ "source": [
649
+ "def list_dev():\n",
650
+ " with open(DEV_SCP) as f:\n",
651
+ " return [ln.strip() for ln in f if ln.strip()] # tên file .wav\n",
652
+ "\n",
653
+ "dev_names = list_dev()\n",
654
+ "if LIMIT_DEV:\n",
655
+ " dev_names = dev_names[:LIMIT_DEV]\n",
656
+ "dev_stems = [stem(n) for n in dev_names]\n",
657
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
658
+ "\n",
659
+ "# 6a. Trích đặc trưng 2 backbone cho DEV (cache riêng)\n",
660
+ "e2v_dev = extract_e2v(dev_stems, \"dev\") if USE_E2V else {}\n",
661
+ "sailer_dev = extract_sailer(dev_stems, \"dev\") if USE_SAILER else {}\n",
662
+ "\n",
663
+ "# 6b. Dự đoán 5 cột cảm xúc bằng model fusion\n",
664
+ "@torch.no_grad()\n",
665
+ "def predict_emotion(sid):\n",
666
+ " f = audio_feature(sid, e2v_dev, sailer_dev)\n",
667
+ " if f is None:\n",
668
+ " return None\n",
669
+ " fn = (f[None, :] - feat_mean) / feat_std\n",
670
+ " tgt = onehot_target(target_map.get(sid))[None, :]\n",
671
+ " model.eval()\n",
672
+ " emos_p, cat_logits, vad_p = model(to_t(fn), to_t(tgt))\n",
673
+ " emos = float(emos_p.item()) * emos_sd + emos_mu # đảo z-score\n",
674
+ " cat5 = F.softmax(cat_logits, dim=1)[0].cpu().numpy()\n",
675
+ " vad3 = vad_p[0].cpu().numpy() * vad_sd + vad_mu # [VAL,ARO,DOM]\n",
676
+ " return emos, cat5, vad3\n",
677
+ "\n",
678
+ "# 6c. QMOS = SpeechMOS (để riêng)\n",
679
+ "@torch.no_grad()\n",
680
+ "def run_qmos(names):\n",
681
+ " import librosa\n",
682
+ " from tqdm.auto import tqdm\n",
683
+ " predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True).to(device).eval()\n",
684
+ " out = {}\n",
685
+ " for n in tqdm(names, desc=\"QMOS\"):\n",
686
+ " p = os.path.join(WAV_DIR, n)\n",
687
+ " if not os.path.exists(p):\n",
688
+ " continue\n",
689
+ " wave, _ = librosa.load(p, sr=16000, mono=True)\n",
690
+ " out[n] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device), sr=16000).mean().item())\n",
691
+ " return out\n",
692
+ "\n",
693
+ "qmos_scores = run_qmos(dev_names)"
694
+ ]
695
+ },
696
+ {
697
+ "cell_type": "code",
698
+ "execution_count": null,
699
+ "id": "999f19fc",
700
+ "metadata": {
701
+ "lines_to_next_cell": 1
702
+ },
703
+ "outputs": [],
704
+ "source": [
705
+ "def fmt_cat(probs5):\n",
706
+ " return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
707
+ "\n",
708
+ "def build_answer(out_path):\n",
709
+ " from tqdm.auto import tqdm\n",
710
+ " n_real = n_default = 0\n",
711
+ " with open(out_path, \"w\") as f:\n",
712
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
713
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
714
+ " sid = stem(name)\n",
715
+ " pred = predict_emotion(sid)\n",
716
+ " if pred is None:\n",
717
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])\n",
718
+ " n_default += 1\n",
719
+ " else:\n",
720
+ " emos, cat5, vad3 = pred\n",
721
+ " n_real += 1\n",
722
+ " qmos = qmos_scores.get(name, 3.0)\n",
723
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
724
+ " f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
725
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | fusion thật {n_real}, mặc định {n_default}\")\n",
726
+ "\n",
727
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
728
+ "build_answer(answer_path)"
729
+ ]
730
+ },
731
+ {
732
+ "cell_type": "markdown",
733
+ "id": "708acd7a",
734
+ "metadata": {},
735
+ "source": [
736
+ "## 7. Validate + đóng zip"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "code",
741
+ "execution_count": null,
742
+ "id": "ba406750",
743
+ "metadata": {},
744
+ "outputs": [],
745
+ "source": [
746
+ "def validate(path):\n",
747
+ " import csv\n",
748
+ " with open(path) as f:\n",
749
+ " rows = list(csv.reader(f))\n",
750
+ " header = rows[0]\n",
751
+ " assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
752
+ " for i, r in enumerate(rows[1:], 2):\n",
753
+ " assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
754
+ " print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
755
+ "\n",
756
+ "validate(answer_path)\n",
757
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp04_fusion.zip answer.txt && unzip -l submission_track2_exp04_fusion.zip\")\n",
758
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp04_fusion.zip\"))"
759
+ ]
760
+ },
761
+ {
762
+ "cell_type": "markdown",
763
+ "id": "c0f4e2ae",
764
+ "metadata": {},
765
+ "source": [
766
+ "## Ghi chú\n",
767
+ "- **Lần đầu**: đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` ở cell 0 để bắt lỗi setup (clone repo / import / model).\n",
768
+ " Chạy OK rồi đặt `None` chạy full.\n",
769
+ "- **VAL SRCC** ở mục 5 là ước lượng nội bộ (10% train) → so mốc EMOS 0.637 / ARO 0.712. Điểm DEV thật\n",
770
+ " phải nộp CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).\n",
771
+ "- Embedding đã cache trong `/kaggle/working/fusion_cache/` → **Save Version** để giữ; lần sau đổi\n",
772
+ " siêu tham số/đổi cách cân loss chỉ train lại head (vài phút), khỏi trích lại.\n",
773
+ "- **Ablation cho paper** (đổi cờ ở cell 0, train lại head):\n",
774
+ " `USE_E2V=False` (chỉ SAILER) · `USE_SAILER=False` (chỉ emotion2vec) · `USE_UNCERTAINTY=False` (trọng số tay)\n",
775
+ " · `USE_CLASSPROB=False` (chỉ embedding) → điền bảng ablation `docs/04_experiments_log.md`.\n",
776
+ "- License SAILER = **Open RAIL (phi thương mại)** → nhắc trong `docs/12_system_description.md`.\n",
777
+ "- Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp04)."
778
+ ]
779
+ }
780
+ ],
781
+ "metadata": {
782
+ "jupytext": {
783
+ "cell_metadata_filter": "-all",
784
+ "main_language": "python",
785
+ "notebook_metadata_filter": "-all"
786
+ }
787
+ },
788
+ "nbformat": 4,
789
+ "nbformat_minor": 5
790
+ }
track2/exp04_fusion_pipeline.py ADDED
@@ -0,0 +1,652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp04 (FUSION multi-task) — Kaggle
3
+ #
4
+ # **Mục tiêu:** gộp 2 backbone bổ sung nhau (**emotion2vec** thắng EMOS · **SAILER/WavLM** thắng VAD)
5
+ # thành **1 model multi-task** dự đoán chung 5 đầu ra cảm xúc: **EMOS · CAT · VAL · ARO · DOM**.
6
+ # QMOS để **riêng** (giữ SpeechMOS) — đúng thiết kế đã chốt: *"QMOS riêng + 5 cảm xúc chung"*.
7
+ #
8
+ # ## Ý tưởng (đọc 1 lần cho hiểu)
9
+ # Bằng chứng để fusion (từ exp01 & exp03): emotion2vec đứng đầu **EMOS** (0.637), SAILER đứng đầu
10
+ # **VAD** (ARO 0.712 / DOM 0.630). Hai model "nhìn" cảm xúc theo cách khác nhau → **nối đặc trưng**
11
+ # của cả hai rồi cho một mạng nhỏ học → kỳ vọng mạnh hơn từng model lẻ.
12
+ #
13
+ # ```
14
+ # ┌─ emotion2vec ─► embedding ~D1 + xác suất 5 lớp ─┐
15
+ # mỗi wav ──────►│ ├─► NỐI ─► TRUNK chung
16
+ # └─ SAILER(WavLM) ► embedding ~D2 + 9 lớp + VAD3 ─┘ (Linear+ReLU)
17
+ # │
18
+ # ┌───────────────────────────────────────────────┤
19
+ # target emotion(one-hot)│ │
20
+ # ▼ ▼
21
+ # [EMOS head] [CAT head] [VAD head]
22
+ # (cần target) (5 lớp) (VAL/ARO/DOM)
23
+ # ```
24
+ #
25
+ # - **Cả 2 backbone ĐÓNG BĂNG** → chỉ trích đặc trưng (cache `.npz`), **chỉ train phần trunk + head nhỏ**
26
+ # → nhẹ GPU, train vài phút, hợp T4. (Né fine-tune end-to-end lúc đầu.)
27
+ # - **EMOS phụ thuộc target** (cùng audio, target khác → điểm khác) → EMOS head nhận thêm one-hot target.
28
+ # **CAT/VAD** là cảm nhận về chính audio → chỉ cần trunk (không cần target).
29
+ # - **Nhãn vàng** gộp theo `wavID` từ `sets/train.csv`:
30
+ # EMOS = TB `eMOS` · VAL/ARO/DOM = TB `val/aro/dom` · CAT = **tỉ lệ vote 5 lớp** của `emoCat`.
31
+ # - **Cân loss = uncertainty weighting** (Kendall 2018): mỗi task có 1 trọng số σ **tự học**
32
+ # → không phải dò tay. Có cờ `USE_UNCERTAINTY=False` để quay về trọng số cố định khi cần debug.
33
+ # - Cuối cùng xuất `answer.txt` **đủ 7 cột**: `wav,QMOS,EMOS,CAT,VAL,ARO,DOM`
34
+ # (QMOS=SpeechMOS · 5 cột còn lại = model fusion) → nộp được ngay. So mốc: EMOS 0.637 · VAD ARO 0.712.
35
+ #
36
+ # **Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**
37
+ # → + Add Input dataset Track 2 (15.477 wav, có `sets/train.csv`, `sets/dev.scp`, `metadata.csv`)
38
+ # → sửa `DATA_ROOT` ở cell 0 → Run All. Lần đầu nên đặt `LIMIT_TRAIN = 300`, `LIMIT_DEV = 20` để bắt lỗi setup.
39
+
40
+ # %% [markdown]
41
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
42
+
43
+ # %%
44
+ import os
45
+
46
+ # ── Data Track 2 (dataset 15.477 wav đã ráp, có sets/train.csv) ──────────────
47
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
48
+ WAV_DIR = f"{DATA_ROOT}/wav"
49
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) → target emotion
50
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # nhãn người nghe: lisID,wavID,qMOS,emoCat,eMOS,val,dom,aro
51
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV (tập cần nộp ở training phase)
52
+
53
+ OUT_DIR = "/kaggle/working"
54
+ CACHE_DIR = "/kaggle/working/fusion_cache" # cache embedding 2 backbone (tái dùng giữa các lần chạy)
55
+ os.makedirs(CACHE_DIR, exist_ok=True)
56
+
57
+ # ── Siêu tham số train ───────────────────────────────────────────────────────
58
+ DEVICE = "cuda" # "cuda" trên Kaggle GPU; "cpu" nếu không có GPU
59
+ TRUNK_HIDDEN = 512 # số neuron lớp trunk chung
60
+ HEAD_HIDDEN = 128 # số neuron lớp ẩn mỗi head
61
+ DROPOUT = 0.3
62
+ LR = 1e-3
63
+ EPOCHS = 80
64
+ BATCH = 64
65
+ VAL_FRAC = 0.10 # 10% train → validation nội bộ (đo SRCC từng task)
66
+ PATIENCE = 15 # early stop theo điểm tổng val (xem SCORE_FOR_STOP)
67
+ SEED = 42
68
+
69
+ USE_UNCERTAINTY = True # True = tự cân loss (Kendall); False = dùng LOSS_W cố định bên dưới
70
+ LOSS_W = {"emos": 1.0, "cat": 1.0, "val": 1.0, "aro": 1.0, "dom": 1.0} # chỉ dùng khi tắt uncertainty
71
+ USE_E2V = True # bật/tắt nhánh emotion2vec trong fusion (để ablation)
72
+ USE_SAILER = True # bật/tắt nhánh SAILER trong fusion (để ablation)
73
+ USE_CLASSPROB = True # thêm xác suất lớp (e2v 5 + sailer 9) + VAD3 của SAILER vào feature
74
+
75
+ LIMIT_TRAIN = None # đặt số nhỏ (vd 300) để chạy thử nhanh; None = full
76
+ LIMIT_DEV = None # đặt số nhỏ (vd 20) để chạy thử nhanh; None = full
77
+
78
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
79
+
80
+ # 9 lớp SAILER (đúng thứ tự model xuất) + chỉ số của 5 lớp challenge trong đó
81
+ SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
82
+ EMO2SAILER = {"angry": 0, "happy": 4, "neutral": 5, "sad": 6, "surprised": 7} # EMOTIONS5 → index trong SAILER9
83
+
84
+ _EMO_ALIAS = {
85
+ "angry": "angry", "anger": "angry",
86
+ "happy": "happy", "happiness": "happy", "joy": "happy",
87
+ "neutral": "neutral", "calm": "neutral",
88
+ "sad": "sad", "sadness": "sad",
89
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
90
+ }
91
+
92
+ def norm_emotion(label):
93
+ """Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp."""
94
+ key = str(label).strip().lower()
95
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
96
+
97
+ def stem(path_or_name):
98
+ """Lấy tên file không đuôi, để khớp wavID giữa train.csv / metadata / dev.scp."""
99
+ return os.path.splitext(os.path.basename(str(path_or_name)))[0]
100
+
101
+ assert USE_E2V or USE_SAILER, "Phải bật ít nhất 1 backbone (USE_E2V hoặc USE_SAILER)."
102
+ print("DATA_ROOT:", DATA_ROOT)
103
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
104
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
105
+
106
+ # %% [markdown]
107
+ # ## 1. Cài đặt + tải code SAILER
108
+ # emotion2vec qua `funasr` (offline). SAILER cần `WavLMWrapper` trong repo `vox-profile-release`
109
+ # → **clone + sys.path** (KHÔNG `pip install -e .` vì build wheel hay lỗi trên Kaggle).
110
+
111
+ # %%
112
+ import sys, subprocess
113
+
114
+ def pip_install(*pkgs):
115
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
116
+
117
+ pip_install("speechmos", "funasr", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
118
+
119
+ if USE_SAILER:
120
+ pip_install("loralib", "speechbrain") # deps WavLMWrapper cần
121
+ REPO_DIR = "/kaggle/working/vox-profile-release"
122
+ if not os.path.exists(REPO_DIR):
123
+ subprocess.run(["git", "clone", "--depth", "1",
124
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
125
+ if REPO_DIR not in sys.path:
126
+ sys.path.insert(0, REPO_DIR)
127
+
128
+ # %% [markdown]
129
+ # ## 2. Đọc & gộp nhãn (gộp theo wavID)
130
+ # - `train.csv`: mỗi dòng = 1 listener chấm 1 wav → gộp **theo wavID**:
131
+ # EMOS=TB `eMOS` · VAL/ARO/DOM=TB `val/aro/dom` · CAT=**tỉ lệ vote 5 lớp** của `emoCat`.
132
+ # - `metadata.csv`: lấy **cảm xúc target** cho mỗi wav (để feed EMOS head).
133
+
134
+ # %%
135
+ import numpy as np
136
+ import pandas as pd
137
+
138
+ def load_target_emotions():
139
+ """metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}."""
140
+ tgt = {}
141
+ with open(METADATA_CSV, encoding="utf-8") as f:
142
+ for ln in f:
143
+ parts = ln.strip().split("|")
144
+ if len(parts) < 2:
145
+ continue
146
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
147
+ return tgt
148
+
149
+ def _col(cols_map, *names, default_idx=None, df=None):
150
+ for n in names:
151
+ if n in cols_map:
152
+ return cols_map[n]
153
+ return list(df.columns)[default_idx] if default_idx is not None else None
154
+
155
+ def parse_emocat_votes(cell):
156
+ """1 ô emoCat (có thể đa nhãn, vd 'happy;surprised') → vector đếm 5 lớp (chưa chuẩn hóa)."""
157
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
158
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
159
+ e = norm_emotion(tok)
160
+ if e in EMOTIONS5:
161
+ v[EMOTIONS5.index(e)] += 1.0
162
+ return v
163
+
164
+ def load_train_labels():
165
+ """train.csv → DataFrame [wavID, emos, val, aro, dom, cat0..cat4] gộp theo wav.
166
+ CAT = tỉ lệ vote 5 lớp (tổng=1); nếu wav không có vote hợp lệ → phân phối đều."""
167
+ # train.csv phân tách bằng "|"; cột emoCat đa nhãn dùng "," bên trong (vd "Angry,Surprised").
168
+ df = pd.read_csv(TRAIN_CSV, sep="|")
169
+ cols = {c.lower().strip(): c for c in df.columns}
170
+ wav_col = _col(cols, "wavid", "wav", default_idx=1, df=df)
171
+ emos_col = _col(cols, "emos", "emo", "emomos")
172
+ val_col = _col(cols, "val", "valence")
173
+ aro_col = _col(cols, "aro", "arousal")
174
+ dom_col = _col(cols, "dom", "dominance")
175
+ cat_col = _col(cols, "emocat", "cat", "emotion")
176
+ assert emos_col, f"Không thấy cột eMOS trong train.csv (cột: {list(df.columns)})"
177
+
178
+ df["_stem"] = df[wav_col].map(stem)
179
+ rows = []
180
+ for sid, g in df.groupby("_stem"):
181
+ rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
182
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
183
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
184
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
185
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
186
+ if cat_col:
187
+ for cell in g[cat_col]:
188
+ votes += parse_emocat_votes(cell)
189
+ s = votes.sum()
190
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)
191
+ for i in range(len(EMOTIONS5)):
192
+ rec[f"cat{i}"] = float(cat[i])
193
+ rows.append(rec)
194
+ return pd.DataFrame(rows)
195
+
196
+ target_map = load_target_emotions()
197
+ train_df = load_train_labels()
198
+ HAS_VAD = bool(train_df["val"].notna().any())
199
+ print(f"Target emotions: {len(target_map)} | wav train (gộp): {len(train_df)} | có nhãn VAD: {HAS_VAD}")
200
+ print("eMOS:", train_df["emos"].describe()[["mean", "std", "min", "max"]].to_dict())
201
+ train_df.head()
202
+
203
+ # %% [markdown]
204
+ # ## 3. Trích đặc trưng 2 backbone (có cache riêng từng model)
205
+ # - **emotion2vec** → embedding + xác suất 5 lớp (như exp02).
206
+ # - **SAILER** → embedding (features) + xác suất 9 lớp + VAD3 (như exp03).
207
+ # Mỗi backbone cache riêng (`e2v_<tag>.npz`, `sailer_<tag>.npz`) → chạy nối tiếp được, đổi 1 backbone
208
+ # không phải trích lại cái kia. Trích xong **giải phóng GPU** rồi mới nạp backbone sau.
209
+
210
+ # %%
211
+ import torch
212
+ import torch.nn.functional as F
213
+
214
+ device = DEVICE if torch.cuda.is_available() else "cpu"
215
+ print("Device:", device)
216
+ if device == "cuda":
217
+ print(" ✅ GPU:", torch.cuda.get_device_name(0))
218
+ else:
219
+ print(" ⚠️ KHÔNG thấy GPU! Trích đặc trưng ~15k file trên CPU rất lâu.")
220
+ print(" → Settings → Accelerator = GPU T4 rồi chạy lại.")
221
+
222
+ # ---- emotion2vec ----
223
+ def extract_e2v(stems, tag):
224
+ """→ dict {stem: (emb[D1], probs5[5])}. Cache CACHE_DIR/e2v_<tag>.npz."""
225
+ from tqdm.auto import tqdm
226
+ cache_path = os.path.join(CACHE_DIR, f"e2v_{tag}.npz")
227
+ store = {}
228
+ if os.path.exists(cache_path):
229
+ z = np.load(cache_path, allow_pickle=True)
230
+ store = {k: z[k] for k in z.files}
231
+ print(f"[e2v/{tag}] nạp cache: {len(store)}")
232
+ todo = [s for s in stems if s not in store]
233
+ if todo:
234
+ from funasr import AutoModel
235
+ m = AutoModel(model="iic/emotion2vec_plus_large", hub="hf", device=device) # ép GPU
236
+ miss = 0
237
+ for i, s in enumerate(tqdm(todo, desc=f"e2v {tag}")):
238
+ wav = os.path.join(WAV_DIR, s + ".wav")
239
+ if not os.path.exists(wav):
240
+ miss += 1; continue
241
+ r = m.generate(wav, granularity="utterance", extract_embedding=True)[0]
242
+ emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
243
+ probs = {e: 0.0 for e in EMOTIONS5}
244
+ for lab, sc in zip(r["labels"], r["scores"]):
245
+ name = lab.split("/")[-1]
246
+ if name in probs:
247
+ probs[name] = float(sc)
248
+ tot = sum(probs.values())
249
+ p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)
250
+ store[s] = np.concatenate([emb, p5]).astype(np.float32) # [D1 + 5]
251
+ if (i + 1) % 500 == 0:
252
+ np.savez(cache_path, **store)
253
+ np.savez(cache_path, **store)
254
+ del m
255
+ torch.cuda.empty_cache() if device == "cuda" else None
256
+ if miss:
257
+ print(f"[e2v/{tag}] {miss} file thiếu → bỏ qua.")
258
+ return {s: (v[:-5], v[-5:]) for s, v in store.items()}
259
+
260
+ # ---- SAILER ----
261
+ def _pool_feat(features):
262
+ """features (tensor) → vector 1 chiều (mean-pool nếu còn chiều thời gian)."""
263
+ f = features.detach().cpu().numpy()
264
+ if f.ndim <= 1:
265
+ return f.reshape(-1).astype(np.float32)
266
+ return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)
267
+
268
+ def extract_sailer(stems, tag):
269
+ """→ dict {stem: (emb[D2], probs9[9], vad3[3] thang 1–5)}. Cache CACHE_DIR/sailer_<tag>.npz.
270
+ Mỗi mẫu lưu vector [emb | probs9(9) | vad3(3)] → cắt lại khi nạp."""
271
+ import librosa
272
+ from tqdm.auto import tqdm
273
+ cache_path = os.path.join(CACHE_DIR, f"sailer_{tag}.npz")
274
+ store = {}
275
+ if os.path.exists(cache_path):
276
+ z = np.load(cache_path, allow_pickle=True)
277
+ store = {k: z[k] for k in z.files}
278
+ print(f"[sailer/{tag}] nạp cache: {len(store)}")
279
+ todo = [s for s in stems if s not in store]
280
+ if todo:
281
+ from src.model.emotion.wavlm_emotion import WavLMWrapper
282
+ sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device).eval()
283
+ miss = 0
284
+ with torch.no_grad():
285
+ for i, s in enumerate(tqdm(todo, desc=f"sailer {tag}")):
286
+ wav = os.path.join(WAV_DIR, s + ".wav")
287
+ if not os.path.exists(wav):
288
+ miss += 1; continue
289
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
290
+ wave = wave[: 15 * 16000]
291
+ data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
292
+ logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
293
+ emb = _pool_feat(feat)
294
+ p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
295
+ vad3 = np.array([1 + 4 * float(valence.item()),
296
+ 1 + 4 * float(arousal.item()),
297
+ 1 + 4 * float(dominance.item())], dtype=np.float32) # [VAL,ARO,DOM]
298
+ store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32) # [D2 + 9 + 3]
299
+ if (i + 1) % 500 == 0:
300
+ np.savez(cache_path, **store)
301
+ np.savez(cache_path, **store)
302
+ del sailer
303
+ torch.cuda.empty_cache() if device == "cuda" else None
304
+ if miss:
305
+ print(f"[sailer/{tag}] {miss} file thiếu → bỏ qua.")
306
+ return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}
307
+
308
+ # %% [markdown]
309
+ # ## 4. Dựng feature + nhãn cho train
310
+ # Feature audio (KHÔNG gồm target) = nối các phần đang bật:
311
+ # `[e2v_emb | e2v_probs5 | sailer_emb | sailer_probs9 | sailer_vad3]`.
312
+ # One-hot target để **riêng** (chỉ EMOS head dùng). Bỏ wav thiếu feature.
313
+
314
+ # %%
315
+ train_stems = list(train_df["wavID"])
316
+ if LIMIT_TRAIN:
317
+ train_stems = train_stems[:LIMIT_TRAIN]
318
+
319
+ e2v_tr = extract_e2v(train_stems, "train") if USE_E2V else {}
320
+ sailer_tr = extract_sailer(train_stems, "train") if USE_SAILER else {}
321
+
322
+ def audio_feature(sid, e2v_map, sailer_map):
323
+ """Nối đặc trưng audio cho 1 wav. None nếu thiếu phần bắt buộc."""
324
+ parts = []
325
+ if USE_E2V:
326
+ pk = e2v_map.get(sid)
327
+ if pk is None:
328
+ return None
329
+ emb, p5 = pk
330
+ parts.append(emb)
331
+ if USE_CLASSPROB:
332
+ parts.append(p5)
333
+ if USE_SAILER:
334
+ pk = sailer_map.get(sid)
335
+ if pk is None:
336
+ return None
337
+ emb, p9, vad3 = pk
338
+ parts.append(emb)
339
+ if USE_CLASSPROB:
340
+ parts.append(p9); parts.append(vad3)
341
+ return np.concatenate(parts).astype(np.float32)
342
+
343
+ def onehot_target(tgt):
344
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
345
+ if tgt in EMOTIONS5:
346
+ v[EMOTIONS5.index(tgt)] = 1.0
347
+ return v
348
+
349
+ lab = train_df.set_index("wavID")
350
+ X, T, y_emos, y_vad, y_cat = [], [], [], [], []
351
+ for s in train_stems:
352
+ f = audio_feature(s, e2v_tr, sailer_tr)
353
+ tgt = target_map.get(s)
354
+ if f is None or tgt is None or s not in lab.index:
355
+ continue
356
+ X.append(f)
357
+ T.append(onehot_target(tgt))
358
+ y_emos.append(lab.loc[s, "emos"])
359
+ y_vad.append([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]])
360
+ y_cat.append([lab.loc[s, f"cat{i}"] for i in range(len(EMOTIONS5))])
361
+
362
+ X = np.stack(X).astype(np.float32)
363
+ T = np.stack(T).astype(np.float32)
364
+ y_emos = np.array(y_emos, dtype=np.float32)
365
+ y_vad = np.array(y_vad, dtype=np.float32) # [N,3] (VAL,ARO,DOM) — có thể toàn NaN nếu thiếu nhãn
366
+ y_cat = np.array(y_cat, dtype=np.float32) # [N,5] phân phối tổng=1
367
+ FEAT_DIM = X.shape[1]
368
+ print(f"Train: X={X.shape} target={T.shape} emos={y_emos.shape} vad={y_vad.shape} cat={y_cat.shape}")
369
+
370
+ # Chuẩn hóa feature audio (z-score) — lưu mean/std để áp dụng y hệt lúc dự đoán DEV.
371
+ feat_mean = X.mean(0, keepdims=True)
372
+ feat_std = X.std(0, keepdims=True) + 1e-6
373
+ Xn = (X - feat_mean) / feat_std
374
+
375
+ # Chuẩn hóa nhãn liên tục (eMOS, VAD) về z-score → các MSE cùng thang (uncertainty weighting ổn định hơn).
376
+ # SRCC bất biến với scale → khi xuất answer.txt chỉ cần đảo z-score về thang gốc cho đẹp.
377
+ emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6)
378
+ y_emos_z = (y_emos - emos_mu) / emos_sd
379
+ if HAS_VAD:
380
+ vad_mu = np.nanmean(y_vad, axis=0)
381
+ vad_sd = np.nanstd(y_vad, axis=0) + 1e-6
382
+ y_vad_z = (y_vad - vad_mu) / vad_sd
383
+ else:
384
+ vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
385
+ y_vad_z = np.zeros_like(y_vad)
386
+
387
+ # %% [markdown]
388
+ # ## 5. Model fusion multi-task + train loop
389
+ # - **Trunk** chung: `Linear(FEAT_DIM→TRUNK_HIDDEN)+ReLU+Dropout` (×2).
390
+ # - **EMOS head**: nối `[trunk | one-hot target]` → MLP → 1 (vì EMOS phụ thuộc target).
391
+ # - **CAT head**: trunk → 5 logits → softmax (dự đoán phân phối vote). Loss = soft-CE (KL).
392
+ # - **VAD head**: trunk → 3 (VAL/ARO/DOM). Loss = MSE (bỏ qua nếu thiếu nhãn VAD).
393
+ # - **Cân loss**: uncertainty weighting — tổng `Σ exp(-sᵢ)·Lᵢ + sᵢ`, `sᵢ=log σᵢ²` **học được**.
394
+
395
+ # %%
396
+ import torch.nn as nn
397
+ from scipy.stats import spearmanr
398
+ from sklearn.model_selection import train_test_split
399
+
400
+ torch.manual_seed(SEED); np.random.seed(SEED)
401
+ N_EMO = len(EMOTIONS5)
402
+
403
+ idx_all = np.arange(X.shape[0])
404
+ tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)
405
+
406
+ def to_t(a):
407
+ return torch.tensor(a, dtype=torch.float32, device=device)
408
+
409
+ Xn_t, T_t = to_t(Xn), to_t(T)
410
+ emos_t = to_t(y_emos_z).unsqueeze(1)
411
+ vad_t = to_t(y_vad_z)
412
+ cat_t = to_t(y_cat)
413
+
414
+ class FusionMTL(nn.Module):
415
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
416
+ super().__init__()
417
+ self.trunk = nn.Sequential(
418
+ nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
419
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p),
420
+ )
421
+ self.emos = nn.Sequential( # nhận [trunk | target]
422
+ nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
423
+ self.cat = nn.Sequential(
424
+ nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
425
+ self.vad = nn.Sequential(
426
+ nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
427
+
428
+ def forward(self, x, tgt):
429
+ h = self.trunk(x)
430
+ emos = self.emos(torch.cat([h, tgt], dim=1))
431
+ cat_logits = self.cat(h)
432
+ vad = self.vad(h)
433
+ return emos, cat_logits, vad
434
+
435
+ model = FusionMTL(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
436
+
437
+ # Trọng số bất định (log σ²) cho 5 task: emos, cat, val, aro, dom.
438
+ TASKS = ["emos", "cat", "val", "aro", "dom"]
439
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
440
+ params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
441
+ opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)
442
+
443
+ mse = nn.MSELoss(reduction="none")
444
+
445
+ def soft_ce(logits, target_dist):
446
+ """Cross-entropy với nhãn mềm (phân phối): −Σ p·log q."""
447
+ logq = F.log_softmax(logits, dim=1)
448
+ return -(target_dist * logq).sum(dim=1)
449
+
450
+ def task_losses(emos_p, cat_logits, vad_p, b):
451
+ """Trả về dict loss TB từng task cho 1 batch (chỉ số b)."""
452
+ L = {}
453
+ L["emos"] = mse(emos_p, emos_t[b]).mean()
454
+ L["cat"] = soft_ce(cat_logits, cat_t[b]).mean()
455
+ if HAS_VAD:
456
+ L["val"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()
457
+ L["aro"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()
458
+ L["dom"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()
459
+ else:
460
+ z = torch.zeros((), device=device)
461
+ L["val"] = L["aro"] = L["dom"] = z
462
+ return L
463
+
464
+ def combine(L):
465
+ """Gộp 5 loss thành 1 số: uncertainty weighting hoặc trọng số cố định."""
466
+ if USE_UNCERTAINTY:
467
+ tot = 0.0
468
+ for i, t in enumerate(TASKS):
469
+ tot = tot + torch.exp(-log_var[i]) * L[t] + log_var[i]
470
+ return tot
471
+ return sum(LOSS_W[t] * L[t] for t in TASKS)
472
+
473
+ @torch.no_grad()
474
+ def eval_val():
475
+ """SRCC từng task trên tập val nội bộ (CAT báo bằng −KL để 'cao=tốt' cho early-stop)."""
476
+ model.eval()
477
+ ep, cl, vp = model(Xn_t[va_idx], T_t[va_idx])
478
+ ep = ep.cpu().numpy().ravel()
479
+ out = {"emos": spearmanr(ep, y_emos[va_idx]).correlation}
480
+ if HAS_VAD:
481
+ vp = vp.cpu().numpy()
482
+ for j, t in enumerate(["val", "aro", "dom"]):
483
+ out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation
484
+ # CAT: dùng −KL(p‖q) trung bình (càng gần 0 càng tốt) → đổi dấu để hợp early-stop
485
+ q = F.softmax(cl, dim=1).cpu().numpy()
486
+ p = y_cat[va_idx]
487
+ kl = (p * (np.log(p + 1e-9) - np.log(q + 1e-9))).sum(1).mean()
488
+ out["cat_negkl"] = float(-kl)
489
+ return out
490
+
491
+ def val_score(m):
492
+ """Điểm tổng để early-stop = TB SRCC các task liên tục có nhãn."""
493
+ keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
494
+ return float(np.mean([m[k] for k in keys]))
495
+
496
+ best_score, best_state, bad = -1e9, None, 0
497
+ tr_t = torch.tensor(tr_idx, device=device)
498
+ for ep in range(1, EPOCHS + 1):
499
+ model.train()
500
+ perm = tr_t[torch.randperm(len(tr_t), device=device)]
501
+ run = 0.0
502
+ for i in range(0, len(perm), BATCH):
503
+ b = perm[i:i + BATCH]
504
+ opt.zero_grad()
505
+ emos_p, cat_logits, vad_p = model(Xn_t[b], T_t[b])
506
+ L = task_losses(emos_p, cat_logits, vad_p, b)
507
+ loss = combine(L)
508
+ loss.backward(); opt.step()
509
+ run += loss.item() * len(b)
510
+ m = eval_val()
511
+ sc = val_score(m)
512
+ if sc > best_score:
513
+ best_score = sc
514
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
515
+ bad = 0
516
+ else:
517
+ bad += 1
518
+ if ep % 5 == 0 or ep == 1:
519
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in m)
520
+ print(f"epoch {ep:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}")
521
+ if bad >= PATIENCE:
522
+ print(f"Early stop ở epoch {ep}.")
523
+ break
524
+
525
+ model.load_state_dict(best_state)
526
+ final = eval_val()
527
+ print("\n✅ VAL (nội bộ) tốt nhất:")
528
+ print(f" EMOS SRCC = {final['emos']:.4f} (so mốc exp01 emotion2vec = 0.637)")
529
+ if HAS_VAD:
530
+ print(f" VAL/ARO/DOM SRCC = {final['val']:.4f} / {final['aro']:.4f} / {final['dom']:.4f}"
531
+ f" (so mốc SAILER = 0.341 / 0.712 / 0.630)")
532
+ if USE_UNCERTAINTY:
533
+ print(" log σ² mỗi task:", {t: round(float(log_var[i]), 3) for i, t in enumerate(TASKS)})
534
+
535
+ # Lưu model + tham số chuẩn hóa.
536
+ torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
537
+ "emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
538
+ "FEAT_DIM": FEAT_DIM, "EMOTIONS5": EMOTIONS5, "HAS_VAD": HAS_VAD,
539
+ "USE_E2V": USE_E2V, "USE_SAILER": USE_SAILER, "USE_CLASSPROB": USE_CLASSPROB,
540
+ "TRUNK_HIDDEN": TRUNK_HIDDEN, "HEAD_HIDDEN": HEAD_HIDDEN, "val_score": best_score},
541
+ os.path.join(OUT_DIR, "fusion_mtl.pt"))
542
+ print("Đã lưu", os.path.join(OUT_DIR, "fusion_mtl.pt"))
543
+
544
+ # %% [markdown]
545
+ # ## 6. Dự đoán DEV → `answer.txt` đầy đủ 7 cột
546
+ # - **EMOS/CAT/VAD** = model fusion (đảo z-score về thang gốc cho EMOS/VAD; CAT = softmax 5 lớp).
547
+ # - **QMOS** = SpeechMOS (UTMOS) — để riêng, đúng thiết kế.
548
+
549
+ # %%
550
+ def list_dev():
551
+ with open(DEV_SCP) as f:
552
+ return [ln.strip() for ln in f if ln.strip()] # tên file .wav
553
+
554
+ dev_names = list_dev()
555
+ if LIMIT_DEV:
556
+ dev_names = dev_names[:LIMIT_DEV]
557
+ dev_stems = [stem(n) for n in dev_names]
558
+ print("DEV:", len(dev_names), "mẫu")
559
+
560
+ # 6a. Trích đặc trưng 2 backbone cho DEV (cache riêng)
561
+ e2v_dev = extract_e2v(dev_stems, "dev") if USE_E2V else {}
562
+ sailer_dev = extract_sailer(dev_stems, "dev") if USE_SAILER else {}
563
+
564
+ # 6b. Dự đoán 5 cột cảm xúc bằng model fusion
565
+ @torch.no_grad()
566
+ def predict_emotion(sid):
567
+ f = audio_feature(sid, e2v_dev, sailer_dev)
568
+ if f is None:
569
+ return None
570
+ fn = (f[None, :] - feat_mean) / feat_std
571
+ tgt = onehot_target(target_map.get(sid))[None, :]
572
+ model.eval()
573
+ emos_p, cat_logits, vad_p = model(to_t(fn), to_t(tgt))
574
+ emos = float(emos_p.item()) * emos_sd + emos_mu # đảo z-score
575
+ cat5 = F.softmax(cat_logits, dim=1)[0].cpu().numpy()
576
+ vad3 = vad_p[0].cpu().numpy() * vad_sd + vad_mu # [VAL,ARO,DOM]
577
+ return emos, cat5, vad3
578
+
579
+ # 6c. QMOS = SpeechMOS (để riêng)
580
+ @torch.no_grad()
581
+ def run_qmos(names):
582
+ import librosa
583
+ from tqdm.auto import tqdm
584
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).to(device).eval()
585
+ out = {}
586
+ for n in tqdm(names, desc="QMOS"):
587
+ p = os.path.join(WAV_DIR, n)
588
+ if not os.path.exists(p):
589
+ continue
590
+ wave, _ = librosa.load(p, sr=16000, mono=True)
591
+ out[n] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device), sr=16000).mean().item())
592
+ return out
593
+
594
+ qmos_scores = run_qmos(dev_names)
595
+
596
+ # %%
597
+ def fmt_cat(probs5):
598
+ return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
599
+
600
+ def build_answer(out_path):
601
+ from tqdm.auto import tqdm
602
+ n_real = n_default = 0
603
+ with open(out_path, "w") as f:
604
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
605
+ for name in tqdm(dev_names, desc="answer"):
606
+ sid = stem(name)
607
+ pred = predict_emotion(sid)
608
+ if pred is None:
609
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])
610
+ n_default += 1
611
+ else:
612
+ emos, cat5, vad3 = pred
613
+ n_real += 1
614
+ qmos = qmos_scores.get(name, 3.0)
615
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
616
+ f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
617
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | fusion thật {n_real}, mặc định {n_default}")
618
+
619
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
620
+ build_answer(answer_path)
621
+
622
+ # %% [markdown]
623
+ # ## 7. Validate + đóng zip
624
+
625
+ # %%
626
+ def validate(path):
627
+ import csv
628
+ with open(path) as f:
629
+ rows = list(csv.reader(f))
630
+ header = rows[0]
631
+ assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
632
+ for i, r in enumerate(rows[1:], 2):
633
+ assert len(r) == len(header), f"Dòng {i} sai số cột"
634
+ print(f"OK: {len(rows)-1} dòng, header = {header}")
635
+
636
+ validate(answer_path)
637
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp04_fusion.zip answer.txt && unzip -l submission_track2_exp04_fusion.zip")
638
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp04_fusion.zip"))
639
+
640
+ # %% [markdown]
641
+ # ## Ghi chú
642
+ # - **Lần đầu**: đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` ở cell 0 để bắt lỗi setup (clone repo / import / model).
643
+ # Chạy OK rồi đặt `None` chạy full.
644
+ # - **VAL SRCC** ở mục 5 là ước lượng nội bộ (10% train) → so mốc EMOS 0.637 / ARO 0.712. Điểm DEV thật
645
+ # phải nộp CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).
646
+ # - Embedding đã cache trong `/kaggle/working/fusion_cache/` → **Save Version** để giữ; lần sau đổi
647
+ # siêu tham số/đổi cách cân loss chỉ train lại head (vài phút), khỏi trích lại.
648
+ # - **Ablation cho paper** (đổi cờ ở cell 0, train lại head):
649
+ # `USE_E2V=False` (chỉ SAILER) · `USE_SAILER=False` (chỉ emotion2vec) · `USE_UNCERTAINTY=False` (trọng số tay)
650
+ # · `USE_CLASSPROB=False` (chỉ embedding) → điền bảng ablation `docs/04_experiments_log.md`.
651
+ # - License SAILER = **Open RAIL (phi thương mại)** → nhắc trong `docs/12_system_description.md`.
652
+ # - Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp04).
track2/exp05_vad_audeering.ipynb ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "40a15eae",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp05 (VAD bằng audeering MSP-dim) — Kaggle\n",
9
+ "\n",
10
+ "**Mục tiêu:** đẩy **VAL** (SAILER chỉ 0.341 — thấp nhất) bằng model VAD chuyên\n",
11
+ "`audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim` (dimensional, xuất thẳng\n",
12
+ "arousal/dominance/valence ∈ [0,1]). **Thay cả 3 cột VAD** bằng audeering.\n",
13
+ "\n",
14
+ "## Phân công model (giữ cái tốt của exp03, chỉ đổi VAD)\n",
15
+ "```\n",
16
+ "QMOS ← SpeechMOS (UTMOS) (để riêng)\n",
17
+ "EMOS ← SAILER (1 + 4·P(target)) ┐ giữ nguyên exp03\n",
18
+ "CAT ← SAILER (5 lớp renorm) ┘\n",
19
+ "VAL ← audeering ┐\n",
20
+ "ARO ← audeering ├─ THAY cả 3 (model VAD chuyên)\n",
21
+ "DOM ← audeering ┘\n",
22
+ "```\n",
23
+ "- Mỗi wav chạy **2 forward**: SAILER (EMOS+CAT) + audeering (VAD). KHÔNG train.\n",
24
+ "- So với exp03 (VAD từ SAILER: VAL 0.341 / ARO 0.712 / DOM 0.630) → nộp để A/B từng cột.\n",
25
+ "\n",
26
+ "**Cách chạy Kaggle:** GPU **T4** + Internet **On** → + Add Input dataset Track 2 (có `sets/dev.scp`,\n",
27
+ "`metadata.csv`) → sửa `DATA_ROOT` → lần đầu `LIMIT = 20` kiểm tra VAD ra 1–5 hợp lý → rồi `None`.\n",
28
+ "\n",
29
+ "⚠️ License **SAILER = Open RAIL** · **audeering = CC BY-NC-SA 4.0** (đều phi thương mại) → khai báo `docs/12_`."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "id": "2c098aff",
35
+ "metadata": {},
36
+ "source": [
37
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "id": "fa143e27",
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "import os\n",
48
+ "\n",
49
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
50
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
51
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript → target emotion (cho EMOS)\n",
52
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV\n",
53
+ "\n",
54
+ "OUT_DIR = \"/kaggle/working\"\n",
55
+ "\n",
56
+ "DEVICE = \"cuda\"\n",
57
+ "MAX_SECONDS = 15\n",
58
+ "SR = 16000\n",
59
+ "LIMIT = None # đặt 20 để chạy thử nhanh; None = full DEV\n",
60
+ "\n",
61
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
62
+ "SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
63
+ "EMO2SAILER = {\"angry\": 0, \"happy\": 4, \"neutral\": 5, \"sad\": 6, \"surprised\": 7}\n",
64
+ "\n",
65
+ "_EMO_ALIAS = {\n",
66
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
67
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
68
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
69
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
70
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
71
+ "}\n",
72
+ "\n",
73
+ "def norm_emotion(label):\n",
74
+ " key = str(label).strip().lower()\n",
75
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
76
+ "\n",
77
+ "def stem(path_or_name):\n",
78
+ " return os.path.splitext(os.path.basename(str(path_or_name)))[0]\n",
79
+ "\n",
80
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
81
+ "for p in [WAV_DIR, METADATA_CSV, DEV_SCP]:\n",
82
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "id": "f2d1dd91",
88
+ "metadata": {},
89
+ "source": [
90
+ "## 1. Cài đặt + tải code SAILER (clone + sys.path, KHÔNG pip install -e .)"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": null,
96
+ "id": "d426b50b",
97
+ "metadata": {},
98
+ "outputs": [],
99
+ "source": [
100
+ "import sys, subprocess\n",
101
+ "\n",
102
+ "def pip_install(*pkgs):\n",
103
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
104
+ "\n",
105
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
106
+ "if not os.path.exists(REPO_DIR):\n",
107
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
108
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
109
+ "\n",
110
+ "pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\", \"scipy\", \"tqdm\")\n",
111
+ "\n",
112
+ "if REPO_DIR not in sys.path:\n",
113
+ " sys.path.insert(0, REPO_DIR)"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "id": "798ad5ef",
119
+ "metadata": {},
120
+ "source": [
121
+ "## 2. Nạp model SAILER (cho EMOS + CAT)"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "code",
126
+ "execution_count": null,
127
+ "id": "5d9ffc83",
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "import torch\n",
132
+ "import torch.nn.functional as F\n",
133
+ "\n",
134
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
135
+ "print(\"Device:\", device)\n",
136
+ "if device == \"cuda\":\n",
137
+ " print(\" ✅ GPU:\", torch.cuda.get_device_name(0))\n",
138
+ "else:\n",
139
+ " print(\" ⚠️ KHÔNG thấy GPU → Settings → Accelerator = GPU T4 rồi chạy lại.\")\n",
140
+ "\n",
141
+ "from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
142
+ "\n",
143
+ "sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device)\n",
144
+ "sailer.eval()\n",
145
+ "print(\"✅ Đã nạp SAILER (wavlm-large-categorical-emotion)\")"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "markdown",
150
+ "id": "6c18fa40",
151
+ "metadata": {},
152
+ "source": [
153
+ "## 2b. Nạp model VAD chuyên: audeering wav2vec2 MSP-dim\n",
154
+ "⚠️ Kế thừa `Wav2Vec2PreTrainedModel` (theo model card) hay dính lỗi version transformers\n",
155
+ "(thiếu `__file__` / `all_tied_weights_keys`...). Cách dứt điểm: CHỈ dùng `Wav2Vec2Model` (backbone\n",
156
+ "được hỗ trợ tốt) + **tự nạp tay** trọng số regression head từ checkpoint → không đụng tie-weights/experts.\n",
157
+ "⚠️ Model xuất thứ tự **[arousal, dominance, valence]** ∈ [0,1] → đổi về [VAL,ARO,DOM] thang 1–5 khi ghi."
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "code",
162
+ "execution_count": null,
163
+ "id": "ddf569cd",
164
+ "metadata": {},
165
+ "outputs": [],
166
+ "source": [
167
+ "import torch.nn as nn\n",
168
+ "from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
169
+ "from huggingface_hub import hf_hub_download\n",
170
+ "\n",
171
+ "AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
172
+ "aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
173
+ "\n",
174
+ "# 1) backbone wav2vec2 (load chuẩn, không subclass)\n",
175
+ "aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
176
+ "aud_backbone = Wav2Vec2Model(aud_cfg)\n",
177
+ "\n",
178
+ "# 2) tải state_dict gốc của checkpoint (ưu tiên safetensors)\n",
179
+ "try:\n",
180
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
181
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
182
+ "except Exception:\n",
183
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
184
+ "\n",
185
+ "# 3) nạp phần backbone (key có tiền tố \"wav2vec2.\") vào Wav2Vec2Model\n",
186
+ "bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
187
+ "missing, unexpected = aud_backbone.load_state_dict(bb_sd, strict=False)\n",
188
+ "print(f\" backbone: thiếu {len(missing)} key, dư {len(unexpected)} key (strict=False)\")\n",
189
+ "\n",
190
+ "# 4) dựng regression head theo đúng shape trong checkpoint rồi nạp trọng số \"classifier.*\"\n",
191
+ "_hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
192
+ "_out = _sd[\"classifier.out_proj.weight\"].shape[0] # = 3 (arousal, dominance, valence)\n",
193
+ "aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
194
+ "aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"])\n",
195
+ "aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
196
+ "aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"])\n",
197
+ "aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
198
+ "\n",
199
+ "aud_backbone = aud_backbone.to(device).eval()\n",
200
+ "aud_head = aud_head.to(device).eval()\n",
201
+ "print(f\"✅ Đã nạp audeering MSP-dim (backbone + head {_hid}→{_out}) — model VAD chuyên\")"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "markdown",
206
+ "id": "849f083a",
207
+ "metadata": {},
208
+ "source": [
209
+ "## 3. Đọc cảm xúc target cho mỗi wav (cho EMOS của SAILER)"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": null,
215
+ "id": "546df027",
216
+ "metadata": {
217
+ "lines_to_next_cell": 1
218
+ },
219
+ "outputs": [],
220
+ "source": [
221
+ "import numpy as np\n",
222
+ "import librosa\n",
223
+ "\n",
224
+ "def load_target_emotions():\n",
225
+ " tgt = {}\n",
226
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
227
+ " for ln in f:\n",
228
+ " parts = ln.strip().split(\"|\")\n",
229
+ " if len(parts) < 2:\n",
230
+ " continue\n",
231
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
232
+ " return tgt\n",
233
+ "\n",
234
+ "target_map = load_target_emotions()\n",
235
+ "print(f\"Target emotions: {len(target_map)} wav | ví dụ:\", dict(list(target_map.items())[:3]))"
236
+ ]
237
+ },
238
+ {
239
+ "cell_type": "markdown",
240
+ "id": "1fee644a",
241
+ "metadata": {},
242
+ "source": [
243
+ "## 4. Hàm chấm: SAILER (EMOS+CAT) + audeering (VAD)"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "id": "54a8ad31",
250
+ "metadata": {
251
+ "lines_to_next_cell": 1
252
+ },
253
+ "outputs": [],
254
+ "source": [
255
+ "@torch.no_grad()\n",
256
+ "def sailer_probs(wav_path):\n",
257
+ " \"\"\"→ probs9 (float32[9]); None nếu thiếu/lỗi. Chỉ lấy 9 lớp (EMOS+CAT), bỏ VAD của SAILER.\"\"\"\n",
258
+ " if not os.path.exists(wav_path):\n",
259
+ " return None\n",
260
+ " wave, _ = librosa.load(wav_path, sr=SR, mono=True)\n",
261
+ " wave = wave[: MAX_SECONDS * SR]\n",
262
+ " data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
263
+ " logits, _feat, _det, _aro, _val, _dom = sailer(data, return_feature=True)\n",
264
+ " return F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
265
+ "\n",
266
+ "def emos_from_probs(probs9, target):\n",
267
+ " if target is None or target not in EMO2SAILER:\n",
268
+ " return None\n",
269
+ " return 1.0 + 4.0 * float(probs9[EMO2SAILER[target]])\n",
270
+ "\n",
271
+ "def cat5_from_probs(probs9):\n",
272
+ " v = np.array([probs9[EMO2SAILER[e]] for e in EMOTIONS5], dtype=np.float32)\n",
273
+ " s = v.sum()\n",
274
+ " return v / s if s > 0 else np.full(5, 0.2, dtype=np.float32)\n",
275
+ "\n",
276
+ "@torch.no_grad()\n",
277
+ "def audeering_vad(wav_path):\n",
278
+ " \"\"\"VAD bằng audeering → [VAL, ARO, DOM] thang 1–5; None nếu thiếu/lỗi.\n",
279
+ " Model xuất [arousal, dominance, valence] ∈ [0,1].\"\"\"\n",
280
+ " if not os.path.exists(wav_path):\n",
281
+ " return None\n",
282
+ " wave, _ = librosa.load(wav_path, sr=SR, mono=True)\n",
283
+ " wave = wave[: MAX_SECONDS * SR]\n",
284
+ " x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
285
+ " x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
286
+ " h = aud_backbone(x)[0].mean(dim=1) # mean-pool theo thời gian\n",
287
+ " out = aud_head(h)[0].detach().cpu().numpy() # [arousal, dominance, valence]\n",
288
+ " aro, dom, val = float(out[0]), float(out[1]), float(out[2])\n",
289
+ " return np.array([1 + 4 * val, 1 + 4 * aro, 1 + 4 * dom], dtype=np.float32) # [VAL,ARO,DOM]"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "markdown",
294
+ "id": "e662c05e",
295
+ "metadata": {},
296
+ "source": [
297
+ "## 5. QMOS = SpeechMOS (UTMOS) — bắt buộc cho answer.txt"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": null,
303
+ "id": "aacc9e34",
304
+ "metadata": {
305
+ "lines_to_next_cell": 1
306
+ },
307
+ "outputs": [],
308
+ "source": [
309
+ "@torch.no_grad()\n",
310
+ "def run_qmos(names):\n",
311
+ " predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True).to(device).eval()\n",
312
+ " from tqdm.auto import tqdm\n",
313
+ " out = {}\n",
314
+ " for n in tqdm(names, desc=\"QMOS\"):\n",
315
+ " p = os.path.join(WAV_DIR, n)\n",
316
+ " if not os.path.exists(p):\n",
317
+ " continue\n",
318
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
319
+ " x = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
320
+ " out[n] = float(predictor(x, sr=SR).mean().item())\n",
321
+ " return out"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "markdown",
326
+ "id": "d6712414",
327
+ "metadata": {},
328
+ "source": [
329
+ "## 6. Chạy trên DEV → `answer.txt` (QMOS, EMOS, CAT ← SAILER/UTMOS · VAL,ARO,DOM ← audeering)"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": null,
335
+ "id": "011b2530",
336
+ "metadata": {
337
+ "lines_to_next_cell": 1
338
+ },
339
+ "outputs": [],
340
+ "source": [
341
+ "def list_dev():\n",
342
+ " with open(DEV_SCP) as f:\n",
343
+ " return [ln.strip() for ln in f if ln.strip()]\n",
344
+ "\n",
345
+ "dev_names = list_dev()\n",
346
+ "if LIMIT:\n",
347
+ " dev_names = dev_names[:LIMIT]\n",
348
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
349
+ "\n",
350
+ "qmos_scores = run_qmos(dev_names)\n",
351
+ "\n",
352
+ "def fmt_cat(probs5):\n",
353
+ " return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
354
+ "\n",
355
+ "def build_answer(out_path):\n",
356
+ " from tqdm.auto import tqdm\n",
357
+ " n_emos = n_default = n_vad_def = 0\n",
358
+ " with open(out_path, \"w\") as f:\n",
359
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
360
+ " for name in tqdm(dev_names, desc=\"EMOS/CAT(SAILER)+VAD(audeering)\"):\n",
361
+ " sid = stem(name)\n",
362
+ " wav = os.path.join(WAV_DIR, name)\n",
363
+ " # EMOS + CAT từ SAILER\n",
364
+ " probs9 = sailer_probs(wav)\n",
365
+ " if probs9 is None:\n",
366
+ " emos, cat5 = 3.0, np.full(5, 0.2, dtype=np.float32); n_default += 1\n",
367
+ " else:\n",
368
+ " emos = emos_from_probs(probs9, target_map.get(sid))\n",
369
+ " if emos is None:\n",
370
+ " emos = 3.0; n_default += 1\n",
371
+ " else:\n",
372
+ " n_emos += 1\n",
373
+ " cat5 = cat5_from_probs(probs9)\n",
374
+ " # VAD từ audeering\n",
375
+ " vad3 = audeering_vad(wav)\n",
376
+ " if vad3 is None:\n",
377
+ " vad3 = np.array([3.0, 3.0, 3.0], dtype=np.float32); n_vad_def += 1\n",
378
+ " qmos = qmos_scores.get(name, 3.0)\n",
379
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
380
+ " f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
381
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default} | VAD mặc định {n_vad_def}\")\n",
382
+ "\n",
383
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
384
+ "build_answer(answer_path)"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "markdown",
389
+ "id": "6afa397f",
390
+ "metadata": {},
391
+ "source": [
392
+ "## 7. Validate + đóng zip"
393
+ ]
394
+ },
395
+ {
396
+ "cell_type": "code",
397
+ "execution_count": null,
398
+ "id": "749f1366",
399
+ "metadata": {},
400
+ "outputs": [],
401
+ "source": [
402
+ "def validate(path):\n",
403
+ " import csv\n",
404
+ " with open(path) as f:\n",
405
+ " rows = list(csv.reader(f))\n",
406
+ " header = rows[0]\n",
407
+ " assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
408
+ " for i, r in enumerate(rows[1:], 2):\n",
409
+ " assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
410
+ " print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
411
+ "\n",
412
+ "validate(answer_path)\n",
413
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp05_vad-audeering.zip answer.txt && unzip -l submission_track2_exp05_vad-audeering.zip\")\n",
414
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp05_vad-audeering.zip\"))"
415
+ ]
416
+ },
417
+ {
418
+ "cell_type": "markdown",
419
+ "id": "69fb16b7",
420
+ "metadata": {},
421
+ "source": [
422
+ "## Ghi chú\n",
423
+ "- **Quan hệ với exp03:** exp03 = SAILER lo cả EMOS+CAT+VAD (giữ nguyên, file `exp03_emos_sailer`).\n",
424
+ " exp05 (file này) chỉ **đổi VAD sang audeering**, EMOS/CAT vẫn SAILER → nộp 2 bản để A/B từng cột VAD.\n",
425
+ "- **Lần đầu** đặt `LIMIT = 20`, kiểm tra VAL/ARO/DOM ∈ [1,5] hợp lý (không toàn 3 / không âm).\n",
426
+ " Nếu giá trị lệch → có thể sai thứ tự arousal/dominance/valence, báo lại để chỉnh.\n",
427
+ "- Khi chạy để ý dòng `backbone: thiếu N key, dư M key`: thiếu/dư vài key phụ là bình thường;\n",
428
+ " thiếu hàng trăm key = sai tiền tố → báo lại.\n",
429
+ "- Nếu audeering thắng VAL nhưng thua ARO/DOM so SAILER → bản tối ưu = trộn cột\n",
430
+ " (VAL từ audeering, ARO/DOM từ exp03). Ghi kết quả vào `docs/04_experiments_log.md` (exp05)."
431
+ ]
432
+ }
433
+ ],
434
+ "metadata": {
435
+ "jupytext": {
436
+ "cell_metadata_filter": "-all",
437
+ "main_language": "python",
438
+ "notebook_metadata_filter": "-all"
439
+ }
440
+ },
441
+ "nbformat": 4,
442
+ "nbformat_minor": 5
443
+ }
track2/exp05_vad_audeering_pipeline.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp05 (VAD bằng audeering MSP-dim) — Kaggle
3
+ #
4
+ # **Mục tiêu:** đẩy **VAL** (SAILER chỉ 0.341 — thấp nhất) bằng model VAD chuyên
5
+ # `audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim` (dimensional, xuất thẳng
6
+ # arousal/dominance/valence ∈ [0,1]). **Thay cả 3 cột VAD** bằng audeering.
7
+ #
8
+ # ## Phân công model (giữ cái tốt của exp03, chỉ đổi VAD)
9
+ # ```
10
+ # QMOS ← SpeechMOS (UTMOS) (để riêng)
11
+ # EMOS ← SAILER (1 + 4·P(target)) ┐ giữ nguyên exp03
12
+ # CAT ← SAILER (5 lớp renorm) ┘
13
+ # VAL ← audeering ┐
14
+ # ARO ← audeering ├─ THAY cả 3 (model VAD chuyên)
15
+ # DOM ← audeering ┘
16
+ # ```
17
+ # - Mỗi wav chạy **2 forward**: SAILER (EMOS+CAT) + audeering (VAD). KHÔNG train.
18
+ # - So với exp03 (VAD từ SAILER: VAL 0.341 / ARO 0.712 / DOM 0.630) → nộp để A/B từng cột.
19
+ #
20
+ # **Cách chạy Kaggle:** GPU **T4** + Internet **On** → + Add Input dataset Track 2 (có `sets/dev.scp`,
21
+ # `metadata.csv`) → sửa `DATA_ROOT` → lần đầu `LIMIT = 20` kiểm tra VAD ra 1–5 hợp lý → rồi `None`.
22
+ #
23
+ # ⚠️ License **SAILER = Open RAIL** · **audeering = CC BY-NC-SA 4.0** (đều phi thương mại) → khai báo `docs/12_`.
24
+
25
+ # %% [markdown]
26
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
27
+
28
+ # %%
29
+ import os
30
+
31
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
32
+ WAV_DIR = f"{DATA_ROOT}/wav"
33
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript → target emotion (cho EMOS)
34
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV
35
+
36
+ OUT_DIR = "/kaggle/working"
37
+
38
+ DEVICE = "cuda"
39
+ MAX_SECONDS = 15
40
+ SR = 16000
41
+ LIMIT = None # đặt 20 để chạy thử nhanh; None = full DEV
42
+
43
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
44
+ SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
45
+ EMO2SAILER = {"angry": 0, "happy": 4, "neutral": 5, "sad": 6, "surprised": 7}
46
+
47
+ _EMO_ALIAS = {
48
+ "angry": "angry", "anger": "angry",
49
+ "happy": "happy", "happiness": "happy", "joy": "happy",
50
+ "neutral": "neutral", "calm": "neutral",
51
+ "sad": "sad", "sadness": "sad",
52
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
53
+ }
54
+
55
+ def norm_emotion(label):
56
+ key = str(label).strip().lower()
57
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
58
+
59
+ def stem(path_or_name):
60
+ return os.path.splitext(os.path.basename(str(path_or_name)))[0]
61
+
62
+ print("DATA_ROOT:", DATA_ROOT)
63
+ for p in [WAV_DIR, METADATA_CSV, DEV_SCP]:
64
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
65
+
66
+ # %% [markdown]
67
+ # ## 1. Cài đặt + tải code SAILER (clone + sys.path, KHÔNG pip install -e .)
68
+
69
+ # %%
70
+ import sys, subprocess
71
+
72
+ def pip_install(*pkgs):
73
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
74
+
75
+ REPO_DIR = "/kaggle/working/vox-profile-release"
76
+ if not os.path.exists(REPO_DIR):
77
+ subprocess.run(["git", "clone", "--depth", "1",
78
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
79
+
80
+ pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile", "scipy", "tqdm")
81
+
82
+ if REPO_DIR not in sys.path:
83
+ sys.path.insert(0, REPO_DIR)
84
+
85
+ # %% [markdown]
86
+ # ## 2. Nạp model SAILER (cho EMOS + CAT)
87
+
88
+ # %%
89
+ import torch
90
+ import torch.nn.functional as F
91
+
92
+ device = DEVICE if torch.cuda.is_available() else "cpu"
93
+ print("Device:", device)
94
+ if device == "cuda":
95
+ print(" ✅ GPU:", torch.cuda.get_device_name(0))
96
+ else:
97
+ print(" ⚠️ KHÔNG thấy GPU → Settings → Accelerator = GPU T4 rồi chạy lại.")
98
+
99
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
100
+
101
+ sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device)
102
+ sailer.eval()
103
+ print("✅ Đã nạp SAILER (wavlm-large-categorical-emotion)")
104
+
105
+ # %% [markdown]
106
+ # ## 2b. Nạp model VAD chuyên: audeering wav2vec2 MSP-dim
107
+ # ⚠️ Kế thừa `Wav2Vec2PreTrainedModel` (theo model card) hay dính lỗi version transformers
108
+ # (thiếu `__file__` / `all_tied_weights_keys`...). Cách dứt điểm: CHỈ dùng `Wav2Vec2Model` (backbone
109
+ # được hỗ trợ tốt) + **tự nạp tay** trọng số regression head từ checkpoint → không đụng tie-weights/experts.
110
+ # ⚠️ Model xuất thứ tự **[arousal, dominance, valence]** ∈ [0,1] → đổi về [VAL,ARO,DOM] thang 1–5 khi ghi.
111
+
112
+ # %%
113
+ import torch.nn as nn
114
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
115
+ from huggingface_hub import hf_hub_download
116
+
117
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
118
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
119
+
120
+ # 1) backbone wav2vec2 (load chuẩn, không subclass)
121
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
122
+ aud_backbone = Wav2Vec2Model(aud_cfg)
123
+
124
+ # 2) tải state_dict gốc của checkpoint (ưu tiên safetensors)
125
+ try:
126
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
127
+ hf_hub_download(AUD_NAME, "model.safetensors"))
128
+ except Exception:
129
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
130
+
131
+ # 3) nạp phần backbone (key có tiền tố "wav2vec2.") vào Wav2Vec2Model
132
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
133
+ missing, unexpected = aud_backbone.load_state_dict(bb_sd, strict=False)
134
+ print(f" backbone: thiếu {len(missing)} key, dư {len(unexpected)} key (strict=False)")
135
+
136
+ # 4) dựng regression head theo đúng shape trong checkpoint rồi nạp trọng số "classifier.*"
137
+ _hid = _sd["classifier.dense.weight"].shape[0]
138
+ _out = _sd["classifier.out_proj.weight"].shape[0] # = 3 (arousal, dominance, valence)
139
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
140
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"])
141
+ aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
142
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"])
143
+ aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
144
+
145
+ aud_backbone = aud_backbone.to(device).eval()
146
+ aud_head = aud_head.to(device).eval()
147
+ print(f"✅ Đã nạp audeering MSP-dim (backbone + head {_hid}→{_out}) — model VAD chuyên")
148
+
149
+ # %% [markdown]
150
+ # ## 3. Đọc cảm xúc target cho mỗi wav (cho EMOS của SAILER)
151
+
152
+ # %%
153
+ import numpy as np
154
+ import librosa
155
+
156
+ def load_target_emotions():
157
+ tgt = {}
158
+ with open(METADATA_CSV, encoding="utf-8") as f:
159
+ for ln in f:
160
+ parts = ln.strip().split("|")
161
+ if len(parts) < 2:
162
+ continue
163
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
164
+ return tgt
165
+
166
+ target_map = load_target_emotions()
167
+ print(f"Target emotions: {len(target_map)} wav | ví dụ:", dict(list(target_map.items())[:3]))
168
+
169
+ # %% [markdown]
170
+ # ## 4. Hàm chấm: SAILER (EMOS+CAT) + audeering (VAD)
171
+
172
+ # %%
173
+ @torch.no_grad()
174
+ def sailer_probs(wav_path):
175
+ """→ probs9 (float32[9]); None nếu thiếu/lỗi. Chỉ lấy 9 lớp (EMOS+CAT), bỏ VAD của SAILER."""
176
+ if not os.path.exists(wav_path):
177
+ return None
178
+ wave, _ = librosa.load(wav_path, sr=SR, mono=True)
179
+ wave = wave[: MAX_SECONDS * SR]
180
+ data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
181
+ logits, _feat, _det, _aro, _val, _dom = sailer(data, return_feature=True)
182
+ return F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
183
+
184
+ def emos_from_probs(probs9, target):
185
+ if target is None or target not in EMO2SAILER:
186
+ return None
187
+ return 1.0 + 4.0 * float(probs9[EMO2SAILER[target]])
188
+
189
+ def cat5_from_probs(probs9):
190
+ v = np.array([probs9[EMO2SAILER[e]] for e in EMOTIONS5], dtype=np.float32)
191
+ s = v.sum()
192
+ return v / s if s > 0 else np.full(5, 0.2, dtype=np.float32)
193
+
194
+ @torch.no_grad()
195
+ def audeering_vad(wav_path):
196
+ """VAD bằng audeering → [VAL, ARO, DOM] thang 1–5; None nếu thiếu/lỗi.
197
+ Model xuất [arousal, dominance, valence] ∈ [0,1]."""
198
+ if not os.path.exists(wav_path):
199
+ return None
200
+ wave, _ = librosa.load(wav_path, sr=SR, mono=True)
201
+ wave = wave[: MAX_SECONDS * SR]
202
+ x = aud_proc(wave, sampling_rate=SR).input_values[0]
203
+ x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
204
+ h = aud_backbone(x)[0].mean(dim=1) # mean-pool theo thời gian
205
+ out = aud_head(h)[0].detach().cpu().numpy() # [arousal, dominance, valence]
206
+ aro, dom, val = float(out[0]), float(out[1]), float(out[2])
207
+ return np.array([1 + 4 * val, 1 + 4 * aro, 1 + 4 * dom], dtype=np.float32) # [VAL,ARO,DOM]
208
+
209
+ # %% [markdown]
210
+ # ## 5. QMOS = SpeechMOS (UTMOS) — bắt buộc cho answer.txt
211
+
212
+ # %%
213
+ @torch.no_grad()
214
+ def run_qmos(names):
215
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).to(device).eval()
216
+ from tqdm.auto import tqdm
217
+ out = {}
218
+ for n in tqdm(names, desc="QMOS"):
219
+ p = os.path.join(WAV_DIR, n)
220
+ if not os.path.exists(p):
221
+ continue
222
+ wave, _ = librosa.load(p, sr=SR, mono=True)
223
+ x = torch.from_numpy(wave).unsqueeze(0).to(device)
224
+ out[n] = float(predictor(x, sr=SR).mean().item())
225
+ return out
226
+
227
+ # %% [markdown]
228
+ # ## 6. Chạy trên DEV → `answer.txt` (QMOS, EMOS, CAT ← SAILER/UTMOS · VAL,ARO,DOM ← audeering)
229
+
230
+ # %%
231
+ def list_dev():
232
+ with open(DEV_SCP) as f:
233
+ return [ln.strip() for ln in f if ln.strip()]
234
+
235
+ dev_names = list_dev()
236
+ if LIMIT:
237
+ dev_names = dev_names[:LIMIT]
238
+ print("DEV:", len(dev_names), "mẫu")
239
+
240
+ qmos_scores = run_qmos(dev_names)
241
+
242
+ def fmt_cat(probs5):
243
+ return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
244
+
245
+ def build_answer(out_path):
246
+ from tqdm.auto import tqdm
247
+ n_emos = n_default = n_vad_def = 0
248
+ with open(out_path, "w") as f:
249
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
250
+ for name in tqdm(dev_names, desc="EMOS/CAT(SAILER)+VAD(audeering)"):
251
+ sid = stem(name)
252
+ wav = os.path.join(WAV_DIR, name)
253
+ # EMOS + CAT từ SAILER
254
+ probs9 = sailer_probs(wav)
255
+ if probs9 is None:
256
+ emos, cat5 = 3.0, np.full(5, 0.2, dtype=np.float32); n_default += 1
257
+ else:
258
+ emos = emos_from_probs(probs9, target_map.get(sid))
259
+ if emos is None:
260
+ emos = 3.0; n_default += 1
261
+ else:
262
+ n_emos += 1
263
+ cat5 = cat5_from_probs(probs9)
264
+ # VAD từ audeering
265
+ vad3 = audeering_vad(wav)
266
+ if vad3 is None:
267
+ vad3 = np.array([3.0, 3.0, 3.0], dtype=np.float32); n_vad_def += 1
268
+ qmos = qmos_scores.get(name, 3.0)
269
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
270
+ f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
271
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default} | VAD mặc định {n_vad_def}")
272
+
273
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
274
+ build_answer(answer_path)
275
+
276
+ # %% [markdown]
277
+ # ## 7. Validate + đóng zip
278
+
279
+ # %%
280
+ def validate(path):
281
+ import csv
282
+ with open(path) as f:
283
+ rows = list(csv.reader(f))
284
+ header = rows[0]
285
+ assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
286
+ for i, r in enumerate(rows[1:], 2):
287
+ assert len(r) == len(header), f"Dòng {i} sai số cột"
288
+ print(f"OK: {len(rows)-1} dòng, header = {header}")
289
+
290
+ validate(answer_path)
291
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp05_vad-audeering.zip answer.txt && unzip -l submission_track2_exp05_vad-audeering.zip")
292
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp05_vad-audeering.zip"))
293
+
294
+ # %% [markdown]
295
+ # ## Ghi chú
296
+ # - **Quan hệ với exp03:** exp03 = SAILER lo cả EMOS+CAT+VAD (giữ nguyên, file `exp03_emos_sailer`).
297
+ # exp05 (file này) chỉ **đổi VAD sang audeering**, EMOS/CAT vẫn SAILER → nộp 2 bản để A/B từng cột VAD.
298
+ # - **Lần đầu** đặt `LIMIT = 20`, kiểm tra VAL/ARO/DOM ∈ [1,5] hợp lý (không toàn 3 / không âm).
299
+ # Nếu giá trị lệch → có thể sai thứ tự arousal/dominance/valence, báo lại để chỉnh.
300
+ # - Khi chạy để ý dòng `backbone: thiếu N key, dư M key`: thiếu/dư vài key phụ là bình thường;
301
+ # thiếu hàng trăm key = sai tiền tố → báo lại.
302
+ # - Nếu audeering thắng VAL nhưng thua ARO/DOM so SAILER → bản tối ưu = trộn cột
303
+ # (VAL từ audeering, ARO/DOM từ exp03). Ghi kết quả vào `docs/04_experiments_log.md` (exp05).
track2/exp06_qmos_train.ipynb ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "e2d94d72",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp06 (TRAIN QMOS head) — Kaggle\n",
9
+ "\n",
10
+ "**Mục tiêu:** QMOS là cột **duy nhất chưa train** (đang dùng UTMOS zero-shot → SRCC kẹt 0.414).\n",
11
+ "`train.csv` CÓ sẵn cột `qMOS` → ta train 1 **head hồi quy nhỏ** trên đặc trưng SSL (đã cache ở exp04)\n",
12
+ "để vượt 0.414.\n",
13
+ "\n",
14
+ "## Ý tưởng (đọc 1 lần cho hiểu)\n",
15
+ "- Tái dùng đặc trưng **emotion2vec + SAILER** đã trích & cache trong `fusion_cache/` (exp04) → KHÔNG trích lại.\n",
16
+ "- Thêm **chính điểm UTMOS** (SpeechMOS) làm 1 đặc trưng đầu vào → head chỉ cần **học chỉnh sửa (residual)**\n",
17
+ " quanh 0.414 thay vì học lại từ đầu → an toàn, gần như chắc chắn ≥ UTMOS đơn lẻ.\n",
18
+ "- Nhãn vàng QMOS = **TB `qMOS` theo wav** (gộp các listener trong `train.csv`).\n",
19
+ "- Có **val nội bộ 10%** → đo SRCC, so thẳng với UTMOS trên CÙNG tập val → biết có cải thiện thật\n",
20
+ " **trước khi** tốn lượt nộp CodaBench.\n",
21
+ "- Cuối cùng: **GIỮ NGUYÊN exp04** (5 cột cảm xúc đang thắng), chỉ **thay cột QMOS** trong `answer.txt`.\n",
22
+ "\n",
23
+ "```\n",
24
+ " mỗi wav ─► [e2v_emb | e2v_probs5 | sailer_emb | sailer_probs9 | sailer_vad3 | UTMOS] ─► MLP ─► QMOS\n",
25
+ " (head train)\n",
26
+ "```\n",
27
+ "\n",
28
+ "**Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**.\n",
29
+ "+ Add Input: (1) dataset Track 2 (15.477 wav, có `sets/train.csv`) ; (2) — nếu có — dataset chứa\n",
30
+ "`fusion_cache/*.npz` đã Save Version ở exp04 (đỡ ~15') ; (3) file `answer.txt` của exp04 để ghép cột.\n",
31
+ "Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để bắt lỗi setup, OK rồi đặt `None`."
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "id": "b42d5d49",
37
+ "metadata": {},
38
+ "source": [
39
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "id": "93e29194",
46
+ "metadata": {},
47
+ "outputs": [],
48
+ "source": [
49
+ "import os\n",
50
+ "\n",
51
+ "# ── Data Track 2 ─────────────────────────────────────────────────────────────\n",
52
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
53
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
54
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # nhãn người nghe: lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
55
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV\n",
56
+ "\n",
57
+ "OUT_DIR = \"/kaggle/working\"\n",
58
+ "# Dùng CHUNG cache với exp04. Nếu đã Save Version cache ở exp04, trỏ CACHE_DIR vào dataset đó\n",
59
+ "# (vd \"/kaggle/input/<slug-cache>/fusion_cache\") để khỏi trích lại; nếu không, để mặc định sẽ tự trích.\n",
60
+ "CACHE_DIR = \"/kaggle/working/fusion_cache\"\n",
61
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
62
+ "\n",
63
+ "# File answer.txt của exp04 (5 cột cảm xúc đang thắng) để GHÉP cột QMOS mới vào.\n",
64
+ "# Trỏ tới nơi bạn đặt file exp04. Nếu không có, notebook vẫn xuất qmos_dev.csv riêng + cảnh báo.\n",
65
+ "EXP04_ANSWER = \"/kaggle/input/exp04-answer/answer.txt\" # << SỬA; hoặc \"/kaggle/working/answer.txt\"\n",
66
+ "\n",
67
+ "# ── Đặc trưng dùng cho QMOS ──────────────────────────────────────────────────\n",
68
+ "USE_E2V = True # nối embedding emotion2vec\n",
69
+ "USE_SAILER = True # nối embedding SAILER/WavLM\n",
70
+ "USE_CLASSPROB = True # nối thêm xác suất lớp (e2v5 + sailer9 + vad3)\n",
71
+ "USE_UTMOS_FEAT = True # nối thêm điểm UTMOS làm 1 đặc trưng (neo residual quanh 0.414)\n",
72
+ "\n",
73
+ "# ── Siêu tham số train head ──────────────────────────────────────────────────\n",
74
+ "DEVICE = \"cuda\"\n",
75
+ "HIDDEN = 256\n",
76
+ "DROPOUT = 0.3\n",
77
+ "LR = 1e-3\n",
78
+ "EPOCHS = 120\n",
79
+ "BATCH = 64\n",
80
+ "VAL_FRAC = 0.10\n",
81
+ "PATIENCE = 20\n",
82
+ "SEED = 42\n",
83
+ "RANK_LAMBDA = 0.0 # 0 = chỉ MSE. >0 (vd 0.2) = cộng thêm pairwise ranking loss (tối ưu thứ hạng=SRCC)\n",
84
+ "\n",
85
+ "LIMIT_TRAIN = None # số nhỏ (vd 300) để chạy thử; None = full\n",
86
+ "LIMIT_DEV = None\n",
87
+ "\n",
88
+ "def stem(p):\n",
89
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
90
+ "\n",
91
+ "assert USE_E2V or USE_SAILER or USE_UTMOS_FEAT, \"Phải bật ít nhất 1 nguồn đặc trưng.\"\n",
92
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
93
+ "for p in [WAV_DIR, TRAIN_CSV, DEV_SCP]:\n",
94
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "id": "47ac221d",
100
+ "metadata": {},
101
+ "source": [
102
+ "## 1. Cài đặt + (nếu cần) tải code SAILER\n",
103
+ "emotion2vec qua `funasr`; SAILER cần `WavLMWrapper` trong repo `vox-profile-release` (clone + sys.path).\n",
104
+ "Nếu cache đã đủ thì các model này sẽ KHÔNG được nạp (chỉ nạp khi còn file phải trích)."
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "99ba1947",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "import sys, subprocess\n",
115
+ "\n",
116
+ "def pip_install(*pkgs):\n",
117
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
118
+ "\n",
119
+ "pip_install(\"speechmos\", \"funasr\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
120
+ "\n",
121
+ "if USE_SAILER:\n",
122
+ " pip_install(\"loralib\", \"speechbrain\")\n",
123
+ " REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
124
+ " if not os.path.exists(REPO_DIR):\n",
125
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
126
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
127
+ " if REPO_DIR not in sys.path:\n",
128
+ " sys.path.insert(0, REPO_DIR)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "id": "ac9dcefc",
134
+ "metadata": {},
135
+ "source": [
136
+ "## 2. Nhãn vàng QMOS (gộp `qMOS` theo wavID)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": null,
142
+ "id": "db4a41a5",
143
+ "metadata": {},
144
+ "outputs": [],
145
+ "source": [
146
+ "import numpy as np\n",
147
+ "import pandas as pd\n",
148
+ "\n",
149
+ "def load_qmos_labels():\n",
150
+ " \"\"\"train.csv (sep='|') → DataFrame [wavID, qmos] với qmos = TB theo wav.\"\"\"\n",
151
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
152
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
153
+ " wav_col = cols.get(\"wavid\") or cols.get(\"wav\") or list(df.columns)[1]\n",
154
+ " qmos_col = cols.get(\"qmos\") or cols.get(\"qMOS\".lower()) or cols.get(\"mos\")\n",
155
+ " assert qmos_col, f\"Không thấy cột qMOS trong train.csv (cột: {list(df.columns)})\"\n",
156
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
157
+ " g = df.groupby(\"_stem\")[qmos_col].mean().reset_index()\n",
158
+ " g.columns = [\"wavID\", \"qmos\"]\n",
159
+ " return g\n",
160
+ "\n",
161
+ "qmos_df = load_qmos_labels()\n",
162
+ "print(f\"wav train (gộp): {len(qmos_df)}\")\n",
163
+ "print(\"qMOS:\", qmos_df[\"qmos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
164
+ "qmos_df.head()"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "markdown",
169
+ "id": "dfd7df0c",
170
+ "metadata": {},
171
+ "source": [
172
+ "## 3. Trích / nạp đặc trưng (cache CHUNG với exp04) + điểm UTMOS\n",
173
+ "- `extract_e2v` / `extract_sailer`: y hệt exp04, cache `e2v_<tag>.npz` / `sailer_<tag>.npz`.\n",
174
+ "- `extract_utmos`: chấm UTMOS từng wav → cache `utmos_<tag>.npz` (dùng vừa làm đặc trưng, vừa làm baseline so sánh)."
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "ec1e63a1",
181
+ "metadata": {
182
+ "lines_to_next_cell": 1
183
+ },
184
+ "outputs": [],
185
+ "source": [
186
+ "import torch\n",
187
+ "import torch.nn.functional as F\n",
188
+ "\n",
189
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
190
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU\")\n",
191
+ "\n",
192
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
193
+ "\n",
194
+ "def extract_e2v(stems, tag):\n",
195
+ " \"\"\"→ dict {stem: emb_full[D1+5]}. Cache CACHE_DIR/e2v_<tag>.npz (giống exp04).\"\"\"\n",
196
+ " from tqdm.auto import tqdm\n",
197
+ " cache_path = os.path.join(CACHE_DIR, f\"e2v_{tag}.npz\")\n",
198
+ " store = {}\n",
199
+ " if os.path.exists(cache_path):\n",
200
+ " z = np.load(cache_path, allow_pickle=True)\n",
201
+ " store = {k: z[k] for k in z.files}\n",
202
+ " print(f\"[e2v/{tag}] nạp cache: {len(store)}\")\n",
203
+ " todo = [s for s in stems if s not in store]\n",
204
+ " if todo:\n",
205
+ " from funasr import AutoModel\n",
206
+ " m = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\", device=device)\n",
207
+ " for i, s in enumerate(tqdm(todo, desc=f\"e2v {tag}\")):\n",
208
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
209
+ " if not os.path.exists(wav):\n",
210
+ " continue\n",
211
+ " r = m.generate(wav, granularity=\"utterance\", extract_embedding=True)[0]\n",
212
+ " emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
213
+ " probs = {e: 0.0 for e in EMOTIONS5}\n",
214
+ " for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
215
+ " name = lab.split(\"/\")[-1]\n",
216
+ " if name in probs:\n",
217
+ " probs[name] = float(sc)\n",
218
+ " tot = sum(probs.values())\n",
219
+ " p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)\n",
220
+ " store[s] = np.concatenate([emb, p5]).astype(np.float32)\n",
221
+ " if (i + 1) % 500 == 0:\n",
222
+ " np.savez(cache_path, **store)\n",
223
+ " np.savez(cache_path, **store)\n",
224
+ " del m\n",
225
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
226
+ " return store # mỗi value = [D1 | 5]\n",
227
+ "\n",
228
+ "def _pool_feat(features):\n",
229
+ " f = features.detach().cpu().numpy()\n",
230
+ " if f.ndim <= 1:\n",
231
+ " return f.reshape(-1).astype(np.float32)\n",
232
+ " return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)\n",
233
+ "\n",
234
+ "def extract_sailer(stems, tag):\n",
235
+ " \"\"\"→ dict {stem: vec[D2+9+3]}. Cache CACHE_DIR/sailer_<tag>.npz (giống exp04).\"\"\"\n",
236
+ " import librosa\n",
237
+ " from tqdm.auto import tqdm\n",
238
+ " cache_path = os.path.join(CACHE_DIR, f\"sailer_{tag}.npz\")\n",
239
+ " store = {}\n",
240
+ " if os.path.exists(cache_path):\n",
241
+ " z = np.load(cache_path, allow_pickle=True)\n",
242
+ " store = {k: z[k] for k in z.files}\n",
243
+ " print(f\"[sailer/{tag}] nạp cache: {len(store)}\")\n",
244
+ " todo = [s for s in stems if s not in store]\n",
245
+ " if todo:\n",
246
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
247
+ " sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device).eval()\n",
248
+ " with torch.no_grad():\n",
249
+ " for i, s in enumerate(tqdm(todo, desc=f\"sailer {tag}\")):\n",
250
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
251
+ " if not os.path.exists(wav):\n",
252
+ " continue\n",
253
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
254
+ " wave = wave[: 15 * 16000]\n",
255
+ " data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
256
+ " logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
257
+ " emb = _pool_feat(feat)\n",
258
+ " p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
259
+ " vad3 = np.array([1 + 4 * float(valence.item()),\n",
260
+ " 1 + 4 * float(arousal.item()),\n",
261
+ " 1 + 4 * float(dominance.item())], dtype=np.float32)\n",
262
+ " store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)\n",
263
+ " if (i + 1) % 500 == 0:\n",
264
+ " np.savez(cache_path, **store)\n",
265
+ " np.savez(cache_path, **store)\n",
266
+ " del sailer\n",
267
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
268
+ " return store # mỗi value = [D2 | 9 | 3]\n",
269
+ "\n",
270
+ "def extract_utmos(names, tag):\n",
271
+ " \"\"\"Chấm UTMOS từng wav (theo TÊN file, vì DEV gọi .wav theo tên). → dict {stem: score}.\n",
272
+ " Cache CACHE_DIR/utmos_<tag>.npz. Dùng vừa làm đặc trưng vừa làm baseline so sánh.\"\"\"\n",
273
+ " import librosa\n",
274
+ " from tqdm.auto import tqdm\n",
275
+ " cache_path = os.path.join(CACHE_DIR, f\"utmos_{tag}.npz\")\n",
276
+ " store = {}\n",
277
+ " if os.path.exists(cache_path):\n",
278
+ " z = np.load(cache_path, allow_pickle=True)\n",
279
+ " store = {k: float(z[k]) for k in z.files}\n",
280
+ " print(f\"[utmos/{tag}] nạp cache: {len(store)}\")\n",
281
+ " todo = [n for n in names if stem(n) not in store]\n",
282
+ " if todo:\n",
283
+ " predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\",\n",
284
+ " trust_repo=True).to(device).eval()\n",
285
+ " with torch.no_grad():\n",
286
+ " for i, n in enumerate(tqdm(todo, desc=f\"utmos {tag}\")):\n",
287
+ " wav = os.path.join(WAV_DIR, n if n.endswith(\".wav\") else n + \".wav\")\n",
288
+ " if not os.path.exists(wav):\n",
289
+ " continue\n",
290
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
291
+ " sc = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device), sr=16000).mean().item())\n",
292
+ " store[stem(n)] = sc\n",
293
+ " if (i + 1) % 500 == 0:\n",
294
+ " np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
295
+ " np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
296
+ " del predictor\n",
297
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
298
+ " return store"
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "id": "aed7338b",
304
+ "metadata": {},
305
+ "source": [
306
+ "## 4. Dựng feature + nhãn cho train"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "id": "c09bb508",
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "train_stems = list(qmos_df[\"wavID\"])\n",
317
+ "if LIMIT_TRAIN:\n",
318
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
319
+ "\n",
320
+ "e2v_tr = extract_e2v(train_stems, \"train\") if USE_E2V else {}\n",
321
+ "sailer_tr = extract_sailer(train_stems, \"train\") if USE_SAILER else {}\n",
322
+ "utmos_tr = extract_utmos(train_stems, \"train\") if USE_UTMOS_FEAT else {}\n",
323
+ "\n",
324
+ "def qmos_feature(sid, e2v_map, sailer_map, utmos_map):\n",
325
+ " \"\"\"Nối đặc trưng QMOS cho 1 wav. None nếu thiếu phần bắt buộc.\"\"\"\n",
326
+ " parts = []\n",
327
+ " if USE_E2V:\n",
328
+ " v = e2v_map.get(sid)\n",
329
+ " if v is None:\n",
330
+ " return None\n",
331
+ " parts.append(v[:-5]) # emb e2v\n",
332
+ " if USE_CLASSPROB:\n",
333
+ " parts.append(v[-5:]) # probs5\n",
334
+ " if USE_SAILER:\n",
335
+ " v = sailer_map.get(sid)\n",
336
+ " if v is None:\n",
337
+ " return None\n",
338
+ " parts.append(v[:-12]) # emb sailer\n",
339
+ " if USE_CLASSPROB:\n",
340
+ " parts.append(v[-12:]) # probs9 + vad3\n",
341
+ " if USE_UTMOS_FEAT:\n",
342
+ " u = utmos_map.get(sid)\n",
343
+ " if u is None:\n",
344
+ " return None\n",
345
+ " parts.append(np.array([u], dtype=np.float32))\n",
346
+ " return np.concatenate(parts).astype(np.float32)\n",
347
+ "\n",
348
+ "lab = qmos_df.set_index(\"wavID\")[\"qmos\"]\n",
349
+ "X, y = [], []\n",
350
+ "for s in train_stems:\n",
351
+ " f = qmos_feature(s, e2v_tr, sailer_tr, utmos_tr)\n",
352
+ " if f is None or s not in lab.index:\n",
353
+ " continue\n",
354
+ " X.append(f)\n",
355
+ " y.append(float(lab.loc[s]))\n",
356
+ "\n",
357
+ "X = np.stack(X).astype(np.float32)\n",
358
+ "y = np.array(y, dtype=np.float32)\n",
359
+ "FEAT_DIM = X.shape[1]\n",
360
+ "print(f\"Train: X={X.shape} y={y.shape}\")\n",
361
+ "\n",
362
+ "feat_mean = X.mean(0, keepdims=True)\n",
363
+ "feat_std = X.std(0, keepdims=True) + 1e-6\n",
364
+ "Xn = (X - feat_mean) / feat_std\n",
365
+ "y_mu, y_sd = float(y.mean()), float(y.std() + 1e-6)\n",
366
+ "yn = (y - y_mu) / y_sd"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "markdown",
371
+ "id": "82cc65f8",
372
+ "metadata": {},
373
+ "source": [
374
+ "## 5. Train head QMOS + so với UTMOS trên CÙNG val nội bộ\n",
375
+ "- Head = MLP nhỏ (`Linear→ReLU→Dropout ×2 → 1`). Loss = MSE (+ tùy chọn pairwise ranking).\n",
376
+ "- In **SRCC head** và **SRCC UTMOS** trên cùng tập val → biết head có thật sự vượt 0.414 không."
377
+ ]
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "execution_count": null,
382
+ "id": "324ab564",
383
+ "metadata": {
384
+ "lines_to_next_cell": 1
385
+ },
386
+ "outputs": [],
387
+ "source": [
388
+ "import torch.nn as nn\n",
389
+ "from scipy.stats import spearmanr\n",
390
+ "from sklearn.model_selection import train_test_split\n",
391
+ "\n",
392
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
393
+ "idx_all = np.arange(X.shape[0])\n",
394
+ "tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)\n",
395
+ "\n",
396
+ "def to_t(a):\n",
397
+ " return torch.tensor(a, dtype=torch.float32, device=device)\n",
398
+ "\n",
399
+ "Xn_t = to_t(Xn); yn_t = to_t(yn).unsqueeze(1)\n",
400
+ "\n",
401
+ "class QMOSHead(nn.Module):\n",
402
+ " def __init__(self, d_in, h, p):\n",
403
+ " super().__init__()\n",
404
+ " self.net = nn.Sequential(\n",
405
+ " nn.Linear(d_in, h), nn.ReLU(), nn.Dropout(p),\n",
406
+ " nn.Linear(h, h), nn.ReLU(), nn.Dropout(p),\n",
407
+ " nn.Linear(h, 1),\n",
408
+ " )\n",
409
+ "\n",
410
+ " def forward(self, x):\n",
411
+ " return self.net(x)\n",
412
+ "\n",
413
+ "model = QMOSHead(FEAT_DIM, HIDDEN, DROPOUT).to(device)\n",
414
+ "opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)\n",
415
+ "mse = nn.MSELoss()\n",
416
+ "\n",
417
+ "def pairwise_rank_loss(pred, target):\n",
418
+ " \"\"\"Khuyến khích pred xếp hạng giống target (margin ranking trên các cặp trong batch).\"\"\"\n",
419
+ " n = pred.shape[0]\n",
420
+ " if n < 2:\n",
421
+ " return torch.zeros((), device=device)\n",
422
+ " pi, pj = pred.unsqueeze(0), pred.unsqueeze(1)\n",
423
+ " ti, tj = target.unsqueeze(0), target.unsqueeze(1)\n",
424
+ " sign = torch.sign(ti - tj) # +1 nếu i nên cao hơn j\n",
425
+ " diff = pi - pj\n",
426
+ " # hinge: phạt khi thứ tự sai\n",
427
+ " return torch.relu(-sign * diff).mean()\n",
428
+ "\n",
429
+ "@torch.no_grad()\n",
430
+ "def eval_val():\n",
431
+ " model.eval()\n",
432
+ " p = model(Xn_t[va_idx]).cpu().numpy().ravel()\n",
433
+ " srcc_head = spearmanr(p, y[va_idx]).correlation\n",
434
+ " out = {\"head\": float(srcc_head)}\n",
435
+ " if USE_UTMOS_FEAT:\n",
436
+ " u = X[va_idx, -1] # cột UTMOS (đặc trưng cuối, chưa chuẩn hóa)\n",
437
+ " out[\"utmos\"] = float(spearmanr(u, y[va_idx]).correlation)\n",
438
+ " return out\n",
439
+ "\n",
440
+ "best, best_state, bad = -1e9, None, 0\n",
441
+ "tr_t = torch.tensor(tr_idx, device=device)\n",
442
+ "for ep in range(1, EPOCHS + 1):\n",
443
+ " model.train()\n",
444
+ " perm = tr_t[torch.randperm(len(tr_t), device=device)]\n",
445
+ " run = 0.0\n",
446
+ " for i in range(0, len(perm), BATCH):\n",
447
+ " b = perm[i:i + BATCH]\n",
448
+ " opt.zero_grad()\n",
449
+ " pred = model(Xn_t[b])\n",
450
+ " loss = mse(pred, yn_t[b])\n",
451
+ " if RANK_LAMBDA > 0:\n",
452
+ " loss = loss + RANK_LAMBDA * pairwise_rank_loss(pred.ravel(), yn_t[b].ravel())\n",
453
+ " loss.backward(); opt.step()\n",
454
+ " run += loss.item() * len(b)\n",
455
+ " m = eval_val()\n",
456
+ " if m[\"head\"] > best:\n",
457
+ " best = m[\"head\"]\n",
458
+ " best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
459
+ " bad = 0\n",
460
+ " else:\n",
461
+ " bad += 1\n",
462
+ " if ep % 5 == 0 or ep == 1:\n",
463
+ " extra = f\" | UTMOS={m['utmos']:.4f}\" if \"utmos\" in m else \"\"\n",
464
+ " print(f\"epoch {ep:3d} | loss {run/len(perm):.4f} | head SRCC={m['head']:.4f}{extra} | best {best:.4f}\")\n",
465
+ " if bad >= PATIENCE:\n",
466
+ " print(f\"Early stop ở epoch {ep}.\")\n",
467
+ " break\n",
468
+ "\n",
469
+ "model.load_state_dict(best_state)\n",
470
+ "final = eval_val()\n",
471
+ "print(\"\\n✅ VAL (nội bộ):\")\n",
472
+ "print(f\" QMOS head SRCC = {final['head']:.4f}\")\n",
473
+ "if \"utmos\" in final:\n",
474
+ " print(f\" UTMOS baseline = {final['utmos']:.4f} (mốc leaderboard 0.414)\")\n",
475
+ " print(\" →\", \"✅ HEAD VƯỢT UTMOS\" if final[\"head\"] > final[\"utmos\"] else \"⚠️ chưa vượt — thử tăng EPOCHS / RANK_LAMBDA / bật thêm đặc trưng\")\n",
476
+ "\n",
477
+ "torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
478
+ " \"y_mu\": y_mu, \"y_sd\": y_sd, \"FEAT_DIM\": FEAT_DIM,\n",
479
+ " \"USE_E2V\": USE_E2V, \"USE_SAILER\": USE_SAILER,\n",
480
+ " \"USE_CLASSPROB\": USE_CLASSPROB, \"USE_UTMOS_FEAT\": USE_UTMOS_FEAT,\n",
481
+ " \"val_srcc\": best}, os.path.join(OUT_DIR, \"qmos_head.pt\"))\n",
482
+ "print(\"Đã lưu\", os.path.join(OUT_DIR, \"qmos_head.pt\"))"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "markdown",
487
+ "id": "d33a7aca",
488
+ "metadata": {},
489
+ "source": [
490
+ "## 6. Dự đoán QMOS cho DEV"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": null,
496
+ "id": "69efbd00",
497
+ "metadata": {
498
+ "lines_to_next_cell": 1
499
+ },
500
+ "outputs": [],
501
+ "source": [
502
+ "def list_dev():\n",
503
+ " with open(DEV_SCP) as f:\n",
504
+ " return [ln.strip() for ln in f if ln.strip()]\n",
505
+ "\n",
506
+ "dev_names = list_dev()\n",
507
+ "if LIMIT_DEV:\n",
508
+ " dev_names = dev_names[:LIMIT_DEV]\n",
509
+ "dev_stems = [stem(n) for n in dev_names]\n",
510
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
511
+ "\n",
512
+ "e2v_dev = extract_e2v(dev_stems, \"dev\") if USE_E2V else {}\n",
513
+ "sailer_dev = extract_sailer(dev_stems, \"dev\") if USE_SAILER else {}\n",
514
+ "utmos_dev = extract_utmos(dev_names, \"dev\") if USE_UTMOS_FEAT else {}\n",
515
+ "\n",
516
+ "@torch.no_grad()\n",
517
+ "def predict_qmos(sid):\n",
518
+ " f = qmos_feature(sid, e2v_dev, sailer_dev, utmos_dev)\n",
519
+ " if f is None:\n",
520
+ " return None\n",
521
+ " fn = (f[None, :] - feat_mean) / feat_std\n",
522
+ " model.eval()\n",
523
+ " return float(model(to_t(fn)).item()) * y_sd + y_mu # đảo z-score\n",
524
+ "\n",
525
+ "qmos_pred = {}\n",
526
+ "n_real = n_def = 0\n",
527
+ "for n in dev_names:\n",
528
+ " sid = stem(n)\n",
529
+ " p = predict_qmos(sid)\n",
530
+ " if p is None:\n",
531
+ " p = utmos_dev.get(sid, 3.0) # rơi về UTMOS nếu thiếu feature\n",
532
+ " n_def += 1\n",
533
+ " else:\n",
534
+ " n_real += 1\n",
535
+ " qmos_pred[n] = p\n",
536
+ "print(f\"QMOS dự đoán: head thật {n_real}, dự phòng UTMOS {n_def}\")\n",
537
+ "\n",
538
+ "# Lưu riêng (để ghép tay nếu cần)\n",
539
+ "import csv\n",
540
+ "qmos_csv = os.path.join(OUT_DIR, \"qmos_dev.csv\")\n",
541
+ "with open(qmos_csv, \"w\", newline=\"\") as f:\n",
542
+ " w = csv.writer(f); w.writerow([\"wav\", \"QMOS\"])\n",
543
+ " for n in dev_names:\n",
544
+ " w.writerow([n, f\"{qmos_pred[n]:.6g}\"])\n",
545
+ "print(\"Đã ghi\", qmos_csv)"
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "markdown",
550
+ "id": "f3e47def",
551
+ "metadata": {},
552
+ "source": [
553
+ "## 7. Ghép QMOS mới vào answer.txt của exp04 → bản nộp mới\n",
554
+ "Giữ NGUYÊN 5 cột cảm xúc đang thắng (EMOS/CAT/VAL/ARO/DOM), chỉ thay cột QMOS."
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": null,
560
+ "id": "a3b94589",
561
+ "metadata": {},
562
+ "outputs": [],
563
+ "source": [
564
+ "def merge_into_exp04(exp04_path, out_path):\n",
565
+ " if not os.path.exists(exp04_path):\n",
566
+ " print(f\"⚠️ Không thấy {exp04_path} → BỎ QUA ghép. Hãy dùng qmos_dev.csv để thay cột QMOS thủ công,\")\n",
567
+ " print(\" hoặc trỏ EXP04_ANSWER đúng đường dẫn answer.txt của exp04 rồi chạy lại cell này.\")\n",
568
+ " return False\n",
569
+ " with open(exp04_path) as f:\n",
570
+ " rows = list(csv.reader(f))\n",
571
+ " header = rows[0]\n",
572
+ " qi = header.index(\"QMOS\")\n",
573
+ " wi = header.index(\"wav\")\n",
574
+ " n_swapped = n_miss = 0\n",
575
+ " with open(out_path, \"w\", newline=\"\") as f:\n",
576
+ " w = csv.writer(f); w.writerow(header)\n",
577
+ " for r in rows[1:]:\n",
578
+ " name = r[wi]\n",
579
+ " if name in qmos_pred:\n",
580
+ " r[qi] = f\"{qmos_pred[name]:.6g}\"; n_swapped += 1\n",
581
+ " else:\n",
582
+ " n_miss += 1\n",
583
+ " w.writerow(r)\n",
584
+ " print(f\"Ghép xong → {out_path} | thay {n_swapped} cột QMOS, thiếu {n_miss} (giữ QMOS cũ)\")\n",
585
+ " return True\n",
586
+ "\n",
587
+ "merged = os.path.join(OUT_DIR, \"answer.txt\")\n",
588
+ "ok = merge_into_exp04(EXP04_ANSWER, merged)\n",
589
+ "\n",
590
+ "if ok:\n",
591
+ " # validate + zip\n",
592
+ " with open(merged) as f:\n",
593
+ " rows = list(csv.reader(f))\n",
594
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0]\n",
595
+ " for i, r in enumerate(rows[1:], 2):\n",
596
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
597
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
598
+ " os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp06_qmos.zip answer.txt \"\n",
599
+ " f\"&& unzip -l submission_track2_exp06_qmos.zip\")\n",
600
+ " print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp06_qmos.zip\"))"
601
+ ]
602
+ },
603
+ {
604
+ "cell_type": "markdown",
605
+ "id": "0a517b97",
606
+ "metadata": {},
607
+ "source": [
608
+ "## Ghi chú\n",
609
+ "- **Lần đầu** đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để bắt lỗi; OK rồi đặt `None`.\n",
610
+ "- **So sánh công bằng**: mục 5 in cả `head SRCC` và `UTMOS SRCC` trên CÙNG val nội bộ → chỉ nộp khi head > UTMOS.\n",
611
+ "- Nếu head **chưa vượt** 0.414: thử (a) tăng `EPOCHS`; (b) bật `RANK_LAMBDA=0.2` (tối ưu thứ hạng);\n",
612
+ " (c) đảm bảo `USE_UTMOS_FEAT=True` (neo residual); (d) thử bỏ bớt đặc trưng nhiễu (tắt `USE_CLASSPROB`).\n",
613
+ "- **Ablation QMOS cho paper**: bật/tắt `USE_E2V/USE_SAILER/USE_UTMOS_FEAT/USE_CLASSPROB` → ghi `docs/04_experiments_log.md` (exp06).\n",
614
+ "- Cache dùng CHUNG `fusion_cache/` với exp04 → nhớ **Save Version** giữ lại (gồm `utmos_*.npz` mới).\n",
615
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp06)."
616
+ ]
617
+ }
618
+ ],
619
+ "metadata": {
620
+ "jupytext": {
621
+ "cell_metadata_filter": "-all",
622
+ "main_language": "python",
623
+ "notebook_metadata_filter": "-all"
624
+ }
625
+ },
626
+ "nbformat": 4,
627
+ "nbformat_minor": 5
628
+ }
track2/exp06_qmos_train_pipeline.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp06 (TRAIN QMOS head) — Kaggle
3
+ #
4
+ # **Mục tiêu:** QMOS là cột **duy nhất chưa train** (đang dùng UTMOS zero-shot → SRCC kẹt 0.414).
5
+ # `train.csv` CÓ sẵn cột `qMOS` → ta train 1 **head hồi quy nhỏ** trên đặc trưng SSL (đã cache ở exp04)
6
+ # để vượt 0.414.
7
+ #
8
+ # ## Ý tưởng (đọc 1 lần cho hiểu)
9
+ # - Tái dùng đặc trưng **emotion2vec + SAILER** đã trích & cache trong `fusion_cache/` (exp04) → KHÔNG trích lại.
10
+ # - Thêm **chính điểm UTMOS** (SpeechMOS) làm 1 đặc trưng đầu vào → head chỉ cần **học chỉnh sửa (residual)**
11
+ # quanh 0.414 thay vì học lại từ đầu → an toàn, gần như chắc chắn ≥ UTMOS đơn lẻ.
12
+ # - Nhãn vàng QMOS = **TB `qMOS` theo wav** (gộp các listener trong `train.csv`).
13
+ # - Có **val nội bộ 10%** → đo SRCC, so thẳng với UTMOS trên CÙNG tập val → biết có cải thiện thật
14
+ # **trước khi** tốn lượt nộp CodaBench.
15
+ # - Cuối cùng: **GIỮ NGUYÊN exp04** (5 cột cảm xúc đang thắng), chỉ **thay cột QMOS** trong `answer.txt`.
16
+ #
17
+ # ```
18
+ # mỗi wav ─► [e2v_emb | e2v_probs5 | sailer_emb | sailer_probs9 | sailer_vad3 | UTMOS] ─► MLP ─► QMOS
19
+ # (head train)
20
+ # ```
21
+ #
22
+ # **Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**.
23
+ # + Add Input: (1) dataset Track 2 (15.477 wav, có `sets/train.csv`) ; (2) — nếu có — dataset chứa
24
+ # `fusion_cache/*.npz` đã Save Version ở exp04 (đỡ ~15') ; (3) file `answer.txt` của exp04 để ghép cột.
25
+ # Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để bắt lỗi setup, OK rồi đặt `None`.
26
+
27
+ # %% [markdown]
28
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
29
+
30
+ # %%
31
+ import os
32
+
33
+ # ── Data Track 2 ─────────────────────────────────────────────────────────────
34
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
35
+ WAV_DIR = f"{DATA_ROOT}/wav"
36
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # nhãn người nghe: lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
37
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV
38
+
39
+ OUT_DIR = "/kaggle/working"
40
+ # Dùng CHUNG cache với exp04. Nếu đã Save Version cache ở exp04, trỏ CACHE_DIR vào dataset đó
41
+ # (vd "/kaggle/input/<slug-cache>/fusion_cache") để khỏi trích lại; nếu không, để mặc định sẽ tự trích.
42
+ CACHE_DIR = "/kaggle/working/fusion_cache"
43
+ os.makedirs(CACHE_DIR, exist_ok=True)
44
+
45
+ # File answer.txt của exp04 (5 cột cảm xúc đang thắng) để GHÉP cột QMOS mới vào.
46
+ # Trỏ tới nơi bạn đặt file exp04. Nếu không có, notebook vẫn xuất qmos_dev.csv riêng + cảnh báo.
47
+ EXP04_ANSWER = "/kaggle/input/exp04-answer/answer.txt" # << SỬA; hoặc "/kaggle/working/answer.txt"
48
+
49
+ # ── Đặc trưng dùng cho QMOS ──────────────────────────────────────────────────
50
+ USE_E2V = True # nối embedding emotion2vec
51
+ USE_SAILER = True # nối embedding SAILER/WavLM
52
+ USE_CLASSPROB = True # nối thêm xác suất lớp (e2v5 + sailer9 + vad3)
53
+ USE_UTMOS_FEAT = True # nối thêm điểm UTMOS làm 1 đặc trưng (neo residual quanh 0.414)
54
+
55
+ # ── Siêu tham số train head ──────────────────────────────────────────────────
56
+ DEVICE = "cuda"
57
+ HIDDEN = 256
58
+ DROPOUT = 0.3
59
+ LR = 1e-3
60
+ EPOCHS = 120
61
+ BATCH = 64
62
+ VAL_FRAC = 0.10
63
+ PATIENCE = 20
64
+ SEED = 42
65
+ RANK_LAMBDA = 0.0 # 0 = chỉ MSE. >0 (vd 0.2) = cộng thêm pairwise ranking loss (tối ưu thứ hạng=SRCC)
66
+
67
+ LIMIT_TRAIN = None # số nhỏ (vd 300) để chạy thử; None = full
68
+ LIMIT_DEV = None
69
+
70
+ def stem(p):
71
+ return os.path.splitext(os.path.basename(str(p)))[0]
72
+
73
+ assert USE_E2V or USE_SAILER or USE_UTMOS_FEAT, "Phải bật ít nhất 1 nguồn đặc trưng."
74
+ print("DATA_ROOT:", DATA_ROOT)
75
+ for p in [WAV_DIR, TRAIN_CSV, DEV_SCP]:
76
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
77
+
78
+ # %% [markdown]
79
+ # ## 1. Cài đặt + (nếu cần) tải code SAILER
80
+ # emotion2vec qua `funasr`; SAILER cần `WavLMWrapper` trong repo `vox-profile-release` (clone + sys.path).
81
+ # Nếu cache đã đủ thì các model này sẽ KHÔNG được nạp (chỉ nạp khi còn file phải trích).
82
+
83
+ # %%
84
+ import sys, subprocess
85
+
86
+ def pip_install(*pkgs):
87
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
88
+
89
+ pip_install("speechmos", "funasr", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
90
+
91
+ if USE_SAILER:
92
+ pip_install("loralib", "speechbrain")
93
+ REPO_DIR = "/kaggle/working/vox-profile-release"
94
+ if not os.path.exists(REPO_DIR):
95
+ subprocess.run(["git", "clone", "--depth", "1",
96
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
97
+ if REPO_DIR not in sys.path:
98
+ sys.path.insert(0, REPO_DIR)
99
+
100
+ # %% [markdown]
101
+ # ## 2. Nhãn vàng QMOS (gộp `qMOS` theo wavID)
102
+
103
+ # %%
104
+ import numpy as np
105
+ import pandas as pd
106
+
107
+ def load_qmos_labels():
108
+ """train.csv (sep='|') → DataFrame [wavID, qmos] với qmos = TB theo wav."""
109
+ df = pd.read_csv(TRAIN_CSV, sep="|")
110
+ cols = {c.lower().strip(): c for c in df.columns}
111
+ wav_col = cols.get("wavid") or cols.get("wav") or list(df.columns)[1]
112
+ qmos_col = cols.get("qmos") or cols.get("qMOS".lower()) or cols.get("mos")
113
+ assert qmos_col, f"Không thấy cột qMOS trong train.csv (cột: {list(df.columns)})"
114
+ df["_stem"] = df[wav_col].map(stem)
115
+ g = df.groupby("_stem")[qmos_col].mean().reset_index()
116
+ g.columns = ["wavID", "qmos"]
117
+ return g
118
+
119
+ qmos_df = load_qmos_labels()
120
+ print(f"wav train (gộp): {len(qmos_df)}")
121
+ print("qMOS:", qmos_df["qmos"].describe()[["mean", "std", "min", "max"]].to_dict())
122
+ qmos_df.head()
123
+
124
+ # %% [markdown]
125
+ # ## 3. Trích / nạp đặc trưng (cache CHUNG với exp04) + điểm UTMOS
126
+ # - `extract_e2v` / `extract_sailer`: y hệt exp04, cache `e2v_<tag>.npz` / `sailer_<tag>.npz`.
127
+ # - `extract_utmos`: chấm UTMOS từng wav → cache `utmos_<tag>.npz` (dùng vừa làm đặc trưng, vừa làm baseline so sánh).
128
+
129
+ # %%
130
+ import torch
131
+ import torch.nn.functional as F
132
+
133
+ device = DEVICE if torch.cuda.is_available() else "cpu"
134
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU")
135
+
136
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
137
+
138
+ def extract_e2v(stems, tag):
139
+ """→ dict {stem: emb_full[D1+5]}. Cache CACHE_DIR/e2v_<tag>.npz (giống exp04)."""
140
+ from tqdm.auto import tqdm
141
+ cache_path = os.path.join(CACHE_DIR, f"e2v_{tag}.npz")
142
+ store = {}
143
+ if os.path.exists(cache_path):
144
+ z = np.load(cache_path, allow_pickle=True)
145
+ store = {k: z[k] for k in z.files}
146
+ print(f"[e2v/{tag}] nạp cache: {len(store)}")
147
+ todo = [s for s in stems if s not in store]
148
+ if todo:
149
+ from funasr import AutoModel
150
+ m = AutoModel(model="iic/emotion2vec_plus_large", hub="hf", device=device)
151
+ for i, s in enumerate(tqdm(todo, desc=f"e2v {tag}")):
152
+ wav = os.path.join(WAV_DIR, s + ".wav")
153
+ if not os.path.exists(wav):
154
+ continue
155
+ r = m.generate(wav, granularity="utterance", extract_embedding=True)[0]
156
+ emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
157
+ probs = {e: 0.0 for e in EMOTIONS5}
158
+ for lab, sc in zip(r["labels"], r["scores"]):
159
+ name = lab.split("/")[-1]
160
+ if name in probs:
161
+ probs[name] = float(sc)
162
+ tot = sum(probs.values())
163
+ p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)
164
+ store[s] = np.concatenate([emb, p5]).astype(np.float32)
165
+ if (i + 1) % 500 == 0:
166
+ np.savez(cache_path, **store)
167
+ np.savez(cache_path, **store)
168
+ del m
169
+ torch.cuda.empty_cache() if device == "cuda" else None
170
+ return store # mỗi value = [D1 | 5]
171
+
172
+ def _pool_feat(features):
173
+ f = features.detach().cpu().numpy()
174
+ if f.ndim <= 1:
175
+ return f.reshape(-1).astype(np.float32)
176
+ return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)
177
+
178
+ def extract_sailer(stems, tag):
179
+ """→ dict {stem: vec[D2+9+3]}. Cache CACHE_DIR/sailer_<tag>.npz (giống exp04)."""
180
+ import librosa
181
+ from tqdm.auto import tqdm
182
+ cache_path = os.path.join(CACHE_DIR, f"sailer_{tag}.npz")
183
+ store = {}
184
+ if os.path.exists(cache_path):
185
+ z = np.load(cache_path, allow_pickle=True)
186
+ store = {k: z[k] for k in z.files}
187
+ print(f"[sailer/{tag}] nạp cache: {len(store)}")
188
+ todo = [s for s in stems if s not in store]
189
+ if todo:
190
+ from src.model.emotion.wavlm_emotion import WavLMWrapper
191
+ sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device).eval()
192
+ with torch.no_grad():
193
+ for i, s in enumerate(tqdm(todo, desc=f"sailer {tag}")):
194
+ wav = os.path.join(WAV_DIR, s + ".wav")
195
+ if not os.path.exists(wav):
196
+ continue
197
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
198
+ wave = wave[: 15 * 16000]
199
+ data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
200
+ logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
201
+ emb = _pool_feat(feat)
202
+ p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
203
+ vad3 = np.array([1 + 4 * float(valence.item()),
204
+ 1 + 4 * float(arousal.item()),
205
+ 1 + 4 * float(dominance.item())], dtype=np.float32)
206
+ store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)
207
+ if (i + 1) % 500 == 0:
208
+ np.savez(cache_path, **store)
209
+ np.savez(cache_path, **store)
210
+ del sailer
211
+ torch.cuda.empty_cache() if device == "cuda" else None
212
+ return store # mỗi value = [D2 | 9 | 3]
213
+
214
+ def extract_utmos(names, tag):
215
+ """Chấm UTMOS từng wav (theo TÊN file, vì DEV gọi .wav theo tên). → dict {stem: score}.
216
+ Cache CACHE_DIR/utmos_<tag>.npz. Dùng vừa làm đặc trưng vừa làm baseline so sánh."""
217
+ import librosa
218
+ from tqdm.auto import tqdm
219
+ cache_path = os.path.join(CACHE_DIR, f"utmos_{tag}.npz")
220
+ store = {}
221
+ if os.path.exists(cache_path):
222
+ z = np.load(cache_path, allow_pickle=True)
223
+ store = {k: float(z[k]) for k in z.files}
224
+ print(f"[utmos/{tag}] nạp cache: {len(store)}")
225
+ todo = [n for n in names if stem(n) not in store]
226
+ if todo:
227
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong",
228
+ trust_repo=True).to(device).eval()
229
+ with torch.no_grad():
230
+ for i, n in enumerate(tqdm(todo, desc=f"utmos {tag}")):
231
+ wav = os.path.join(WAV_DIR, n if n.endswith(".wav") else n + ".wav")
232
+ if not os.path.exists(wav):
233
+ continue
234
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
235
+ sc = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device), sr=16000).mean().item())
236
+ store[stem(n)] = sc
237
+ if (i + 1) % 500 == 0:
238
+ np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
239
+ np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
240
+ del predictor
241
+ torch.cuda.empty_cache() if device == "cuda" else None
242
+ return store
243
+
244
+ # %% [markdown]
245
+ # ## 4. Dựng feature + nhãn cho train
246
+
247
+ # %%
248
+ train_stems = list(qmos_df["wavID"])
249
+ if LIMIT_TRAIN:
250
+ train_stems = train_stems[:LIMIT_TRAIN]
251
+
252
+ e2v_tr = extract_e2v(train_stems, "train") if USE_E2V else {}
253
+ sailer_tr = extract_sailer(train_stems, "train") if USE_SAILER else {}
254
+ utmos_tr = extract_utmos(train_stems, "train") if USE_UTMOS_FEAT else {}
255
+
256
+ def qmos_feature(sid, e2v_map, sailer_map, utmos_map):
257
+ """Nối đặc trưng QMOS cho 1 wav. None nếu thiếu phần bắt buộc."""
258
+ parts = []
259
+ if USE_E2V:
260
+ v = e2v_map.get(sid)
261
+ if v is None:
262
+ return None
263
+ parts.append(v[:-5]) # emb e2v
264
+ if USE_CLASSPROB:
265
+ parts.append(v[-5:]) # probs5
266
+ if USE_SAILER:
267
+ v = sailer_map.get(sid)
268
+ if v is None:
269
+ return None
270
+ parts.append(v[:-12]) # emb sailer
271
+ if USE_CLASSPROB:
272
+ parts.append(v[-12:]) # probs9 + vad3
273
+ if USE_UTMOS_FEAT:
274
+ u = utmos_map.get(sid)
275
+ if u is None:
276
+ return None
277
+ parts.append(np.array([u], dtype=np.float32))
278
+ return np.concatenate(parts).astype(np.float32)
279
+
280
+ lab = qmos_df.set_index("wavID")["qmos"]
281
+ X, y = [], []
282
+ for s in train_stems:
283
+ f = qmos_feature(s, e2v_tr, sailer_tr, utmos_tr)
284
+ if f is None or s not in lab.index:
285
+ continue
286
+ X.append(f)
287
+ y.append(float(lab.loc[s]))
288
+
289
+ X = np.stack(X).astype(np.float32)
290
+ y = np.array(y, dtype=np.float32)
291
+ FEAT_DIM = X.shape[1]
292
+ print(f"Train: X={X.shape} y={y.shape}")
293
+
294
+ feat_mean = X.mean(0, keepdims=True)
295
+ feat_std = X.std(0, keepdims=True) + 1e-6
296
+ Xn = (X - feat_mean) / feat_std
297
+ y_mu, y_sd = float(y.mean()), float(y.std() + 1e-6)
298
+ yn = (y - y_mu) / y_sd
299
+
300
+ # %% [markdown]
301
+ # ## 5. Train head QMOS + so với UTMOS trên CÙNG val nội bộ
302
+ # - Head = MLP nhỏ (`Linear→ReLU→Dropout ×2 → 1`). Loss = MSE (+ tùy chọn pairwise ranking).
303
+ # - In **SRCC head** và **SRCC UTMOS** trên cùng tập val → biết head có thật sự vượt 0.414 không.
304
+
305
+ # %%
306
+ import torch.nn as nn
307
+ from scipy.stats import spearmanr
308
+ from sklearn.model_selection import train_test_split
309
+
310
+ torch.manual_seed(SEED); np.random.seed(SEED)
311
+ idx_all = np.arange(X.shape[0])
312
+ tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)
313
+
314
+ def to_t(a):
315
+ return torch.tensor(a, dtype=torch.float32, device=device)
316
+
317
+ Xn_t = to_t(Xn); yn_t = to_t(yn).unsqueeze(1)
318
+
319
+ class QMOSHead(nn.Module):
320
+ def __init__(self, d_in, h, p):
321
+ super().__init__()
322
+ self.net = nn.Sequential(
323
+ nn.Linear(d_in, h), nn.ReLU(), nn.Dropout(p),
324
+ nn.Linear(h, h), nn.ReLU(), nn.Dropout(p),
325
+ nn.Linear(h, 1),
326
+ )
327
+
328
+ def forward(self, x):
329
+ return self.net(x)
330
+
331
+ model = QMOSHead(FEAT_DIM, HIDDEN, DROPOUT).to(device)
332
+ opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
333
+ mse = nn.MSELoss()
334
+
335
+ def pairwise_rank_loss(pred, target):
336
+ """Khuyến khích pred xếp hạng giống target (margin ranking trên các cặp trong batch)."""
337
+ n = pred.shape[0]
338
+ if n < 2:
339
+ return torch.zeros((), device=device)
340
+ pi, pj = pred.unsqueeze(0), pred.unsqueeze(1)
341
+ ti, tj = target.unsqueeze(0), target.unsqueeze(1)
342
+ sign = torch.sign(ti - tj) # +1 nếu i nên cao hơn j
343
+ diff = pi - pj
344
+ # hinge: phạt khi thứ tự sai
345
+ return torch.relu(-sign * diff).mean()
346
+
347
+ @torch.no_grad()
348
+ def eval_val():
349
+ model.eval()
350
+ p = model(Xn_t[va_idx]).cpu().numpy().ravel()
351
+ srcc_head = spearmanr(p, y[va_idx]).correlation
352
+ out = {"head": float(srcc_head)}
353
+ if USE_UTMOS_FEAT:
354
+ u = X[va_idx, -1] # cột UTMOS (đặc trưng cuối, chưa chuẩn hóa)
355
+ out["utmos"] = float(spearmanr(u, y[va_idx]).correlation)
356
+ return out
357
+
358
+ best, best_state, bad = -1e9, None, 0
359
+ tr_t = torch.tensor(tr_idx, device=device)
360
+ for ep in range(1, EPOCHS + 1):
361
+ model.train()
362
+ perm = tr_t[torch.randperm(len(tr_t), device=device)]
363
+ run = 0.0
364
+ for i in range(0, len(perm), BATCH):
365
+ b = perm[i:i + BATCH]
366
+ opt.zero_grad()
367
+ pred = model(Xn_t[b])
368
+ loss = mse(pred, yn_t[b])
369
+ if RANK_LAMBDA > 0:
370
+ loss = loss + RANK_LAMBDA * pairwise_rank_loss(pred.ravel(), yn_t[b].ravel())
371
+ loss.backward(); opt.step()
372
+ run += loss.item() * len(b)
373
+ m = eval_val()
374
+ if m["head"] > best:
375
+ best = m["head"]
376
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
377
+ bad = 0
378
+ else:
379
+ bad += 1
380
+ if ep % 5 == 0 or ep == 1:
381
+ extra = f" | UTMOS={m['utmos']:.4f}" if "utmos" in m else ""
382
+ print(f"epoch {ep:3d} | loss {run/len(perm):.4f} | head SRCC={m['head']:.4f}{extra} | best {best:.4f}")
383
+ if bad >= PATIENCE:
384
+ print(f"Early stop ở epoch {ep}.")
385
+ break
386
+
387
+ model.load_state_dict(best_state)
388
+ final = eval_val()
389
+ print("\n✅ VAL (nội bộ):")
390
+ print(f" QMOS head SRCC = {final['head']:.4f}")
391
+ if "utmos" in final:
392
+ print(f" UTMOS baseline = {final['utmos']:.4f} (mốc leaderboard 0.414)")
393
+ print(" →", "✅ HEAD VƯỢT UTMOS" if final["head"] > final["utmos"] else "⚠️ chưa vượt — thử tăng EPOCHS / RANK_LAMBDA / bật thêm đặc trưng")
394
+
395
+ torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
396
+ "y_mu": y_mu, "y_sd": y_sd, "FEAT_DIM": FEAT_DIM,
397
+ "USE_E2V": USE_E2V, "USE_SAILER": USE_SAILER,
398
+ "USE_CLASSPROB": USE_CLASSPROB, "USE_UTMOS_FEAT": USE_UTMOS_FEAT,
399
+ "val_srcc": best}, os.path.join(OUT_DIR, "qmos_head.pt"))
400
+ print("Đã lưu", os.path.join(OUT_DIR, "qmos_head.pt"))
401
+
402
+ # %% [markdown]
403
+ # ## 6. Dự đoán QMOS cho DEV
404
+
405
+ # %%
406
+ def list_dev():
407
+ with open(DEV_SCP) as f:
408
+ return [ln.strip() for ln in f if ln.strip()]
409
+
410
+ dev_names = list_dev()
411
+ if LIMIT_DEV:
412
+ dev_names = dev_names[:LIMIT_DEV]
413
+ dev_stems = [stem(n) for n in dev_names]
414
+ print("DEV:", len(dev_names), "mẫu")
415
+
416
+ e2v_dev = extract_e2v(dev_stems, "dev") if USE_E2V else {}
417
+ sailer_dev = extract_sailer(dev_stems, "dev") if USE_SAILER else {}
418
+ utmos_dev = extract_utmos(dev_names, "dev") if USE_UTMOS_FEAT else {}
419
+
420
+ @torch.no_grad()
421
+ def predict_qmos(sid):
422
+ f = qmos_feature(sid, e2v_dev, sailer_dev, utmos_dev)
423
+ if f is None:
424
+ return None
425
+ fn = (f[None, :] - feat_mean) / feat_std
426
+ model.eval()
427
+ return float(model(to_t(fn)).item()) * y_sd + y_mu # đảo z-score
428
+
429
+ qmos_pred = {}
430
+ n_real = n_def = 0
431
+ for n in dev_names:
432
+ sid = stem(n)
433
+ p = predict_qmos(sid)
434
+ if p is None:
435
+ p = utmos_dev.get(sid, 3.0) # rơi về UTMOS nếu thiếu feature
436
+ n_def += 1
437
+ else:
438
+ n_real += 1
439
+ qmos_pred[n] = p
440
+ print(f"QMOS dự đoán: head thật {n_real}, dự phòng UTMOS {n_def}")
441
+
442
+ # Lưu riêng (để ghép tay nếu cần)
443
+ import csv
444
+ qmos_csv = os.path.join(OUT_DIR, "qmos_dev.csv")
445
+ with open(qmos_csv, "w", newline="") as f:
446
+ w = csv.writer(f); w.writerow(["wav", "QMOS"])
447
+ for n in dev_names:
448
+ w.writerow([n, f"{qmos_pred[n]:.6g}"])
449
+ print("Đã ghi", qmos_csv)
450
+
451
+ # %% [markdown]
452
+ # ## 7. Ghép QMOS mới vào answer.txt của exp04 → bản nộp mới
453
+ # Giữ NGUYÊN 5 cột cảm xúc đang thắng (EMOS/CAT/VAL/ARO/DOM), chỉ thay cột QMOS.
454
+
455
+ # %%
456
+ def merge_into_exp04(exp04_path, out_path):
457
+ if not os.path.exists(exp04_path):
458
+ print(f"⚠️ Không thấy {exp04_path} → BỎ QUA ghép. Hãy dùng qmos_dev.csv để thay cột QMOS thủ công,")
459
+ print(" hoặc trỏ EXP04_ANSWER đúng đường dẫn answer.txt của exp04 rồi chạy lại cell này.")
460
+ return False
461
+ with open(exp04_path) as f:
462
+ rows = list(csv.reader(f))
463
+ header = rows[0]
464
+ qi = header.index("QMOS")
465
+ wi = header.index("wav")
466
+ n_swapped = n_miss = 0
467
+ with open(out_path, "w", newline="") as f:
468
+ w = csv.writer(f); w.writerow(header)
469
+ for r in rows[1:]:
470
+ name = r[wi]
471
+ if name in qmos_pred:
472
+ r[qi] = f"{qmos_pred[name]:.6g}"; n_swapped += 1
473
+ else:
474
+ n_miss += 1
475
+ w.writerow(r)
476
+ print(f"Ghép xong → {out_path} | thay {n_swapped} cột QMOS, thiếu {n_miss} (giữ QMOS cũ)")
477
+ return True
478
+
479
+ merged = os.path.join(OUT_DIR, "answer.txt")
480
+ ok = merge_into_exp04(EXP04_ANSWER, merged)
481
+
482
+ if ok:
483
+ # validate + zip
484
+ with open(merged) as f:
485
+ rows = list(csv.reader(f))
486
+ assert rows[0][0] == "wav" and "QMOS" in rows[0]
487
+ for i, r in enumerate(rows[1:], 2):
488
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
489
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
490
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp06_qmos.zip answer.txt "
491
+ f"&& unzip -l submission_track2_exp06_qmos.zip")
492
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp06_qmos.zip"))
493
+
494
+ # %% [markdown]
495
+ # ## Ghi chú
496
+ # - **Lần đầu** đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để bắt lỗi; OK rồi đặt `None`.
497
+ # - **So sánh công bằng**: mục 5 in cả `head SRCC` và `UTMOS SRCC` trên CÙNG val nội bộ → chỉ nộp khi head > UTMOS.
498
+ # - Nếu head **chưa vượt** 0.414: thử (a) tăng `EPOCHS`; (b) bật `RANK_LAMBDA=0.2` (tối ưu thứ hạng);
499
+ # (c) đảm bảo `USE_UTMOS_FEAT=True` (neo residual); (d) thử bỏ bớt đặc trưng nhiễu (tắt `USE_CLASSPROB`).
500
+ # - **Ablation QMOS cho paper**: bật/tắt `USE_E2V/USE_SAILER/USE_UTMOS_FEAT/USE_CLASSPROB` → ghi `docs/04_experiments_log.md` (exp06).
501
+ # - Cache dùng CHUNG `fusion_cache/` với exp04 → nhớ **Save Version** giữ lại (gồm `utmos_*.npz` mới).
502
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp06).
track2/exp07_fusion_qmos.ipynb ADDED
@@ -0,0 +1,780 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "c75f9ad6",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp07 (FUSION + QMOS head, HỢP NHẤT 6 cột) — Kaggle\n",
9
+ "\n",
10
+ "**Khác exp04 ở đâu:** exp04 để **QMOS riêng** (UTMOS zero-shot). exp07 **gộp luôn QMOS vào trunk chung**\n",
11
+ "→ 1 model multi-task dự đoán **đủ 6 đầu ra**: QMOS · EMOS · CAT · VAL · ARO · DOM.\n",
12
+ "\n",
13
+ "## Giả thuyết (của bạn) cần kiểm chứng\n",
14
+ "\"Chất giọng tự nhiên có liên quan tới cảm nhận cảm xúc\" → nếu đúng, QMOS sẽ **hưởng lợi** từ biểu diễn\n",
15
+ "cảm xúc chung (emotion2vec + SAILER). **Rủi ro:** 2 backbone này chuyên *cảm xúc*, chưa chắc bắt tốt\n",
16
+ "*lỗi chất lượng/artifact* (thứ UTMOS chuyên trị) → QMOS có thể **thua** UTMOS, hoặc gộp làm **tụt** EMOS/VAD.\n",
17
+ "\n",
18
+ "## Lưới an toàn trong thiết kế\n",
19
+ "- **Vẫn đưa điểm UTMOS làm 1 đầu vào** cho QMOS head (`USE_UTMOS_FEAT`) → head học **chỉnh sửa** quanh\n",
20
+ " 0.414 thay vì học lại từ đầu → khó tệ hơn UTMOS.\n",
21
+ "- **In SRCC cả 6 cột + so mốc exp04** (EMOS 0.788 · CAT err 0.145 · VAL 0.578 · ARO 0.754 · DOM 0.706)\n",
22
+ " → cảnh báo ngay nếu gộp QMOS làm tụt 5 cột cảm xúc.\n",
23
+ "- **File riêng**, KHÔNG đụng `exp04_fusion_pipeline.py` (exp04 vẫn nguyên).\n",
24
+ "\n",
25
+ "```\n",
26
+ " mỗi wav ─► [e2v_emb | e2v_p5 | sailer_emb | sailer_p9 | sailer_vad3] ─► TRUNK chung\n",
27
+ " │\n",
28
+ " ┌──────────────┬───────────────┬─────────────┬───────────────────┤\n",
29
+ " [QMOS head] [EMOS head] [CAT head] [VAD head]\n",
30
+ " trunk + UTMOS trunk + target trunk trunk\n",
31
+ "```\n",
32
+ "\n",
33
+ "**Cách chạy:** GPU T4 + Internet On → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.\n",
34
+ "Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`. Dùng CHUNG cache `fusion_cache/` với exp04 (thêm `utmos_*.npz`)."
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "markdown",
39
+ "id": "b4e814c4",
40
+ "metadata": {},
41
+ "source": [
42
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "57b9eedb",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "import os\n",
53
+ "\n",
54
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
55
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
56
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
57
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
58
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
59
+ "\n",
60
+ "OUT_DIR = \"/kaggle/working\"\n",
61
+ "CACHE_DIR = \"/kaggle/working/fusion_cache\" # dùng CHUNG với exp04 (thêm utmos_*.npz)\n",
62
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
63
+ "\n",
64
+ "# ── Siêu tham số ─────────────────────────────────────────────────────────────\n",
65
+ "DEVICE = \"cuda\"\n",
66
+ "TRUNK_HIDDEN = 512\n",
67
+ "HEAD_HIDDEN = 128\n",
68
+ "DROPOUT = 0.3\n",
69
+ "LR = 1e-3\n",
70
+ "EPOCHS = 80\n",
71
+ "BATCH = 64\n",
72
+ "VAL_FRAC = 0.10\n",
73
+ "PATIENCE = 15\n",
74
+ "SEED = 42\n",
75
+ "\n",
76
+ "USE_UNCERTAINTY = True # tự cân 6 loss (Kendall); False = dùng LOSS_W cố định\n",
77
+ "LOSS_W = {\"qmos\": 1.0, \"emos\": 1.0, \"cat\": 1.0, \"val\": 1.0, \"aro\": 1.0, \"dom\": 1.0}\n",
78
+ "USE_E2V = True\n",
79
+ "USE_SAILER = True\n",
80
+ "USE_CLASSPROB = True\n",
81
+ "USE_UTMOS_FEAT = True # đưa điểm UTMOS làm đầu vào QMOS head (neo residual quanh 0.414)\n",
82
+ "\n",
83
+ "LIMIT_TRAIN = None\n",
84
+ "LIMIT_DEV = None\n",
85
+ "\n",
86
+ "# Mốc exp04 để so (cảnh báo nếu tụt khi gộp QMOS)\n",
87
+ "EXP04 = {\"emos\": 0.788, \"cat_err\": 0.145, \"val\": 0.578, \"aro\": 0.754, \"dom\": 0.706, \"qmos_utmos\": 0.414}\n",
88
+ "\n",
89
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
90
+ "SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
91
+ "EMO2SAILER = {\"angry\": 0, \"happy\": 4, \"neutral\": 5, \"sad\": 6, \"surprised\": 7}\n",
92
+ "\n",
93
+ "_EMO_ALIAS = {\n",
94
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
95
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
96
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
97
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
98
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
99
+ "}\n",
100
+ "\n",
101
+ "def norm_emotion(label):\n",
102
+ " key = str(label).strip().lower()\n",
103
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
104
+ "\n",
105
+ "def stem(p):\n",
106
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
107
+ "\n",
108
+ "assert USE_E2V or USE_SAILER, \"Phải bật ít nhất 1 backbone.\"\n",
109
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
110
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
111
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "id": "547ccf32",
117
+ "metadata": {},
118
+ "source": [
119
+ "## 1. Cài đặt + tải code SAILER"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "id": "132b9321",
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "import sys, subprocess\n",
130
+ "\n",
131
+ "def pip_install(*pkgs):\n",
132
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
133
+ "\n",
134
+ "pip_install(\"speechmos\", \"funasr\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
135
+ "\n",
136
+ "if USE_SAILER:\n",
137
+ " pip_install(\"loralib\", \"speechbrain\")\n",
138
+ " REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
139
+ " if not os.path.exists(REPO_DIR):\n",
140
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
141
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
142
+ " if REPO_DIR not in sys.path:\n",
143
+ " sys.path.insert(0, REPO_DIR)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "markdown",
148
+ "id": "75c6f07c",
149
+ "metadata": {},
150
+ "source": [
151
+ "## 2. Đọc & gộp nhãn (gộp theo wavID) — THÊM cột qMOS\n",
152
+ "Khác exp04: gộp thêm **qMOS** (= TB `qMOS` theo wav) làm nhãn cho QMOS head."
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "execution_count": null,
158
+ "id": "3c73a7fb",
159
+ "metadata": {},
160
+ "outputs": [],
161
+ "source": [
162
+ "import numpy as np\n",
163
+ "import pandas as pd\n",
164
+ "\n",
165
+ "def load_target_emotions():\n",
166
+ " tgt = {}\n",
167
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
168
+ " for ln in f:\n",
169
+ " parts = ln.strip().split(\"|\")\n",
170
+ " if len(parts) < 2:\n",
171
+ " continue\n",
172
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
173
+ " return tgt\n",
174
+ "\n",
175
+ "def _col(cols_map, *names, default_idx=None, df=None):\n",
176
+ " for n in names:\n",
177
+ " if n in cols_map:\n",
178
+ " return cols_map[n]\n",
179
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
180
+ "\n",
181
+ "def parse_emocat_votes(cell):\n",
182
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
183
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
184
+ " e = norm_emotion(tok)\n",
185
+ " if e in EMOTIONS5:\n",
186
+ " v[EMOTIONS5.index(e)] += 1.0\n",
187
+ " return v\n",
188
+ "\n",
189
+ "def load_train_labels():\n",
190
+ " \"\"\"train.csv → DataFrame [wavID, qmos, emos, val, aro, dom, cat0..cat4] gộp theo wav.\"\"\"\n",
191
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
192
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
193
+ " wav_col = _col(cols, \"wavid\", \"wav\", default_idx=1, df=df)\n",
194
+ " qmos_col = _col(cols, \"qmos\", \"mos\")\n",
195
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
196
+ " val_col = _col(cols, \"val\", \"valence\")\n",
197
+ " aro_col = _col(cols, \"aro\", \"arousal\")\n",
198
+ " dom_col = _col(cols, \"dom\", \"dominance\")\n",
199
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
200
+ " assert qmos_col, f\"Không thấy cột qMOS trong train.csv (cột: {list(df.columns)})\"\n",
201
+ " assert emos_col, f\"Không thấy cột eMOS trong train.csv (cột: {list(df.columns)})\"\n",
202
+ "\n",
203
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
204
+ " rows = []\n",
205
+ " for sid, g in df.groupby(\"_stem\"):\n",
206
+ " rec = {\"wavID\": sid,\n",
207
+ " \"qmos\": float(g[qmos_col].mean()),\n",
208
+ " \"emos\": float(g[emos_col].mean())}\n",
209
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
210
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
211
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
212
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
213
+ " if cat_col:\n",
214
+ " for cell in g[cat_col]:\n",
215
+ " votes += parse_emocat_votes(cell)\n",
216
+ " s = votes.sum()\n",
217
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)\n",
218
+ " for i in range(len(EMOTIONS5)):\n",
219
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
220
+ " rows.append(rec)\n",
221
+ " return pd.DataFrame(rows)\n",
222
+ "\n",
223
+ "target_map = load_target_emotions()\n",
224
+ "train_df = load_train_labels()\n",
225
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
226
+ "print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")\n",
227
+ "print(\"qMOS:\", train_df[\"qmos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
228
+ "print(\"eMOS:\", train_df[\"emos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
229
+ "train_df.head()"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "markdown",
234
+ "id": "0726b340",
235
+ "metadata": {},
236
+ "source": [
237
+ "## 3. Trích đặc trưng 2 backbone + điểm UTMOS (cache CHUNG với exp04)"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": null,
243
+ "id": "ae27e424",
244
+ "metadata": {
245
+ "lines_to_next_cell": 1
246
+ },
247
+ "outputs": [],
248
+ "source": [
249
+ "import torch\n",
250
+ "import torch.nn.functional as F\n",
251
+ "\n",
252
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
253
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU\")\n",
254
+ "\n",
255
+ "def extract_e2v(stems, tag):\n",
256
+ " from tqdm.auto import tqdm\n",
257
+ " cache_path = os.path.join(CACHE_DIR, f\"e2v_{tag}.npz\")\n",
258
+ " store = {}\n",
259
+ " if os.path.exists(cache_path):\n",
260
+ " z = np.load(cache_path, allow_pickle=True)\n",
261
+ " store = {k: z[k] for k in z.files}\n",
262
+ " print(f\"[e2v/{tag}] nạp cache: {len(store)}\")\n",
263
+ " todo = [s for s in stems if s not in store]\n",
264
+ " if todo:\n",
265
+ " from funasr import AutoModel\n",
266
+ " m = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\", device=device)\n",
267
+ " for i, s in enumerate(tqdm(todo, desc=f\"e2v {tag}\")):\n",
268
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
269
+ " if not os.path.exists(wav):\n",
270
+ " continue\n",
271
+ " r = m.generate(wav, granularity=\"utterance\", extract_embedding=True)[0]\n",
272
+ " emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
273
+ " probs = {e: 0.0 for e in EMOTIONS5}\n",
274
+ " for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
275
+ " name = lab.split(\"/\")[-1]\n",
276
+ " if name in probs:\n",
277
+ " probs[name] = float(sc)\n",
278
+ " tot = sum(probs.values())\n",
279
+ " p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)\n",
280
+ " store[s] = np.concatenate([emb, p5]).astype(np.float32)\n",
281
+ " if (i + 1) % 500 == 0:\n",
282
+ " np.savez(cache_path, **store)\n",
283
+ " np.savez(cache_path, **store)\n",
284
+ " del m\n",
285
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
286
+ " return {s: (v[:-5], v[-5:]) for s, v in store.items()}\n",
287
+ "\n",
288
+ "def _pool_feat(features):\n",
289
+ " f = features.detach().cpu().numpy()\n",
290
+ " if f.ndim <= 1:\n",
291
+ " return f.reshape(-1).astype(np.float32)\n",
292
+ " return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)\n",
293
+ "\n",
294
+ "def extract_sailer(stems, tag):\n",
295
+ " import librosa\n",
296
+ " from tqdm.auto import tqdm\n",
297
+ " cache_path = os.path.join(CACHE_DIR, f\"sailer_{tag}.npz\")\n",
298
+ " store = {}\n",
299
+ " if os.path.exists(cache_path):\n",
300
+ " z = np.load(cache_path, allow_pickle=True)\n",
301
+ " store = {k: z[k] for k in z.files}\n",
302
+ " print(f\"[sailer/{tag}] nạp cache: {len(store)}\")\n",
303
+ " todo = [s for s in stems if s not in store]\n",
304
+ " if todo:\n",
305
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
306
+ " sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device).eval()\n",
307
+ " with torch.no_grad():\n",
308
+ " for i, s in enumerate(tqdm(todo, desc=f\"sailer {tag}\")):\n",
309
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
310
+ " if not os.path.exists(wav):\n",
311
+ " continue\n",
312
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
313
+ " wave = wave[: 15 * 16000]\n",
314
+ " data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
315
+ " logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
316
+ " emb = _pool_feat(feat)\n",
317
+ " p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
318
+ " vad3 = np.array([1 + 4 * float(valence.item()),\n",
319
+ " 1 + 4 * float(arousal.item()),\n",
320
+ " 1 + 4 * float(dominance.item())], dtype=np.float32)\n",
321
+ " store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)\n",
322
+ " if (i + 1) % 500 == 0:\n",
323
+ " np.savez(cache_path, **store)\n",
324
+ " np.savez(cache_path, **store)\n",
325
+ " del sailer\n",
326
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
327
+ " return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}\n",
328
+ "\n",
329
+ "def extract_utmos(names, tag):\n",
330
+ " \"\"\"Chấm UTMOS từng wav (theo TÊN, vì DEV gọi .wav theo tên). → dict {stem: score}.\n",
331
+ " Cache CACHE_DIR/utmos_<tag>.npz. Dùng vừa làm đầu vào QMOS head, vừa làm baseline so sánh.\"\"\"\n",
332
+ " import librosa\n",
333
+ " from tqdm.auto import tqdm\n",
334
+ " cache_path = os.path.join(CACHE_DIR, f\"utmos_{tag}.npz\")\n",
335
+ " store = {}\n",
336
+ " if os.path.exists(cache_path):\n",
337
+ " z = np.load(cache_path, allow_pickle=True)\n",
338
+ " store = {k: float(z[k]) for k in z.files}\n",
339
+ " print(f\"[utmos/{tag}] nạp cache: {len(store)}\")\n",
340
+ " todo = [n for n in names if stem(n) not in store]\n",
341
+ " if todo:\n",
342
+ " predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\",\n",
343
+ " trust_repo=True).to(device).eval()\n",
344
+ " with torch.no_grad():\n",
345
+ " for i, n in enumerate(tqdm(todo, desc=f\"utmos {tag}\")):\n",
346
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else n + \".wav\")\n",
347
+ " if not os.path.exists(wav):\n",
348
+ " continue\n",
349
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
350
+ " store[stem(n)] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),\n",
351
+ " sr=16000).mean().item())\n",
352
+ " if (i + 1) % 500 == 0:\n",
353
+ " np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
354
+ " np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
355
+ " del predictor\n",
356
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
357
+ " return store"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "markdown",
362
+ "id": "e9fad1a0",
363
+ "metadata": {},
364
+ "source": [
365
+ "## 4. Dựng feature + nhãn cho train\n",
366
+ "Feature audio (cảm xúc) = `[e2v_emb | e2v_p5 | sailer_emb | sailer_p9 | sailer_vad3]` (như exp04).\n",
367
+ "Thêm: vector **UTMOS** (1 số/ wav) cho QMOS head, và nhãn **qMOS**."
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "id": "768f374b",
374
+ "metadata": {},
375
+ "outputs": [],
376
+ "source": [
377
+ "train_stems = list(train_df[\"wavID\"])\n",
378
+ "if LIMIT_TRAIN:\n",
379
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
380
+ "\n",
381
+ "e2v_tr = extract_e2v(train_stems, \"train\") if USE_E2V else {}\n",
382
+ "sailer_tr = extract_sailer(train_stems, \"train\") if USE_SAILER else {}\n",
383
+ "utmos_tr = extract_utmos(train_stems, \"train\") if USE_UTMOS_FEAT else {}\n",
384
+ "\n",
385
+ "def audio_feature(sid, e2v_map, sailer_map):\n",
386
+ " parts = []\n",
387
+ " if USE_E2V:\n",
388
+ " pk = e2v_map.get(sid)\n",
389
+ " if pk is None:\n",
390
+ " return None\n",
391
+ " emb, p5 = pk\n",
392
+ " parts.append(emb)\n",
393
+ " if USE_CLASSPROB:\n",
394
+ " parts.append(p5)\n",
395
+ " if USE_SAILER:\n",
396
+ " pk = sailer_map.get(sid)\n",
397
+ " if pk is None:\n",
398
+ " return None\n",
399
+ " emb, p9, vad3 = pk\n",
400
+ " parts.append(emb)\n",
401
+ " if USE_CLASSPROB:\n",
402
+ " parts.append(p9); parts.append(vad3)\n",
403
+ " return np.concatenate(parts).astype(np.float32)\n",
404
+ "\n",
405
+ "def onehot_target(tgt):\n",
406
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
407
+ " if tgt in EMOTIONS5:\n",
408
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
409
+ " return v\n",
410
+ "\n",
411
+ "lab = train_df.set_index(\"wavID\")\n",
412
+ "X, T, U, y_qmos, y_emos, y_vad, y_cat = [], [], [], [], [], [], []\n",
413
+ "for s in train_stems:\n",
414
+ " f = audio_feature(s, e2v_tr, sailer_tr)\n",
415
+ " tgt = target_map.get(s)\n",
416
+ " if f is None or tgt is None or s not in lab.index:\n",
417
+ " continue\n",
418
+ " if USE_UTMOS_FEAT and s not in utmos_tr:\n",
419
+ " continue\n",
420
+ " X.append(f)\n",
421
+ " T.append(onehot_target(tgt))\n",
422
+ " U.append(utmos_tr.get(s, 3.0) if USE_UTMOS_FEAT else 0.0)\n",
423
+ " y_qmos.append(lab.loc[s, \"qmos\"])\n",
424
+ " y_emos.append(lab.loc[s, \"emos\"])\n",
425
+ " y_vad.append([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]])\n",
426
+ " y_cat.append([lab.loc[s, f\"cat{i}\"] for i in range(len(EMOTIONS5))])\n",
427
+ "\n",
428
+ "X = np.stack(X).astype(np.float32)\n",
429
+ "T = np.stack(T).astype(np.float32)\n",
430
+ "U = np.array(U, dtype=np.float32).reshape(-1, 1)\n",
431
+ "y_qmos = np.array(y_qmos, dtype=np.float32)\n",
432
+ "y_emos = np.array(y_emos, dtype=np.float32)\n",
433
+ "y_vad = np.array(y_vad, dtype=np.float32)\n",
434
+ "y_cat = np.array(y_cat, dtype=np.float32)\n",
435
+ "FEAT_DIM = X.shape[1]\n",
436
+ "print(f\"Train: X={X.shape} U={U.shape} qmos={y_qmos.shape} emos={y_emos.shape} vad={y_vad.shape}\")\n",
437
+ "\n",
438
+ "# Chuẩn hóa feature audio + UTMOS (z-score), lưu mean/std.\n",
439
+ "feat_mean = X.mean(0, keepdims=True); feat_std = X.std(0, keepdims=True) + 1e-6\n",
440
+ "Xn = (X - feat_mean) / feat_std\n",
441
+ "u_mu, u_sd = float(U.mean()), float(U.std() + 1e-6)\n",
442
+ "Un = (U - u_mu) / u_sd\n",
443
+ "\n",
444
+ "# Chuẩn hóa nhãn liên tục về z-score.\n",
445
+ "qmos_mu, qmos_sd = float(y_qmos.mean()), float(y_qmos.std() + 1e-6)\n",
446
+ "y_qmos_z = (y_qmos - qmos_mu) / qmos_sd\n",
447
+ "emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6)\n",
448
+ "y_emos_z = (y_emos - emos_mu) / emos_sd\n",
449
+ "if HAS_VAD:\n",
450
+ " vad_mu = np.nanmean(y_vad, axis=0); vad_sd = np.nanstd(y_vad, axis=0) + 1e-6\n",
451
+ " y_vad_z = (y_vad - vad_mu) / vad_sd\n",
452
+ "else:\n",
453
+ " vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
454
+ " y_vad_z = np.zeros_like(y_vad)"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "markdown",
459
+ "id": "5903e6ca",
460
+ "metadata": {},
461
+ "source": [
462
+ "## 5. Model fusion multi-task (6 head) + train loop\n",
463
+ "Thêm so exp04: **QMOS head** nhận `[trunk | UTMOS]` → 1; `qmos` vào uncertainty weighting (6 task)."
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "code",
468
+ "execution_count": null,
469
+ "id": "68a3a836",
470
+ "metadata": {
471
+ "lines_to_next_cell": 1
472
+ },
473
+ "outputs": [],
474
+ "source": [
475
+ "import torch.nn as nn\n",
476
+ "from scipy.stats import spearmanr\n",
477
+ "from sklearn.model_selection import train_test_split\n",
478
+ "\n",
479
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
480
+ "N_EMO = len(EMOTIONS5)\n",
481
+ "idx_all = np.arange(X.shape[0])\n",
482
+ "tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)\n",
483
+ "\n",
484
+ "def to_t(a):\n",
485
+ " return torch.tensor(a, dtype=torch.float32, device=device)\n",
486
+ "\n",
487
+ "Xn_t, T_t, Un_t = to_t(Xn), to_t(T), to_t(Un)\n",
488
+ "qmos_t = to_t(y_qmos_z).unsqueeze(1)\n",
489
+ "emos_t = to_t(y_emos_z).unsqueeze(1)\n",
490
+ "vad_t = to_t(y_vad_z)\n",
491
+ "cat_t = to_t(y_cat)\n",
492
+ "\n",
493
+ "class FusionMTL6(nn.Module):\n",
494
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo, use_utmos):\n",
495
+ " super().__init__()\n",
496
+ " self.use_utmos = use_utmos\n",
497
+ " self.trunk = nn.Sequential(\n",
498
+ " nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
499
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
500
+ " )\n",
501
+ " self.qmos = nn.Sequential( # nhận [trunk | utmos] nếu bật\n",
502
+ " nn.Linear(trunk_h + (1 if use_utmos else 0), head_h), nn.ReLU(), nn.Dropout(p),\n",
503
+ " nn.Linear(head_h, 1))\n",
504
+ " self.emos = nn.Sequential( # nhận [trunk | target]\n",
505
+ " nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
506
+ " self.cat = nn.Sequential(\n",
507
+ " nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
508
+ " self.vad = nn.Sequential(\n",
509
+ " nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
510
+ "\n",
511
+ " def forward(self, x, tgt, utmos):\n",
512
+ " h = self.trunk(x)\n",
513
+ " qmos_in = torch.cat([h, utmos], dim=1) if self.use_utmos else h\n",
514
+ " qmos = self.qmos(qmos_in)\n",
515
+ " emos = self.emos(torch.cat([h, tgt], dim=1))\n",
516
+ " cat_logits = self.cat(h)\n",
517
+ " vad = self.vad(h)\n",
518
+ " return qmos, emos, cat_logits, vad\n",
519
+ "\n",
520
+ "model = FusionMTL6(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO, USE_UTMOS_FEAT).to(device)\n",
521
+ "\n",
522
+ "TASKS = [\"qmos\", \"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
523
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
524
+ "params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
525
+ "opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)\n",
526
+ "mse = nn.MSELoss(reduction=\"none\")\n",
527
+ "\n",
528
+ "def soft_ce(logits, target_dist):\n",
529
+ " logq = F.log_softmax(logits, dim=1)\n",
530
+ " return -(target_dist * logq).sum(dim=1)\n",
531
+ "\n",
532
+ "def task_losses(qmos_p, emos_p, cat_logits, vad_p, b):\n",
533
+ " L = {}\n",
534
+ " L[\"qmos\"] = mse(qmos_p, qmos_t[b]).mean()\n",
535
+ " L[\"emos\"] = mse(emos_p, emos_t[b]).mean()\n",
536
+ " L[\"cat\"] = soft_ce(cat_logits, cat_t[b]).mean()\n",
537
+ " if HAS_VAD:\n",
538
+ " L[\"val\"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()\n",
539
+ " L[\"aro\"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()\n",
540
+ " L[\"dom\"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()\n",
541
+ " else:\n",
542
+ " z = torch.zeros((), device=device)\n",
543
+ " L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
544
+ " return L\n",
545
+ "\n",
546
+ "def combine(L):\n",
547
+ " if USE_UNCERTAINTY:\n",
548
+ " tot = 0.0\n",
549
+ " for i, t in enumerate(TASKS):\n",
550
+ " tot = tot + torch.exp(-log_var[i]) * L[t] + log_var[i]\n",
551
+ " return tot\n",
552
+ " return sum(LOSS_W[t] * L[t] for t in TASKS)\n",
553
+ "\n",
554
+ "@torch.no_grad()\n",
555
+ "def eval_val():\n",
556
+ " model.eval()\n",
557
+ " qp, ep, cl, vp = model(Xn_t[va_idx], T_t[va_idx], Un_t[va_idx])\n",
558
+ " qp = qp.cpu().numpy().ravel(); ep = ep.cpu().numpy().ravel()\n",
559
+ " out = {\"qmos\": spearmanr(qp, y_qmos[va_idx]).correlation,\n",
560
+ " \"emos\": spearmanr(ep, y_emos[va_idx]).correlation}\n",
561
+ " if USE_UTMOS_FEAT:\n",
562
+ " out[\"qmos_utmos\"] = spearmanr(U[va_idx, 0], y_qmos[va_idx]).correlation # baseline UTMOS đơn lẻ\n",
563
+ " if HAS_VAD:\n",
564
+ " vp = vp.cpu().numpy()\n",
565
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
566
+ " out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation\n",
567
+ " q = F.softmax(cl, dim=1).cpu().numpy(); p = y_cat[va_idx]\n",
568
+ " kl = (p * (np.log(p + 1e-9) - np.log(q + 1e-9))).sum(1).mean()\n",
569
+ " out[\"cat_negkl\"] = float(-kl)\n",
570
+ " return out\n",
571
+ "\n",
572
+ "def val_score(m):\n",
573
+ " \"\"\"Điểm tổng early-stop = TB SRCC các task liên tục (qmos+emos+VAD).\"\"\"\n",
574
+ " keys = [\"qmos\", \"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
575
+ " return float(np.mean([m[k] for k in keys]))\n",
576
+ "\n",
577
+ "best_score, best_state, bad = -1e9, None, 0\n",
578
+ "tr_t = torch.tensor(tr_idx, device=device)\n",
579
+ "for ep_i in range(1, EPOCHS + 1):\n",
580
+ " model.train()\n",
581
+ " perm = tr_t[torch.randperm(len(tr_t), device=device)]\n",
582
+ " run = 0.0\n",
583
+ " for i in range(0, len(perm), BATCH):\n",
584
+ " b = perm[i:i + BATCH]\n",
585
+ " opt.zero_grad()\n",
586
+ " qmos_p, emos_p, cat_logits, vad_p = model(Xn_t[b], T_t[b], Un_t[b])\n",
587
+ " L = task_losses(qmos_p, emos_p, cat_logits, vad_p, b)\n",
588
+ " loss = combine(L)\n",
589
+ " loss.backward(); opt.step()\n",
590
+ " run += loss.item() * len(b)\n",
591
+ " m = eval_val()\n",
592
+ " sc = val_score(m)\n",
593
+ " if sc > best_score:\n",
594
+ " best_score = sc\n",
595
+ " best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
596
+ " bad = 0\n",
597
+ " else:\n",
598
+ " bad += 1\n",
599
+ " if ep_i % 5 == 0 or ep_i == 1:\n",
600
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"qmos\", \"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
601
+ " print(f\"epoch {ep_i:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}\")\n",
602
+ " if bad >= PATIENCE:\n",
603
+ " print(f\"Early stop ở epoch {ep_i}.\")\n",
604
+ " break\n",
605
+ "\n",
606
+ "model.load_state_dict(best_state)\n",
607
+ "final = eval_val()\n",
608
+ "print(\"\\n✅ VAL (nội bộ) — exp07 (fusion + QMOS head):\")\n",
609
+ "print(f\" QMOS SRCC = {final['qmos']:.4f}\", end=\"\")\n",
610
+ "if \"qmos_utmos\" in final:\n",
611
+ " tag = \"✅ vượt UTMOS\" if final[\"qmos\"] > final[\"qmos_utmos\"] else \"⚠️ CHƯA vượt UTMOS\"\n",
612
+ " print(f\" (UTMOS đơn lẻ = {final['qmos_utmos']:.4f} → {tag})\")\n",
613
+ "else:\n",
614
+ " print()\n",
615
+ "print(f\" EMOS SRCC = {final['emos']:.4f} (mốc exp04 = {EXP04['emos']})\")\n",
616
+ "if HAS_VAD:\n",
617
+ " print(f\" VAL/ARO/DOM = {final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}\"\n",
618
+ " f\" (mốc exp04 = {EXP04['val']}/{EXP04['aro']}/{EXP04['dom']})\")\n",
619
+ "# Cảnh báo negative transfer (gộp QMOS làm tụt cảm xúc)\n",
620
+ "warn = []\n",
621
+ "if final[\"emos\"] < EXP04[\"emos\"] - 0.02:\n",
622
+ " warn.append(f\"EMOS {final['emos']:.3f} < {EXP04['emos']}\")\n",
623
+ "if HAS_VAD:\n",
624
+ " for t in [\"val\", \"aro\", \"dom\"]:\n",
625
+ " if final[t] < EXP04[t] - 0.02:\n",
626
+ " warn.append(f\"{t.upper()} {final[t]:.3f} < {EXP04[t]}\")\n",
627
+ "if warn:\n",
628
+ " print(\" ⚠️ NEGATIVE TRANSFER? Cảm xúc tụt so exp04:\", \"; \".join(warn),\n",
629
+ " \"\\n → cân nhắc giữ exp04 cho 5 cột cảm xúc + chỉ lấy QMOS từ exp07/exp06.\")\n",
630
+ "else:\n",
631
+ " print(\" ✅ Không thấy 5 cột cảm xúc tụt rõ so exp04.\")\n",
632
+ "if USE_UNCERTAINTY:\n",
633
+ " print(\" log σ² mỗi task:\", {t: round(float(log_var[i]), 3) for i, t in enumerate(TASKS)})\n",
634
+ "\n",
635
+ "torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
636
+ " \"u_mu\": u_mu, \"u_sd\": u_sd,\n",
637
+ " \"qmos_mu\": qmos_mu, \"qmos_sd\": qmos_sd, \"emos_mu\": emos_mu, \"emos_sd\": emos_sd,\n",
638
+ " \"vad_mu\": vad_mu, \"vad_sd\": vad_sd, \"FEAT_DIM\": FEAT_DIM,\n",
639
+ " \"USE_E2V\": USE_E2V, \"USE_SAILER\": USE_SAILER, \"USE_CLASSPROB\": USE_CLASSPROB,\n",
640
+ " \"USE_UTMOS_FEAT\": USE_UTMOS_FEAT, \"val_score\": best_score},\n",
641
+ " os.path.join(OUT_DIR, \"fusion_qmos_mtl.pt\"))\n",
642
+ "print(\"Đã lưu\", os.path.join(OUT_DIR, \"fusion_qmos_mtl.pt\"))"
643
+ ]
644
+ },
645
+ {
646
+ "cell_type": "markdown",
647
+ "id": "9a788c48",
648
+ "metadata": {},
649
+ "source": [
650
+ "## 6. Dự đoán DEV → `answer.txt` đủ 6 cột (QMOS giờ từ HEAD, không phải SpeechMOS riêng)"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "code",
655
+ "execution_count": null,
656
+ "id": "8acd813f",
657
+ "metadata": {
658
+ "lines_to_next_cell": 1
659
+ },
660
+ "outputs": [],
661
+ "source": [
662
+ "def list_dev():\n",
663
+ " with open(DEV_SCP) as f:\n",
664
+ " return [ln.strip() for ln in f if ln.strip()]\n",
665
+ "\n",
666
+ "dev_names = list_dev()\n",
667
+ "if LIMIT_DEV:\n",
668
+ " dev_names = dev_names[:LIMIT_DEV]\n",
669
+ "dev_stems = [stem(n) for n in dev_names]\n",
670
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
671
+ "\n",
672
+ "e2v_dev = extract_e2v(dev_stems, \"dev\") if USE_E2V else {}\n",
673
+ "sailer_dev = extract_sailer(dev_stems, \"dev\") if USE_SAILER else {}\n",
674
+ "utmos_dev = extract_utmos(dev_names, \"dev\") if USE_UTMOS_FEAT else {}\n",
675
+ "\n",
676
+ "@torch.no_grad()\n",
677
+ "def predict_all(sid):\n",
678
+ " f = audio_feature(sid, e2v_dev, sailer_dev)\n",
679
+ " if f is None:\n",
680
+ " return None\n",
681
+ " fn = (f[None, :] - feat_mean) / feat_std\n",
682
+ " tgt = onehot_target(target_map.get(sid))[None, :]\n",
683
+ " u = np.array([[utmos_dev.get(sid, 3.0)]], dtype=np.float32)\n",
684
+ " un = (u - u_mu) / u_sd\n",
685
+ " model.eval()\n",
686
+ " qmos_p, emos_p, cat_logits, vad_p = model(to_t(fn), to_t(tgt), to_t(un))\n",
687
+ " qmos = float(qmos_p.item()) * qmos_sd + qmos_mu\n",
688
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
689
+ " cat5 = F.softmax(cat_logits, dim=1)[0].cpu().numpy()\n",
690
+ " vad3 = vad_p[0].cpu().numpy() * vad_sd + vad_mu\n",
691
+ " return qmos, emos, cat5, vad3\n",
692
+ "\n",
693
+ "def fmt_cat(probs5):\n",
694
+ " return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
695
+ "\n",
696
+ "def build_answer(out_path):\n",
697
+ " from tqdm.auto import tqdm\n",
698
+ " n_real = n_default = 0\n",
699
+ " with open(out_path, \"w\") as f:\n",
700
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
701
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
702
+ " sid = stem(name)\n",
703
+ " pred = predict_all(sid)\n",
704
+ " if pred is None:\n",
705
+ " # rơi về: QMOS=UTMOS nếu có, còn lại mặc định\n",
706
+ " qmos = utmos_dev.get(sid, 3.0)\n",
707
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])\n",
708
+ " n_default += 1\n",
709
+ " else:\n",
710
+ " qmos, emos, cat5, vad3 = pred\n",
711
+ " n_real += 1\n",
712
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
713
+ " f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
714
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | head thật {n_real}, mặc định {n_default}\")\n",
715
+ "\n",
716
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
717
+ "build_answer(answer_path)"
718
+ ]
719
+ },
720
+ {
721
+ "cell_type": "markdown",
722
+ "id": "0ac67bb5",
723
+ "metadata": {},
724
+ "source": [
725
+ "## 7. Validate + đóng zip"
726
+ ]
727
+ },
728
+ {
729
+ "cell_type": "code",
730
+ "execution_count": null,
731
+ "id": "8244b0c3",
732
+ "metadata": {},
733
+ "outputs": [],
734
+ "source": [
735
+ "def validate(path):\n",
736
+ " import csv\n",
737
+ " with open(path) as f:\n",
738
+ " rows = list(csv.reader(f))\n",
739
+ " header = rows[0]\n",
740
+ " assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
741
+ " for i, r in enumerate(rows[1:], 2):\n",
742
+ " assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
743
+ " print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
744
+ "\n",
745
+ "validate(answer_path)\n",
746
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp07_fusion_qmos.zip answer.txt \"\n",
747
+ " f\"&& unzip -l submission_track2_exp07_fusion_qmos.zip\")\n",
748
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp07_fusion_qmos.zip\"))"
749
+ ]
750
+ },
751
+ {
752
+ "cell_type": "markdown",
753
+ "id": "a73c1c11",
754
+ "metadata": {},
755
+ "source": [
756
+ "## Ghi chú\n",
757
+ "- **Lần đầu** đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`; OK rồi đặt `None`.\n",
758
+ "- **Đọc kết quả mục 5 theo 2 câu hỏi:**\n",
759
+ " 1. QMOS head có **vượt UTMOS đơn lẻ (0.414)** không? (dòng \"vượt/CHƯA vượt UTMOS\")\n",
760
+ " 2. Gộp QMOS có **làm tụt** EMOS/VAD so exp04 không? (dòng \"NEGATIVE TRANSFER?\")\n",
761
+ "- **Quyết định nộp:**\n",
762
+ " - Nếu QMOS↑ và cảm xúc KHÔNG tụt → nộp answer.txt exp07 (1 model trọn 6 cột — đẹp cho paper).\n",
763
+ " - Nếu QMOS↑ nhưng cảm xúc TỤT → giữ exp04 cho 5 cột cảm xúc, chỉ lấy **cột QMOS** của exp07/exp06 ghép vào.\n",
764
+ " - Nếu QMOS không vượt UTMOS → kết luận \"chất lượng trực giao cảm xúc\" (vẫn là phát hiện cho paper); giữ exp04.\n",
765
+ "- **Ablation cho paper**: `USE_UTMOS_FEAT=False` (QMOS chỉ từ trunk cảm xúc) → đo trực tiếp giả thuyết của bạn.\n",
766
+ "- Cache dùng CHUNG `fusion_cache/` với exp04 → **Save Version** giữ lại.\n",
767
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp07)."
768
+ ]
769
+ }
770
+ ],
771
+ "metadata": {
772
+ "jupytext": {
773
+ "cell_metadata_filter": "-all",
774
+ "main_language": "python",
775
+ "notebook_metadata_filter": "-all"
776
+ }
777
+ },
778
+ "nbformat": 4,
779
+ "nbformat_minor": 5
780
+ }
track2/exp07_fusion_qmos_pipeline.py ADDED
@@ -0,0 +1,654 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp07 (FUSION + QMOS head, HỢP NHẤT 6 cột) — Kaggle
3
+ #
4
+ # **Khác exp04 ở đâu:** exp04 để **QMOS riêng** (UTMOS zero-shot). exp07 **gộp luôn QMOS vào trunk chung**
5
+ # → 1 model multi-task dự đoán **đủ 6 đầu ra**: QMOS · EMOS · CAT · VAL · ARO · DOM.
6
+ #
7
+ # ## Giả thuyết (của bạn) cần kiểm chứng
8
+ # "Chất giọng tự nhiên có liên quan tới cảm nhận cảm xúc" → nếu đúng, QMOS sẽ **hưởng lợi** từ biểu diễn
9
+ # cảm xúc chung (emotion2vec + SAILER). **Rủi ro:** 2 backbone này chuyên *cảm xúc*, chưa chắc bắt tốt
10
+ # *lỗi chất lượng/artifact* (thứ UTMOS chuyên trị) → QMOS có thể **thua** UTMOS, hoặc gộp làm **tụt** EMOS/VAD.
11
+ #
12
+ # ## Lưới an toàn trong thiết kế
13
+ # - **Vẫn đưa điểm UTMOS làm 1 đầu vào** cho QMOS head (`USE_UTMOS_FEAT`) → head học **chỉnh sửa** quanh
14
+ # 0.414 thay vì học lại từ đầu → khó tệ hơn UTMOS.
15
+ # - **In SRCC cả 6 cột + so mốc exp04** (EMOS 0.788 · CAT err 0.145 · VAL 0.578 · ARO 0.754 · DOM 0.706)
16
+ # → cảnh báo ngay nếu gộp QMOS làm tụt 5 cột cảm xúc.
17
+ # - **File riêng**, KHÔNG đụng `exp04_fusion_pipeline.py` (exp04 vẫn nguyên).
18
+ #
19
+ # ```
20
+ # mỗi wav ─► [e2v_emb | e2v_p5 | sailer_emb | sailer_p9 | sailer_vad3] ─► TRUNK chung
21
+ # │
22
+ # ┌──────────────┬───────────────┬─────────────┬───────────────────┤
23
+ # [QMOS head] [EMOS head] [CAT head] [VAD head]
24
+ # trunk + UTMOS trunk + target trunk trunk
25
+ # ```
26
+ #
27
+ # **Cách chạy:** GPU T4 + Internet On → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.
28
+ # Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`. Dùng CHUNG cache `fusion_cache/` với exp04 (thêm `utmos_*.npz`).
29
+
30
+ # %% [markdown]
31
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
32
+
33
+ # %%
34
+ import os
35
+
36
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
37
+ WAV_DIR = f"{DATA_ROOT}/wav"
38
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
39
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
40
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
41
+
42
+ OUT_DIR = "/kaggle/working"
43
+ CACHE_DIR = "/kaggle/working/fusion_cache" # dùng CHUNG với exp04 (thêm utmos_*.npz)
44
+ os.makedirs(CACHE_DIR, exist_ok=True)
45
+
46
+ # ── Siêu tham số ─────────────────────────────────────────────────────────────
47
+ DEVICE = "cuda"
48
+ TRUNK_HIDDEN = 512
49
+ HEAD_HIDDEN = 128
50
+ DROPOUT = 0.3
51
+ LR = 1e-3
52
+ EPOCHS = 80
53
+ BATCH = 64
54
+ VAL_FRAC = 0.10
55
+ PATIENCE = 15
56
+ SEED = 42
57
+
58
+ USE_UNCERTAINTY = True # tự cân 6 loss (Kendall); False = dùng LOSS_W cố định
59
+ LOSS_W = {"qmos": 1.0, "emos": 1.0, "cat": 1.0, "val": 1.0, "aro": 1.0, "dom": 1.0}
60
+ USE_E2V = True
61
+ USE_SAILER = True
62
+ USE_CLASSPROB = True
63
+ USE_UTMOS_FEAT = True # đưa điểm UTMOS làm đầu vào QMOS head (neo residual quanh 0.414)
64
+
65
+ LIMIT_TRAIN = None
66
+ LIMIT_DEV = None
67
+
68
+ # Mốc exp04 để so (cảnh báo nếu tụt khi gộp QMOS)
69
+ EXP04 = {"emos": 0.788, "cat_err": 0.145, "val": 0.578, "aro": 0.754, "dom": 0.706, "qmos_utmos": 0.414}
70
+
71
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
72
+ SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
73
+ EMO2SAILER = {"angry": 0, "happy": 4, "neutral": 5, "sad": 6, "surprised": 7}
74
+
75
+ _EMO_ALIAS = {
76
+ "angry": "angry", "anger": "angry",
77
+ "happy": "happy", "happiness": "happy", "joy": "happy",
78
+ "neutral": "neutral", "calm": "neutral",
79
+ "sad": "sad", "sadness": "sad",
80
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
81
+ }
82
+
83
+ def norm_emotion(label):
84
+ key = str(label).strip().lower()
85
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
86
+
87
+ def stem(p):
88
+ return os.path.splitext(os.path.basename(str(p)))[0]
89
+
90
+ assert USE_E2V or USE_SAILER, "Phải bật ít nhất 1 backbone."
91
+ print("DATA_ROOT:", DATA_ROOT)
92
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
93
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
94
+
95
+ # %% [markdown]
96
+ # ## 1. Cài đặt + tải code SAILER
97
+
98
+ # %%
99
+ import sys, subprocess
100
+
101
+ def pip_install(*pkgs):
102
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
103
+
104
+ pip_install("speechmos", "funasr", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
105
+
106
+ if USE_SAILER:
107
+ pip_install("loralib", "speechbrain")
108
+ REPO_DIR = "/kaggle/working/vox-profile-release"
109
+ if not os.path.exists(REPO_DIR):
110
+ subprocess.run(["git", "clone", "--depth", "1",
111
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
112
+ if REPO_DIR not in sys.path:
113
+ sys.path.insert(0, REPO_DIR)
114
+
115
+ # %% [markdown]
116
+ # ## 2. Đọc & gộp nhãn (gộp theo wavID) — THÊM cột qMOS
117
+ # Khác exp04: gộp thêm **qMOS** (= TB `qMOS` theo wav) làm nhãn cho QMOS head.
118
+
119
+ # %%
120
+ import numpy as np
121
+ import pandas as pd
122
+
123
+ def load_target_emotions():
124
+ tgt = {}
125
+ with open(METADATA_CSV, encoding="utf-8") as f:
126
+ for ln in f:
127
+ parts = ln.strip().split("|")
128
+ if len(parts) < 2:
129
+ continue
130
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
131
+ return tgt
132
+
133
+ def _col(cols_map, *names, default_idx=None, df=None):
134
+ for n in names:
135
+ if n in cols_map:
136
+ return cols_map[n]
137
+ return list(df.columns)[default_idx] if default_idx is not None else None
138
+
139
+ def parse_emocat_votes(cell):
140
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
141
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
142
+ e = norm_emotion(tok)
143
+ if e in EMOTIONS5:
144
+ v[EMOTIONS5.index(e)] += 1.0
145
+ return v
146
+
147
+ def load_train_labels():
148
+ """train.csv → DataFrame [wavID, qmos, emos, val, aro, dom, cat0..cat4] gộp theo wav."""
149
+ df = pd.read_csv(TRAIN_CSV, sep="|")
150
+ cols = {c.lower().strip(): c for c in df.columns}
151
+ wav_col = _col(cols, "wavid", "wav", default_idx=1, df=df)
152
+ qmos_col = _col(cols, "qmos", "mos")
153
+ emos_col = _col(cols, "emos", "emo", "emomos")
154
+ val_col = _col(cols, "val", "valence")
155
+ aro_col = _col(cols, "aro", "arousal")
156
+ dom_col = _col(cols, "dom", "dominance")
157
+ cat_col = _col(cols, "emocat", "cat", "emotion")
158
+ assert qmos_col, f"Không thấy cột qMOS trong train.csv (cột: {list(df.columns)})"
159
+ assert emos_col, f"Không thấy cột eMOS trong train.csv (cột: {list(df.columns)})"
160
+
161
+ df["_stem"] = df[wav_col].map(stem)
162
+ rows = []
163
+ for sid, g in df.groupby("_stem"):
164
+ rec = {"wavID": sid,
165
+ "qmos": float(g[qmos_col].mean()),
166
+ "emos": float(g[emos_col].mean())}
167
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
168
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
169
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
170
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
171
+ if cat_col:
172
+ for cell in g[cat_col]:
173
+ votes += parse_emocat_votes(cell)
174
+ s = votes.sum()
175
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)
176
+ for i in range(len(EMOTIONS5)):
177
+ rec[f"cat{i}"] = float(cat[i])
178
+ rows.append(rec)
179
+ return pd.DataFrame(rows)
180
+
181
+ target_map = load_target_emotions()
182
+ train_df = load_train_labels()
183
+ HAS_VAD = bool(train_df["val"].notna().any())
184
+ print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
185
+ print("qMOS:", train_df["qmos"].describe()[["mean", "std", "min", "max"]].to_dict())
186
+ print("eMOS:", train_df["emos"].describe()[["mean", "std", "min", "max"]].to_dict())
187
+ train_df.head()
188
+
189
+ # %% [markdown]
190
+ # ## 3. Trích đặc trưng 2 backbone + điểm UTMOS (cache CHUNG với exp04)
191
+
192
+ # %%
193
+ import torch
194
+ import torch.nn.functional as F
195
+
196
+ device = DEVICE if torch.cuda.is_available() else "cpu"
197
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU")
198
+
199
+ def extract_e2v(stems, tag):
200
+ from tqdm.auto import tqdm
201
+ cache_path = os.path.join(CACHE_DIR, f"e2v_{tag}.npz")
202
+ store = {}
203
+ if os.path.exists(cache_path):
204
+ z = np.load(cache_path, allow_pickle=True)
205
+ store = {k: z[k] for k in z.files}
206
+ print(f"[e2v/{tag}] nạp cache: {len(store)}")
207
+ todo = [s for s in stems if s not in store]
208
+ if todo:
209
+ from funasr import AutoModel
210
+ m = AutoModel(model="iic/emotion2vec_plus_large", hub="hf", device=device)
211
+ for i, s in enumerate(tqdm(todo, desc=f"e2v {tag}")):
212
+ wav = os.path.join(WAV_DIR, s + ".wav")
213
+ if not os.path.exists(wav):
214
+ continue
215
+ r = m.generate(wav, granularity="utterance", extract_embedding=True)[0]
216
+ emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
217
+ probs = {e: 0.0 for e in EMOTIONS5}
218
+ for lab, sc in zip(r["labels"], r["scores"]):
219
+ name = lab.split("/")[-1]
220
+ if name in probs:
221
+ probs[name] = float(sc)
222
+ tot = sum(probs.values())
223
+ p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)
224
+ store[s] = np.concatenate([emb, p5]).astype(np.float32)
225
+ if (i + 1) % 500 == 0:
226
+ np.savez(cache_path, **store)
227
+ np.savez(cache_path, **store)
228
+ del m
229
+ torch.cuda.empty_cache() if device == "cuda" else None
230
+ return {s: (v[:-5], v[-5:]) for s, v in store.items()}
231
+
232
+ def _pool_feat(features):
233
+ f = features.detach().cpu().numpy()
234
+ if f.ndim <= 1:
235
+ return f.reshape(-1).astype(np.float32)
236
+ return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)
237
+
238
+ def extract_sailer(stems, tag):
239
+ import librosa
240
+ from tqdm.auto import tqdm
241
+ cache_path = os.path.join(CACHE_DIR, f"sailer_{tag}.npz")
242
+ store = {}
243
+ if os.path.exists(cache_path):
244
+ z = np.load(cache_path, allow_pickle=True)
245
+ store = {k: z[k] for k in z.files}
246
+ print(f"[sailer/{tag}] nạp cache: {len(store)}")
247
+ todo = [s for s in stems if s not in store]
248
+ if todo:
249
+ from src.model.emotion.wavlm_emotion import WavLMWrapper
250
+ sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device).eval()
251
+ with torch.no_grad():
252
+ for i, s in enumerate(tqdm(todo, desc=f"sailer {tag}")):
253
+ wav = os.path.join(WAV_DIR, s + ".wav")
254
+ if not os.path.exists(wav):
255
+ continue
256
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
257
+ wave = wave[: 15 * 16000]
258
+ data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
259
+ logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
260
+ emb = _pool_feat(feat)
261
+ p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
262
+ vad3 = np.array([1 + 4 * float(valence.item()),
263
+ 1 + 4 * float(arousal.item()),
264
+ 1 + 4 * float(dominance.item())], dtype=np.float32)
265
+ store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)
266
+ if (i + 1) % 500 == 0:
267
+ np.savez(cache_path, **store)
268
+ np.savez(cache_path, **store)
269
+ del sailer
270
+ torch.cuda.empty_cache() if device == "cuda" else None
271
+ return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}
272
+
273
+ def extract_utmos(names, tag):
274
+ """Chấm UTMOS từng wav (theo TÊN, vì DEV gọi .wav theo tên). → dict {stem: score}.
275
+ Cache CACHE_DIR/utmos_<tag>.npz. Dùng vừa làm đầu vào QMOS head, vừa làm baseline so sánh."""
276
+ import librosa
277
+ from tqdm.auto import tqdm
278
+ cache_path = os.path.join(CACHE_DIR, f"utmos_{tag}.npz")
279
+ store = {}
280
+ if os.path.exists(cache_path):
281
+ z = np.load(cache_path, allow_pickle=True)
282
+ store = {k: float(z[k]) for k in z.files}
283
+ print(f"[utmos/{tag}] nạp cache: {len(store)}")
284
+ todo = [n for n in names if stem(n) not in store]
285
+ if todo:
286
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong",
287
+ trust_repo=True).to(device).eval()
288
+ with torch.no_grad():
289
+ for i, n in enumerate(tqdm(todo, desc=f"utmos {tag}")):
290
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else n + ".wav")
291
+ if not os.path.exists(wav):
292
+ continue
293
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
294
+ store[stem(n)] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),
295
+ sr=16000).mean().item())
296
+ if (i + 1) % 500 == 0:
297
+ np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
298
+ np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
299
+ del predictor
300
+ torch.cuda.empty_cache() if device == "cuda" else None
301
+ return store
302
+
303
+ # %% [markdown]
304
+ # ## 4. Dựng feature + nhãn cho train
305
+ # Feature audio (cảm xúc) = `[e2v_emb | e2v_p5 | sailer_emb | sailer_p9 | sailer_vad3]` (như exp04).
306
+ # Thêm: vector **UTMOS** (1 số/ wav) cho QMOS head, và nhãn **qMOS**.
307
+
308
+ # %%
309
+ train_stems = list(train_df["wavID"])
310
+ if LIMIT_TRAIN:
311
+ train_stems = train_stems[:LIMIT_TRAIN]
312
+
313
+ e2v_tr = extract_e2v(train_stems, "train") if USE_E2V else {}
314
+ sailer_tr = extract_sailer(train_stems, "train") if USE_SAILER else {}
315
+ utmos_tr = extract_utmos(train_stems, "train") if USE_UTMOS_FEAT else {}
316
+
317
+ def audio_feature(sid, e2v_map, sailer_map):
318
+ parts = []
319
+ if USE_E2V:
320
+ pk = e2v_map.get(sid)
321
+ if pk is None:
322
+ return None
323
+ emb, p5 = pk
324
+ parts.append(emb)
325
+ if USE_CLASSPROB:
326
+ parts.append(p5)
327
+ if USE_SAILER:
328
+ pk = sailer_map.get(sid)
329
+ if pk is None:
330
+ return None
331
+ emb, p9, vad3 = pk
332
+ parts.append(emb)
333
+ if USE_CLASSPROB:
334
+ parts.append(p9); parts.append(vad3)
335
+ return np.concatenate(parts).astype(np.float32)
336
+
337
+ def onehot_target(tgt):
338
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
339
+ if tgt in EMOTIONS5:
340
+ v[EMOTIONS5.index(tgt)] = 1.0
341
+ return v
342
+
343
+ lab = train_df.set_index("wavID")
344
+ X, T, U, y_qmos, y_emos, y_vad, y_cat = [], [], [], [], [], [], []
345
+ for s in train_stems:
346
+ f = audio_feature(s, e2v_tr, sailer_tr)
347
+ tgt = target_map.get(s)
348
+ if f is None or tgt is None or s not in lab.index:
349
+ continue
350
+ if USE_UTMOS_FEAT and s not in utmos_tr:
351
+ continue
352
+ X.append(f)
353
+ T.append(onehot_target(tgt))
354
+ U.append(utmos_tr.get(s, 3.0) if USE_UTMOS_FEAT else 0.0)
355
+ y_qmos.append(lab.loc[s, "qmos"])
356
+ y_emos.append(lab.loc[s, "emos"])
357
+ y_vad.append([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]])
358
+ y_cat.append([lab.loc[s, f"cat{i}"] for i in range(len(EMOTIONS5))])
359
+
360
+ X = np.stack(X).astype(np.float32)
361
+ T = np.stack(T).astype(np.float32)
362
+ U = np.array(U, dtype=np.float32).reshape(-1, 1)
363
+ y_qmos = np.array(y_qmos, dtype=np.float32)
364
+ y_emos = np.array(y_emos, dtype=np.float32)
365
+ y_vad = np.array(y_vad, dtype=np.float32)
366
+ y_cat = np.array(y_cat, dtype=np.float32)
367
+ FEAT_DIM = X.shape[1]
368
+ print(f"Train: X={X.shape} U={U.shape} qmos={y_qmos.shape} emos={y_emos.shape} vad={y_vad.shape}")
369
+
370
+ # Chuẩn hóa feature audio + UTMOS (z-score), lưu mean/std.
371
+ feat_mean = X.mean(0, keepdims=True); feat_std = X.std(0, keepdims=True) + 1e-6
372
+ Xn = (X - feat_mean) / feat_std
373
+ u_mu, u_sd = float(U.mean()), float(U.std() + 1e-6)
374
+ Un = (U - u_mu) / u_sd
375
+
376
+ # Chuẩn hóa nhãn liên tục về z-score.
377
+ qmos_mu, qmos_sd = float(y_qmos.mean()), float(y_qmos.std() + 1e-6)
378
+ y_qmos_z = (y_qmos - qmos_mu) / qmos_sd
379
+ emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6)
380
+ y_emos_z = (y_emos - emos_mu) / emos_sd
381
+ if HAS_VAD:
382
+ vad_mu = np.nanmean(y_vad, axis=0); vad_sd = np.nanstd(y_vad, axis=0) + 1e-6
383
+ y_vad_z = (y_vad - vad_mu) / vad_sd
384
+ else:
385
+ vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
386
+ y_vad_z = np.zeros_like(y_vad)
387
+
388
+ # %% [markdown]
389
+ # ## 5. Model fusion multi-task (6 head) + train loop
390
+ # Thêm so exp04: **QMOS head** nhận `[trunk | UTMOS]` → 1; `qmos` vào uncertainty weighting (6 task).
391
+
392
+ # %%
393
+ import torch.nn as nn
394
+ from scipy.stats import spearmanr
395
+ from sklearn.model_selection import train_test_split
396
+
397
+ torch.manual_seed(SEED); np.random.seed(SEED)
398
+ N_EMO = len(EMOTIONS5)
399
+ idx_all = np.arange(X.shape[0])
400
+ tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)
401
+
402
+ def to_t(a):
403
+ return torch.tensor(a, dtype=torch.float32, device=device)
404
+
405
+ Xn_t, T_t, Un_t = to_t(Xn), to_t(T), to_t(Un)
406
+ qmos_t = to_t(y_qmos_z).unsqueeze(1)
407
+ emos_t = to_t(y_emos_z).unsqueeze(1)
408
+ vad_t = to_t(y_vad_z)
409
+ cat_t = to_t(y_cat)
410
+
411
+ class FusionMTL6(nn.Module):
412
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo, use_utmos):
413
+ super().__init__()
414
+ self.use_utmos = use_utmos
415
+ self.trunk = nn.Sequential(
416
+ nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
417
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p),
418
+ )
419
+ self.qmos = nn.Sequential( # nhận [trunk | utmos] nếu bật
420
+ nn.Linear(trunk_h + (1 if use_utmos else 0), head_h), nn.ReLU(), nn.Dropout(p),
421
+ nn.Linear(head_h, 1))
422
+ self.emos = nn.Sequential( # nhận [trunk | target]
423
+ nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
424
+ self.cat = nn.Sequential(
425
+ nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
426
+ self.vad = nn.Sequential(
427
+ nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
428
+
429
+ def forward(self, x, tgt, utmos):
430
+ h = self.trunk(x)
431
+ qmos_in = torch.cat([h, utmos], dim=1) if self.use_utmos else h
432
+ qmos = self.qmos(qmos_in)
433
+ emos = self.emos(torch.cat([h, tgt], dim=1))
434
+ cat_logits = self.cat(h)
435
+ vad = self.vad(h)
436
+ return qmos, emos, cat_logits, vad
437
+
438
+ model = FusionMTL6(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO, USE_UTMOS_FEAT).to(device)
439
+
440
+ TASKS = ["qmos", "emos", "cat", "val", "aro", "dom"]
441
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
442
+ params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
443
+ opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)
444
+ mse = nn.MSELoss(reduction="none")
445
+
446
+ def soft_ce(logits, target_dist):
447
+ logq = F.log_softmax(logits, dim=1)
448
+ return -(target_dist * logq).sum(dim=1)
449
+
450
+ def task_losses(qmos_p, emos_p, cat_logits, vad_p, b):
451
+ L = {}
452
+ L["qmos"] = mse(qmos_p, qmos_t[b]).mean()
453
+ L["emos"] = mse(emos_p, emos_t[b]).mean()
454
+ L["cat"] = soft_ce(cat_logits, cat_t[b]).mean()
455
+ if HAS_VAD:
456
+ L["val"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()
457
+ L["aro"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()
458
+ L["dom"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()
459
+ else:
460
+ z = torch.zeros((), device=device)
461
+ L["val"] = L["aro"] = L["dom"] = z
462
+ return L
463
+
464
+ def combine(L):
465
+ if USE_UNCERTAINTY:
466
+ tot = 0.0
467
+ for i, t in enumerate(TASKS):
468
+ tot = tot + torch.exp(-log_var[i]) * L[t] + log_var[i]
469
+ return tot
470
+ return sum(LOSS_W[t] * L[t] for t in TASKS)
471
+
472
+ @torch.no_grad()
473
+ def eval_val():
474
+ model.eval()
475
+ qp, ep, cl, vp = model(Xn_t[va_idx], T_t[va_idx], Un_t[va_idx])
476
+ qp = qp.cpu().numpy().ravel(); ep = ep.cpu().numpy().ravel()
477
+ out = {"qmos": spearmanr(qp, y_qmos[va_idx]).correlation,
478
+ "emos": spearmanr(ep, y_emos[va_idx]).correlation}
479
+ if USE_UTMOS_FEAT:
480
+ out["qmos_utmos"] = spearmanr(U[va_idx, 0], y_qmos[va_idx]).correlation # baseline UTMOS đơn lẻ
481
+ if HAS_VAD:
482
+ vp = vp.cpu().numpy()
483
+ for j, t in enumerate(["val", "aro", "dom"]):
484
+ out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation
485
+ q = F.softmax(cl, dim=1).cpu().numpy(); p = y_cat[va_idx]
486
+ kl = (p * (np.log(p + 1e-9) - np.log(q + 1e-9))).sum(1).mean()
487
+ out["cat_negkl"] = float(-kl)
488
+ return out
489
+
490
+ def val_score(m):
491
+ """Điểm tổng early-stop = TB SRCC các task liên tục (qmos+emos+VAD)."""
492
+ keys = ["qmos", "emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
493
+ return float(np.mean([m[k] for k in keys]))
494
+
495
+ best_score, best_state, bad = -1e9, None, 0
496
+ tr_t = torch.tensor(tr_idx, device=device)
497
+ for ep_i in range(1, EPOCHS + 1):
498
+ model.train()
499
+ perm = tr_t[torch.randperm(len(tr_t), device=device)]
500
+ run = 0.0
501
+ for i in range(0, len(perm), BATCH):
502
+ b = perm[i:i + BATCH]
503
+ opt.zero_grad()
504
+ qmos_p, emos_p, cat_logits, vad_p = model(Xn_t[b], T_t[b], Un_t[b])
505
+ L = task_losses(qmos_p, emos_p, cat_logits, vad_p, b)
506
+ loss = combine(L)
507
+ loss.backward(); opt.step()
508
+ run += loss.item() * len(b)
509
+ m = eval_val()
510
+ sc = val_score(m)
511
+ if sc > best_score:
512
+ best_score = sc
513
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
514
+ bad = 0
515
+ else:
516
+ bad += 1
517
+ if ep_i % 5 == 0 or ep_i == 1:
518
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in ["qmos", "emos", "val", "aro", "dom"] if k in m)
519
+ print(f"epoch {ep_i:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}")
520
+ if bad >= PATIENCE:
521
+ print(f"Early stop ở epoch {ep_i}.")
522
+ break
523
+
524
+ model.load_state_dict(best_state)
525
+ final = eval_val()
526
+ print("\n✅ VAL (nội bộ) — exp07 (fusion + QMOS head):")
527
+ print(f" QMOS SRCC = {final['qmos']:.4f}", end="")
528
+ if "qmos_utmos" in final:
529
+ tag = "✅ vượt UTMOS" if final["qmos"] > final["qmos_utmos"] else "⚠️ CHƯA vượt UTMOS"
530
+ print(f" (UTMOS đơn lẻ = {final['qmos_utmos']:.4f} → {tag})")
531
+ else:
532
+ print()
533
+ print(f" EMOS SRCC = {final['emos']:.4f} (mốc exp04 = {EXP04['emos']})")
534
+ if HAS_VAD:
535
+ print(f" VAL/ARO/DOM = {final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}"
536
+ f" (mốc exp04 = {EXP04['val']}/{EXP04['aro']}/{EXP04['dom']})")
537
+ # Cảnh báo negative transfer (gộp QMOS làm tụt cảm xúc)
538
+ warn = []
539
+ if final["emos"] < EXP04["emos"] - 0.02:
540
+ warn.append(f"EMOS {final['emos']:.3f} < {EXP04['emos']}")
541
+ if HAS_VAD:
542
+ for t in ["val", "aro", "dom"]:
543
+ if final[t] < EXP04[t] - 0.02:
544
+ warn.append(f"{t.upper()} {final[t]:.3f} < {EXP04[t]}")
545
+ if warn:
546
+ print(" ⚠️ NEGATIVE TRANSFER? Cảm xúc tụt so exp04:", "; ".join(warn),
547
+ "\n → cân nhắc giữ exp04 cho 5 cột cảm xúc + chỉ lấy QMOS từ exp07/exp06.")
548
+ else:
549
+ print(" ✅ Không thấy 5 cột cảm xúc tụt rõ so exp04.")
550
+ if USE_UNCERTAINTY:
551
+ print(" log σ² mỗi task:", {t: round(float(log_var[i]), 3) for i, t in enumerate(TASKS)})
552
+
553
+ torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
554
+ "u_mu": u_mu, "u_sd": u_sd,
555
+ "qmos_mu": qmos_mu, "qmos_sd": qmos_sd, "emos_mu": emos_mu, "emos_sd": emos_sd,
556
+ "vad_mu": vad_mu, "vad_sd": vad_sd, "FEAT_DIM": FEAT_DIM,
557
+ "USE_E2V": USE_E2V, "USE_SAILER": USE_SAILER, "USE_CLASSPROB": USE_CLASSPROB,
558
+ "USE_UTMOS_FEAT": USE_UTMOS_FEAT, "val_score": best_score},
559
+ os.path.join(OUT_DIR, "fusion_qmos_mtl.pt"))
560
+ print("Đã lưu", os.path.join(OUT_DIR, "fusion_qmos_mtl.pt"))
561
+
562
+ # %% [markdown]
563
+ # ## 6. Dự đoán DEV → `answer.txt` đủ 6 cột (QMOS giờ từ HEAD, không phải SpeechMOS riêng)
564
+
565
+ # %%
566
+ def list_dev():
567
+ with open(DEV_SCP) as f:
568
+ return [ln.strip() for ln in f if ln.strip()]
569
+
570
+ dev_names = list_dev()
571
+ if LIMIT_DEV:
572
+ dev_names = dev_names[:LIMIT_DEV]
573
+ dev_stems = [stem(n) for n in dev_names]
574
+ print("DEV:", len(dev_names), "mẫu")
575
+
576
+ e2v_dev = extract_e2v(dev_stems, "dev") if USE_E2V else {}
577
+ sailer_dev = extract_sailer(dev_stems, "dev") if USE_SAILER else {}
578
+ utmos_dev = extract_utmos(dev_names, "dev") if USE_UTMOS_FEAT else {}
579
+
580
+ @torch.no_grad()
581
+ def predict_all(sid):
582
+ f = audio_feature(sid, e2v_dev, sailer_dev)
583
+ if f is None:
584
+ return None
585
+ fn = (f[None, :] - feat_mean) / feat_std
586
+ tgt = onehot_target(target_map.get(sid))[None, :]
587
+ u = np.array([[utmos_dev.get(sid, 3.0)]], dtype=np.float32)
588
+ un = (u - u_mu) / u_sd
589
+ model.eval()
590
+ qmos_p, emos_p, cat_logits, vad_p = model(to_t(fn), to_t(tgt), to_t(un))
591
+ qmos = float(qmos_p.item()) * qmos_sd + qmos_mu
592
+ emos = float(emos_p.item()) * emos_sd + emos_mu
593
+ cat5 = F.softmax(cat_logits, dim=1)[0].cpu().numpy()
594
+ vad3 = vad_p[0].cpu().numpy() * vad_sd + vad_mu
595
+ return qmos, emos, cat5, vad3
596
+
597
+ def fmt_cat(probs5):
598
+ return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
599
+
600
+ def build_answer(out_path):
601
+ from tqdm.auto import tqdm
602
+ n_real = n_default = 0
603
+ with open(out_path, "w") as f:
604
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
605
+ for name in tqdm(dev_names, desc="answer"):
606
+ sid = stem(name)
607
+ pred = predict_all(sid)
608
+ if pred is None:
609
+ # rơi về: QMOS=UTMOS nếu có, còn lại mặc định
610
+ qmos = utmos_dev.get(sid, 3.0)
611
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])
612
+ n_default += 1
613
+ else:
614
+ qmos, emos, cat5, vad3 = pred
615
+ n_real += 1
616
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
617
+ f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
618
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | head thật {n_real}, mặc định {n_default}")
619
+
620
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
621
+ build_answer(answer_path)
622
+
623
+ # %% [markdown]
624
+ # ## 7. Validate + đóng zip
625
+
626
+ # %%
627
+ def validate(path):
628
+ import csv
629
+ with open(path) as f:
630
+ rows = list(csv.reader(f))
631
+ header = rows[0]
632
+ assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
633
+ for i, r in enumerate(rows[1:], 2):
634
+ assert len(r) == len(header), f"Dòng {i} sai số cột"
635
+ print(f"OK: {len(rows)-1} dòng, header = {header}")
636
+
637
+ validate(answer_path)
638
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp07_fusion_qmos.zip answer.txt "
639
+ f"&& unzip -l submission_track2_exp07_fusion_qmos.zip")
640
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp07_fusion_qmos.zip"))
641
+
642
+ # %% [markdown]
643
+ # ## Ghi chú
644
+ # - **Lần đầu** đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`; OK rồi đặt `None`.
645
+ # - **Đọc kết quả mục 5 theo 2 câu hỏi:**
646
+ # 1. QMOS head có **vượt UTMOS đơn lẻ (0.414)** không? (dòng "vượt/CHƯA vượt UTMOS")
647
+ # 2. Gộp QMOS có **làm tụt** EMOS/VAD so exp04 không? (dòng "NEGATIVE TRANSFER?")
648
+ # - **Quyết định nộp:**
649
+ # - Nếu QMOS↑ và cảm xúc KHÔNG tụt → nộp answer.txt exp07 (1 model trọn 6 cột — đẹp cho paper).
650
+ # - Nếu QMOS↑ nhưng cảm xúc TỤT → giữ exp04 cho 5 cột cảm xúc, chỉ lấy **cột QMOS** của exp07/exp06 ghép vào.
651
+ # - Nếu QMOS không vượt UTMOS → kết luận "chất lượng trực giao cảm xúc" (vẫn là phát hiện cho paper); giữ exp04.
652
+ # - **Ablation cho paper**: `USE_UTMOS_FEAT=False` (QMOS chỉ từ trunk cảm xúc) → đo trực tiếp giả thuyết của bạn.
653
+ # - Cache dùng CHUNG `fusion_cache/` với exp04 → **Save Version** giữ lại.
654
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp07).
track2/exp08_finetune_emotion.ipynb ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "ee3b7231",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp08 (FINE-TUNE WavLM cho 5 cột cảm xúc) — Kaggle\n",
9
+ "\n",
10
+ "**Khác mọi exp trước:** exp03–07 đều **đóng băng** backbone (chỉ trích đặc trưng + train head nhỏ trên cache).\n",
11
+ "exp08 **MỞ BĂNG (fine-tune)** WavLM-large để nó học lại đặc trưng riêng cho bài MOS cảm xúc 2026.\n",
12
+ "\n",
13
+ "## Thiết kế (chốt với mentor 5/6)\n",
14
+ "```\n",
15
+ " wav ─┬─► WavLM-large (warm-start SAILER, TRAINABLE: chỉ mở băng N lớp trên) ─► pool ─► emb_wavlm ┐\n",
16
+ " └─► audeering MSP-dim (FROZEN, cache .npz) ─► [emb_aud | vad3] ├─► TRUNK ─┬─► EMOS (+target)\n",
17
+ " ┘ ├─► CAT (5)\n",
18
+ " └─► VAD (3)\n",
19
+ " QMOS: KHÔNG train ở đây → mượn cột QMOS của exp07 (0.548) hoặc UTMOSv2 (T05, vô địch VMC2024).\n",
20
+ "```\n",
21
+ "- **Warm-start:** khởi tạo WavLM từ checkpoint **SAILER** (`tiantiaf/wavlm-large-categorical-emotion`,\n",
22
+ " đã giỏi cảm xúc) thay vì WavLM \"trắng\" → điểm xuất phát tốt hơn nhiều.\n",
23
+ "- **Phụ (frozen):** audeering — dimensional, bổ trợ góc nhìn categorical của WavLM, kỳ vọng kéo **VAL**.\n",
24
+ "- **Đóng băng partial:** chỉ train `UNFREEZE_TOP_LAYERS` lớp Transformer trên cùng + feature-extractor giữ băng\n",
25
+ " → tiết kiệm VRAM T4 + chống overfit (chỉ 12.7k mẫu).\n",
26
+ "\n",
27
+ "## ⚠️ Đánh đổi phải biết trước (so freeze+head)\n",
28
+ "- **Mất lợi thế cache:** mỗi epoch chạy lại cả WavLM (forward+backward) → chậm & đốt giờ GPU (30h/tuần).\n",
29
+ " → **Lần đầu BẮT BUỘC đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.\n",
30
+ "- **Dễ overfit / OOM:** nếu OOM → giảm `BATCH`, tăng `ACCUM`, giảm `MAX_SECONDS`, giảm `UNFREEZE_TOP_LAYERS`.\n",
31
+ "- **Lưới an toàn:** exp07 vẫn là bản nộp vô địch tới khi exp08 **thắng trên VAL nội bộ** (đừng đốt lượt nộp).\n",
32
+ "\n",
33
+ "**Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All."
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "id": "656d8385",
39
+ "metadata": {},
40
+ "source": [
41
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "38d86264",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "import os\n",
52
+ "\n",
53
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
54
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
55
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
56
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
57
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
58
+ "\n",
59
+ "OUT_DIR = \"/kaggle/working\"\n",
60
+ "CACHE_DIR = \"/kaggle/working/ft_cache\" # cache audeering (.npz) — backbone WavLM KHÔNG cache (đang train)\n",
61
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
62
+ "\n",
63
+ "# (Tùy chọn) TÁI DÙNG cache audeering cũ: trỏ tới dataset chứa aud_train.npz/aud_dev.npz → tự copy sang CACHE_DIR.\n",
64
+ "# Để \"\" nếu chạy mới hoàn toàn. /kaggle/input read-only nên phải copy sang working để ghi/append.\n",
65
+ "CACHE_INPUT = \"/kaggle/input/datasets/minhtoan2/cache-exp8\" # << SỬA slug cho khớp (hoặc \"\")\n",
66
+ "if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
67
+ " import shutil\n",
68
+ " _n = 0\n",
69
+ " for _fn in os.listdir(CACHE_INPUT):\n",
70
+ " if _fn.startswith(\"aud_\") and _fn.endswith(\".npz\"):\n",
71
+ " shutil.copy(os.path.join(CACHE_INPUT, _fn), os.path.join(CACHE_DIR, _fn)); _n += 1\n",
72
+ " print(f\"📦 Tái dùng cache: copy {_n} file aud_*.npz từ {CACHE_INPUT} → {CACHE_DIR}\")\n",
73
+ "\n",
74
+ "# Mượn cột QMOS của exp07 (tốt nhất 0.548). Trỏ tới answer.txt exp07 nếu có; không thì dùng UTMOSv2.\n",
75
+ "EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) Add Input answer.txt exp07; không có → UTMOSv2\n",
76
+ "\n",
77
+ "# ── Fine-tune / siêu tham số ─────────────────────────────────────────────────\n",
78
+ "DEVICE = \"cuda\"\n",
79
+ "SR = 16000\n",
80
+ "MAX_SECONDS = 8 # cắt audio để chặn bộ nhớ backprop; OOM thì giảm còn 6\n",
81
+ "UNFREEZE_TOP_LAYERS = 6 # số lớp Transformer trên cùng được train (0 = freeze hết = quay về head-only)\n",
82
+ "TRUNK_HIDDEN = 512\n",
83
+ "HEAD_HIDDEN = 128\n",
84
+ "DROPOUT = 0.3\n",
85
+ "LR_BACKBONE = 1e-5 # LR nhỏ cho backbone fine-tune\n",
86
+ "LR_HEAD = 1e-3 # LR lớn cho trunk + head (train từ đầu)\n",
87
+ "WEIGHT_DECAY = 1e-5\n",
88
+ "EPOCHS = 12 # TRẦN; early-stop quyết định số epoch thực (8 hơi thấp cho lần chạy thật)\n",
89
+ "PATIENCE = 3 # dừng khi val SRCC không lên 3 epoch; LUÔN giữ best_state\n",
90
+ "BATCH = 4 # nhỏ vì backbone to; tăng ACCUM để bù\n",
91
+ "ACCUM = 8 # effective batch = BATCH*ACCUM = 32\n",
92
+ "VAL_FRAC = 0.10\n",
93
+ "SEED = 42\n",
94
+ "USE_AMP = True # mixed precision fp16 — tiết kiệm VRAM\n",
95
+ "USE_GRAD_CKPT = True # gradient checkpointing — tiết kiệm VRAM (đổi lấy chậm hơn)\n",
96
+ "USE_AUDEERING = True # nhánh phụ frozen audeering; False = chỉ WavLM\n",
97
+ "USE_UNCERTAINTY = True # tự cân 5 loss (Kendall); False = trọng số 1.0\n",
98
+ "\n",
99
+ "LIMIT_TRAIN = 300 # << LẦN ĐẦU để 300; chạy thật đặt None\n",
100
+ "LIMIT_DEV = 20 # << LẦN ĐẦU để 20; chạy thật đặt None\n",
101
+ "\n",
102
+ "# Mốc exp07 để so (cảnh báo nếu fine-tune KHÔNG thắng → giữ exp07)\n",
103
+ "EXP07 = {\"emos\": 0.795, \"cat_err\": 0.153, \"val\": 0.581, \"aro\": 0.752, \"dom\": 0.705}\n",
104
+ "\n",
105
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
106
+ "\n",
107
+ "_EMO_ALIAS = {\n",
108
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
109
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
110
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
111
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
112
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
113
+ "}\n",
114
+ "\n",
115
+ "def norm_emotion(label):\n",
116
+ " key = str(label).strip().lower()\n",
117
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
118
+ "\n",
119
+ "def stem(p):\n",
120
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
121
+ "\n",
122
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
123
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
124
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
125
+ "print(f\"Fine-tune: mở băng {UNFREEZE_TOP_LAYERS} lớp trên · BATCH {BATCH}×ACCUM {ACCUM} · MAX {MAX_SECONDS}s\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "id": "ed538923",
131
+ "metadata": {},
132
+ "source": [
133
+ "## 1. Cài đặt + tải code SAILER (clone + sys.path, KHÔNG pip install -e .)"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "id": "f052d016",
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "import sys, subprocess\n",
144
+ "\n",
145
+ "def pip_install(*pkgs):\n",
146
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
147
+ "\n",
148
+ "pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\",\n",
149
+ " \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
150
+ "\n",
151
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
152
+ "if not os.path.exists(REPO_DIR):\n",
153
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
154
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
155
+ "if REPO_DIR not in sys.path:\n",
156
+ " sys.path.insert(0, REPO_DIR)"
157
+ ]
158
+ },
159
+ {
160
+ "cell_type": "markdown",
161
+ "id": "0bf41e8c",
162
+ "metadata": {},
163
+ "source": [
164
+ "## 2. Nạp SAILER → lấy backbone WavLM bên trong để FINE-TUNE\n",
165
+ "Thay vì gọi wrapper như hộp đen, ta **lôi module WavLM-large (HuggingFace) bên trong wrapper** ra\n",
166
+ "→ toàn quyền đóng băng/mở băng từng lớp + tự pool. Nếu không tìm thấy (cấu trúc lạ) → **fallback**\n",
167
+ "nạp `microsoft/wavlm-large` trắng (mất warm-start, có cảnh báo)."
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "id": "50a7cac6",
174
+ "metadata": {
175
+ "lines_to_next_cell": 1
176
+ },
177
+ "outputs": [],
178
+ "source": [
179
+ "import torch\n",
180
+ "import torch.nn as nn\n",
181
+ "import torch.nn.functional as F\n",
182
+ "\n",
183
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
184
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
185
+ "\n",
186
+ "def find_hf_backbone(module):\n",
187
+ " \"\"\"Tìm submodule kiểu HF Wav2Vec2/WavLM backbone: có .feature_extractor và .encoder.layers.\"\"\"\n",
188
+ " cands = []\n",
189
+ " for name, m in module.named_modules():\n",
190
+ " enc = getattr(m, \"encoder\", None)\n",
191
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
192
+ " and getattr(enc, \"layers\", None) is not None:\n",
193
+ " cands.append((name, m))\n",
194
+ " if not cands:\n",
195
+ " return None, None\n",
196
+ " cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
197
+ " return cands[0]\n",
198
+ "\n",
199
+ "wavlm = None\n",
200
+ "try:\n",
201
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
202
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
203
+ " name, wavlm = find_hf_backbone(_wrapper)\n",
204
+ " if wavlm is not None:\n",
205
+ " print(f\"✅ Warm-start SAILER: lấy backbone WavLM bên trong wrapper tại '.{name}' \"\n",
206
+ " f\"({sum(p.numel() for p in wavlm.parameters())/1e6:.0f}M params)\")\n",
207
+ " else:\n",
208
+ " print(\"⚠️ Không tìm thấy backbone HF bên trong wrapper SAILER → sẽ fallback WavLM trắng.\")\n",
209
+ "except Exception as e:\n",
210
+ " print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
211
+ "\n",
212
+ "if wavlm is None:\n",
213
+ " from transformers import WavLMModel\n",
214
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
215
+ " print(\"ℹ️ Fallback: nạp microsoft/wavlm-large (KHÔNG warm-start SAILER).\")\n",
216
+ "\n",
217
+ "wavlm = wavlm.to(device)\n",
218
+ "WAVLM_DIM = int(wavlm.config.hidden_size)\n",
219
+ "\n",
220
+ "# ── Đóng băng partial: feature-extractor + tất cả trừ UNFREEZE_TOP_LAYERS lớp trên ──\n",
221
+ "for p in wavlm.parameters():\n",
222
+ " p.requires_grad = False\n",
223
+ "enc_layers = wavlm.encoder.layers\n",
224
+ "n_layers = len(enc_layers)\n",
225
+ "for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:\n",
226
+ " for p in layer.parameters():\n",
227
+ " p.requires_grad = True\n",
228
+ "n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)\n",
229
+ "print(f\"WavLM: {n_layers} lớp encoder · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} lớp trên \"\n",
230
+ " f\"→ {n_train/1e6:.1f}M param train (trên dim {WAVLM_DIM})\")\n",
231
+ "\n",
232
+ "if USE_GRAD_CKPT:\n",
233
+ " wavlm.gradient_checkpointing_enable()\n",
234
+ " if hasattr(wavlm, \"enable_input_require_grads\"):\n",
235
+ " wavlm.enable_input_require_grads() # cần khi grad-ckpt + lớp dưới đóng băng\n",
236
+ "\n",
237
+ "def masked_mean(hidden, attn_mask):\n",
238
+ " \"\"\"Mean-pool theo thời gian, bỏ qua phần pad (giữ gradient).\"\"\"\n",
239
+ " if attn_mask is None:\n",
240
+ " return hidden.mean(dim=1)\n",
241
+ " try:\n",
242
+ " fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
243
+ " except Exception:\n",
244
+ " return hidden.mean(dim=1)\n",
245
+ " fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
246
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
247
+ "\n",
248
+ "def wavlm_embed(input_values, attn_mask):\n",
249
+ " out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state # [B,T,D]\n",
250
+ " return masked_mean(out, attn_mask)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "markdown",
255
+ "id": "d8b8b8de",
256
+ "metadata": {},
257
+ "source": [
258
+ "## 3. Nạp audeering MSP-dim (FROZEN) — đặc trưng phụ\n",
259
+ "Lấy `[emb_pool(1024) | vad3(1–5)]` mỗi wav rồi **cache .npz** (chạy 1 lần). Kỹ thuật nạp head tay\n",
260
+ "y hệt exp05 (tránh lỗi version transformers khi subclass `Wav2Vec2PreTrainedModel`)."
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "id": "8731aa54",
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "AUD_DIM = 0\n",
271
+ "aud_backbone = aud_head = aud_proc = None\n",
272
+ "if USE_AUDEERING:\n",
273
+ " from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
274
+ " from huggingface_hub import hf_hub_download\n",
275
+ " AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
276
+ " aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
277
+ " aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
278
+ " aud_backbone = Wav2Vec2Model(aud_cfg)\n",
279
+ " try:\n",
280
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
281
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
282
+ " except Exception:\n",
283
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
284
+ " bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
285
+ " missing, unexpected = aud_backbone.load_state_dict(bb_sd, strict=False)\n",
286
+ " print(f\" audeering backbone: thiếu {len(missing)} / dư {len(unexpected)} key (strict=False)\")\n",
287
+ " _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
288
+ " _out = _sd[\"classifier.out_proj.weight\"].shape[0]\n",
289
+ " aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
290
+ " aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
291
+ " aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
292
+ " aud_backbone = aud_backbone.to(device).eval()\n",
293
+ " aud_head = aud_head.to(device).eval()\n",
294
+ " AUD_DIM = _hid + 3 # emb_pool + [VAL,ARO,DOM]\n",
295
+ " print(f\"✅ audeering frozen (đặc trưng phụ {AUD_DIM}-D = emb {_hid} + vad 3)\")"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": null,
301
+ "id": "d12e4737",
302
+ "metadata": {
303
+ "lines_to_next_cell": 1
304
+ },
305
+ "outputs": [],
306
+ "source": [
307
+ "import numpy as np\n",
308
+ "import librosa\n",
309
+ "from tqdm.auto import tqdm\n",
310
+ "\n",
311
+ "def load_wav(name_or_stem, in_wav_dir=True):\n",
312
+ " p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(\n",
313
+ " WAV_DIR, name_or_stem if str(name_or_stem).endswith(\".wav\") else str(name_or_stem) + \".wav\")\n",
314
+ " if not os.path.exists(p):\n",
315
+ " return None\n",
316
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
317
+ " return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
318
+ "\n",
319
+ "@torch.no_grad()\n",
320
+ "def extract_audeering(stems, tag):\n",
321
+ " \"\"\"→ dict {stem: float32[AUD_DIM]}; cache CACHE_DIR/aud_<tag>.npz (resume mỗi 500).\"\"\"\n",
322
+ " if not USE_AUDEERING:\n",
323
+ " return {}\n",
324
+ " cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
325
+ " store = {}\n",
326
+ " if os.path.exists(cache_path):\n",
327
+ " z = np.load(cache_path, allow_pickle=True)\n",
328
+ " store = {k: z[k] for k in z.files}\n",
329
+ " print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
330
+ " todo = [s for s in stems if s not in store]\n",
331
+ " for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
332
+ " wave = load_wav(s)\n",
333
+ " if wave is None:\n",
334
+ " continue\n",
335
+ " x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
336
+ " x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
337
+ " h = aud_backbone(x)[0].mean(dim=1) # [1, hid]\n",
338
+ " out = aud_head(h)[0].cpu().numpy() # [arousal, dominance, valence] ∈[0,1]\n",
339
+ " vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
340
+ " store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
341
+ " if (i + 1) % 500 == 0:\n",
342
+ " np.savez(cache_path, **store)\n",
343
+ " if todo:\n",
344
+ " np.savez(cache_path, **store)\n",
345
+ " return store"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "markdown",
350
+ "id": "3397dbe7",
351
+ "metadata": {},
352
+ "source": [
353
+ "## 4. Đọc & gộp nhãn theo wavID (EMOS / VAD / CAT) — như exp04/07 nhưng KHÔNG cần qMOS"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": null,
359
+ "id": "df3b95e3",
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "import pandas as pd\n",
364
+ "\n",
365
+ "def load_target_emotions():\n",
366
+ " tgt = {}\n",
367
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
368
+ " for ln in f:\n",
369
+ " parts = ln.strip().split(\"|\")\n",
370
+ " if len(parts) >= 2:\n",
371
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
372
+ " return tgt\n",
373
+ "\n",
374
+ "def _col(cols_map, *names, df=None, default_idx=None):\n",
375
+ " for n in names:\n",
376
+ " if n in cols_map:\n",
377
+ " return cols_map[n]\n",
378
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
379
+ "\n",
380
+ "def parse_emocat_votes(cell):\n",
381
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
382
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
383
+ " e = norm_emotion(tok)\n",
384
+ " if e in EMOTIONS5:\n",
385
+ " v[EMOTIONS5.index(e)] += 1.0\n",
386
+ " return v\n",
387
+ "\n",
388
+ "def load_train_labels():\n",
389
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
390
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
391
+ " wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
392
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
393
+ " val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
394
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
395
+ " assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
396
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
397
+ " rows = []\n",
398
+ " for sid, g in df.groupby(\"_stem\"):\n",
399
+ " rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
400
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
401
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
402
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
403
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
404
+ " if cat_col:\n",
405
+ " for cell in g[cat_col]:\n",
406
+ " votes += parse_emocat_votes(cell)\n",
407
+ " s = votes.sum()\n",
408
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
409
+ " for i in range(len(EMOTIONS5)):\n",
410
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
411
+ " rows.append(rec)\n",
412
+ " return pd.DataFrame(rows)\n",
413
+ "\n",
414
+ "target_map = load_target_emotions()\n",
415
+ "train_df = load_train_labels()\n",
416
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
417
+ "print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "markdown",
422
+ "id": "48ea29a7",
423
+ "metadata": {},
424
+ "source": [
425
+ "## 5. Dataset / DataLoader (load wav theo batch — KHÔNG cache WavLM vì đang train)"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "id": "478f2af9",
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "from torch.utils.data import Dataset, DataLoader\n",
436
+ "\n",
437
+ "train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
438
+ "if LIMIT_TRAIN:\n",
439
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
440
+ "aud_tr = extract_audeering(train_stems, \"train\")\n",
441
+ "\n",
442
+ "lab = train_df.set_index(\"wavID\")\n",
443
+ "\n",
444
+ "# Chuẩn hóa nhãn liên tục về z-score (để các MSE cùng thang) — lưu để giải mã lúc dự đoán.\n",
445
+ "def _zfit(arr):\n",
446
+ " a = np.asarray(arr, dtype=np.float32)\n",
447
+ " return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
448
+ "\n",
449
+ "emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
450
+ "if HAS_VAD:\n",
451
+ " vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
452
+ " vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
453
+ "else:\n",
454
+ " vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
455
+ "\n",
456
+ "def onehot_target(tgt):\n",
457
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
458
+ " if tgt in EMOTIONS5:\n",
459
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
460
+ " return v\n",
461
+ "\n",
462
+ "class EmoDataset(Dataset):\n",
463
+ " def __init__(self, stems):\n",
464
+ " self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]\n",
465
+ " def __len__(self):\n",
466
+ " return len(self.stems)\n",
467
+ " def __getitem__(self, i):\n",
468
+ " s = self.stems[i]\n",
469
+ " wave = load_wav(s)\n",
470
+ " emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
471
+ " if HAS_VAD:\n",
472
+ " vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
473
+ " else:\n",
474
+ " vad = np.zeros(3, dtype=np.float32)\n",
475
+ " cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
476
+ " aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)\n",
477
+ " return {\"wave\": wave, \"tgt\": onehot_target(target_map.get(s)), \"aud\": aud,\n",
478
+ " \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
479
+ " \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
480
+ " \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
481
+ "\n",
482
+ "def collate(batch):\n",
483
+ " lens = [len(b[\"wave\"]) for b in batch]\n",
484
+ " L = max(lens)\n",
485
+ " waves = np.zeros((len(batch), L), dtype=np.float32)\n",
486
+ " mask = np.zeros((len(batch), L), dtype=np.float32)\n",
487
+ " for i, b in enumerate(batch):\n",
488
+ " waves[i, : len(b[\"wave\"])] = b[\"wave\"]; mask[i, : len(b[\"wave\"])] = 1.0\n",
489
+ " out = {\n",
490
+ " \"input_values\": torch.from_numpy(waves), \"attn_mask\": torch.from_numpy(mask).long(),\n",
491
+ " \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
492
+ " \"aud\": torch.from_numpy(np.stack([b[\"aud\"] for b in batch])) if USE_AUDEERING else None,\n",
493
+ " \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
494
+ " \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
495
+ " \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
496
+ " \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
497
+ " \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
498
+ " }\n",
499
+ " return out\n",
500
+ "\n",
501
+ "from sklearn.model_selection import train_test_split\n",
502
+ "ds = EmoDataset(train_stems)\n",
503
+ "print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
504
+ "tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
505
+ "tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
506
+ "va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "markdown",
511
+ "id": "f3342c6f",
512
+ "metadata": {},
513
+ "source": [
514
+ "## 6. Head fusion (trunk + 3 head cảm xúc) + train loop (AMP + grad accumulation)"
515
+ ]
516
+ },
517
+ {
518
+ "cell_type": "code",
519
+ "execution_count": null,
520
+ "id": "3671b2da",
521
+ "metadata": {
522
+ "lines_to_next_cell": 1
523
+ },
524
+ "outputs": [],
525
+ "source": [
526
+ "from scipy.stats import spearmanr\n",
527
+ "\n",
528
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
529
+ "N_EMO = len(EMOTIONS5)\n",
530
+ "TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
531
+ "\n",
532
+ "class EmoHeads(nn.Module):\n",
533
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
534
+ " super().__init__()\n",
535
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
536
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
537
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
538
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
539
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
540
+ " def forward(self, feat, tgt):\n",
541
+ " h = self.trunk(feat)\n",
542
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
543
+ "\n",
544
+ "heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
545
+ "print(f\"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM if USE_AUDEERING else 0})\")\n",
546
+ "\n",
547
+ "TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
548
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
549
+ "bb_params = [p for p in wavlm.parameters() if p.requires_grad]\n",
550
+ "head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
551
+ "opt = torch.optim.AdamW([\n",
552
+ " {\"params\": bb_params, \"lr\": LR_BACKBONE},\n",
553
+ " {\"params\": head_params, \"lr\": LR_HEAD},\n",
554
+ "], weight_decay=WEIGHT_DECAY)\n",
555
+ "scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
556
+ "mse = nn.MSELoss()\n",
557
+ "\n",
558
+ "def soft_ce(logits, target_dist):\n",
559
+ " return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
560
+ "\n",
561
+ "def forward_batch(b):\n",
562
+ " feat_wavlm = wavlm_embed(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
563
+ " if USE_AUDEERING:\n",
564
+ " feat = torch.cat([feat_wavlm, b[\"aud\"].to(device)], dim=1)\n",
565
+ " else:\n",
566
+ " feat = feat_wavlm\n",
567
+ " return heads(feat, b[\"tgt\"].to(device))\n",
568
+ "\n",
569
+ "def compute_loss(emos_p, cat_l, vad_p, b):\n",
570
+ " L = {}\n",
571
+ " L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
572
+ " L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
573
+ " if HAS_VAD:\n",
574
+ " vt = b[\"vad\"].to(device)\n",
575
+ " L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
576
+ " else:\n",
577
+ " z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
578
+ " if USE_UNCERTAINTY:\n",
579
+ " return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
580
+ " return sum(L.values())\n",
581
+ "\n",
582
+ "@torch.no_grad()\n",
583
+ "def evaluate():\n",
584
+ " wavlm.eval(); heads.eval()\n",
585
+ " P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
586
+ " catP, catY = [], []\n",
587
+ " for b in va_loader:\n",
588
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
589
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
590
+ " P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
591
+ " vad_p = vad_p.float().cpu().numpy()\n",
592
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
593
+ " P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
594
+ " catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
595
+ " out = {}\n",
596
+ " for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
597
+ " out[t] = spearmanr(P[t], Y[t]).correlation\n",
598
+ " q = np.concatenate(catP); p = np.concatenate(catY)\n",
599
+ " out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean()) # ~ tổng |Δ| trung bình (xấp xỉ CAT-ERR)\n",
600
+ " return out\n",
601
+ "\n",
602
+ "def mean_srcc(m):\n",
603
+ " keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
604
+ " return float(np.mean([m[k] for k in keys]))\n",
605
+ "\n",
606
+ "# Lưu checkpoint FULL (có backbone WavLM) — gọi NGAY mỗi best để kernel chết giữa chừng vẫn còn file.\n",
607
+ "CKPT_PATH = os.path.join(OUT_DIR, \"ft_emotion_full.pt\")\n",
608
+ "def save_full_ckpt(state, val_emos=float(\"nan\")):\n",
609
+ " torch.save({\"wavlm\": state[\"wavlm\"], \"heads\": state[\"heads\"],\n",
610
+ " \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
611
+ " \"WAVLM_DIM\": WAVLM_DIM, \"AUD_DIM\": AUD_DIM,\n",
612
+ " \"UNFREEZE_TOP_LAYERS\": UNFREEZE_TOP_LAYERS, \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
613
+ "\n",
614
+ "best, best_state, bad = -1e9, None, 0\n",
615
+ "for ep in range(1, EPOCHS + 1):\n",
616
+ " wavlm.train(); heads.train()\n",
617
+ " opt.zero_grad(); run = 0.0; nb = 0\n",
618
+ " for step, b in enumerate(tqdm(tr_loader, desc=f\"epoch {ep}\")):\n",
619
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
620
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
621
+ " loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
622
+ " scaler.scale(loss).backward()\n",
623
+ " if (step + 1) % ACCUM == 0:\n",
624
+ " scaler.step(opt); scaler.update(); opt.zero_grad()\n",
625
+ " run += loss.item() * ACCUM; nb += 1\n",
626
+ " m = evaluate(); sc = mean_srcc(m)\n",
627
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
628
+ " print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
629
+ " if sc > best:\n",
630
+ " best = sc\n",
631
+ " best_state = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
632
+ " \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
633
+ " save_full_ckpt(best_state, m[\"emos\"]) # LƯU NGAY mỗi best → an toàn nếu kernel chết\n",
634
+ " print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\")\n",
635
+ " bad = 0\n",
636
+ " else:\n",
637
+ " bad += 1\n",
638
+ " if bad >= PATIENCE:\n",
639
+ " print(f\"Early stop ở epoch {ep}.\"); break\n",
640
+ "\n",
641
+ "if best_state:\n",
642
+ " wavlm.load_state_dict(best_state[\"wavlm\"]); heads.load_state_dict(best_state[\"heads\"])\n",
643
+ "final = evaluate()\n",
644
+ "print(\"\\n✅ VAL (nội bộ) — exp08 (fine-tune WavLM cho cảm xúc):\")\n",
645
+ "print(f\" EMOS={final['emos']:.4f} (exp07 {EXP07['emos']})\")\n",
646
+ "if HAS_VAD:\n",
647
+ " print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} \"\n",
648
+ " f\"(exp07 {EXP07['val']}/{EXP07['aro']}/{EXP07['dom']})\")\n",
649
+ "warn = [f\"EMOS {final['emos']:.3f}<{EXP07['emos']}\"] if final[\"emos\"] < EXP07[\"emos\"] - 0.005 else []\n",
650
+ "if HAS_VAD:\n",
651
+ " warn += [f\"{t.upper()} {final[t]:.3f}<{EXP07[t]}\" for t in [\"val\", \"aro\", \"dom\"] if final[t] < EXP07[t] - 0.005]\n",
652
+ "print(\" ⚠️ CHƯA thắng exp07 ở:\", \"; \".join(warn), \"→ cân nhắc giữ exp07.\" if warn else \"\")\n",
653
+ "if not warn:\n",
654
+ " print(\" ✅ Fine-tune thắng/ngang exp07 ở mọi cột cảm xúc → đáng nộp.\")\n",
655
+ "# Lưu lần cuối từ best (đã lưu sẵn mỗi best trong loop; đây là phát cuối cho chắc).\n",
656
+ "save_full_ckpt(best_state if best_state else\n",
657
+ " {\"wavlm\": wavlm.state_dict(), \"heads\": heads.state_dict()}, final[\"emos\"])\n",
658
+ "print(f\"✅ Đã lưu {CKPT_PATH} (CÓ backbone WavLM + heads → resume được). \"\n",
659
+ " f\"NHỚ Save Version để file ra Output!\")"
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "markdown",
664
+ "id": "7b0b1b42",
665
+ "metadata": {},
666
+ "source": [
667
+ "## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc từ exp08; QMOS mượn exp07 hoặc UTMOS)"
668
+ ]
669
+ },
670
+ {
671
+ "cell_type": "code",
672
+ "execution_count": null,
673
+ "id": "b29f616c",
674
+ "metadata": {
675
+ "lines_to_next_cell": 1
676
+ },
677
+ "outputs": [],
678
+ "source": [
679
+ "def list_dev():\n",
680
+ " with open(DEV_SCP) as f:\n",
681
+ " return [ln.strip() for ln in f if ln.strip()]\n",
682
+ "\n",
683
+ "dev_names = list_dev()\n",
684
+ "if LIMIT_DEV:\n",
685
+ " dev_names = dev_names[:LIMIT_DEV]\n",
686
+ "dev_stems = [stem(n) for n in dev_names]\n",
687
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
688
+ "aud_dev = extract_audeering(dev_stems, \"dev\")\n",
689
+ "\n",
690
+ "# QMOS: ưu tiên mượn cột QMOS của exp07; không có file → chấm UTMOSv2 (T05, vô địch VMC2024).\n",
691
+ "def load_exp07_qmos():\n",
692
+ " if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
693
+ " import csv\n",
694
+ " d = {}\n",
695
+ " with open(EXP07_ANSWER) as f:\n",
696
+ " r = csv.DictReader(f)\n",
697
+ " for row in r:\n",
698
+ " d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
699
+ " print(f\"✅ Mượn QMOS từ exp07 ({EXP07_ANSWER}): {len(d)//2} wav\")\n",
700
+ " return d\n",
701
+ " return None\n",
702
+ "\n",
703
+ "qmos_map = load_exp07_qmos()\n",
704
+ "if qmos_map is None:\n",
705
+ " print(\"ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024 Track 1).\")\n",
706
+ " pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\") # cần Internet On, checkpoint tự tải\n",
707
+ " import utmosv2\n",
708
+ " v2 = utmosv2.create_model(pretrained=True)\n",
709
+ " qmos_map = {}\n",
710
+ " for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
711
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
712
+ " if not os.path.exists(wav):\n",
713
+ " continue\n",
714
+ " out = v2.predict(input_path=wav) # trả float hoặc dict {'predicted_mos': ...} tùy phiên bản\n",
715
+ " qmos_map[n] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
716
+ " del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
717
+ "\n",
718
+ "@torch.no_grad()\n",
719
+ "def predict_emotion(sid):\n",
720
+ " wave = load_wav(sid)\n",
721
+ " if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
722
+ " return None\n",
723
+ " wavlm.eval(); heads.eval()\n",
724
+ " iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
725
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
726
+ " tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
727
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
728
+ " fw = wavlm_embed(iv, am)\n",
729
+ " if USE_AUDEERING:\n",
730
+ " aud = torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)\n",
731
+ " feat = torch.cat([fw, aud], dim=1)\n",
732
+ " else:\n",
733
+ " feat = fw\n",
734
+ " emos_p, cat_l, vad_p = heads(feat, tgt)\n",
735
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
736
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
737
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
738
+ " return emos, cat5, vad3\n",
739
+ "\n",
740
+ "def fmt_cat(p5):\n",
741
+ " return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
742
+ "\n",
743
+ "def build_answer(out_path):\n",
744
+ " n_real = n_def = 0\n",
745
+ " with open(out_path, \"w\") as f:\n",
746
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
747
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
748
+ " sid = stem(name)\n",
749
+ " pr = predict_emotion(sid)\n",
750
+ " if pr is None:\n",
751
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
752
+ " else:\n",
753
+ " emos, cat5, vad3 = pr; n_real += 1\n",
754
+ " qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
755
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
756
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
757
+ "\n",
758
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
759
+ "build_answer(answer_path)"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "markdown",
764
+ "id": "2e9cab58",
765
+ "metadata": {},
766
+ "source": [
767
+ "## 8. Validate + đóng zip"
768
+ ]
769
+ },
770
+ {
771
+ "cell_type": "code",
772
+ "execution_count": null,
773
+ "id": "88cb0280",
774
+ "metadata": {},
775
+ "outputs": [],
776
+ "source": [
777
+ "def validate(path):\n",
778
+ " import csv\n",
779
+ " with open(path) as f:\n",
780
+ " rows = list(csv.reader(f))\n",
781
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0] and \"EMOS\" in rows[0], \"Header sai\"\n",
782
+ " for i, r in enumerate(rows[1:], 2):\n",
783
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
784
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
785
+ "\n",
786
+ "validate(answer_path)\n",
787
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp08_ft-emotion.zip answer.txt \"\n",
788
+ " f\"&& unzip -l submission_track2_exp08_ft-emotion.zip\")\n",
789
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp08_ft-emotion.zip\"))"
790
+ ]
791
+ },
792
+ {
793
+ "cell_type": "markdown",
794
+ "id": "e2018df3",
795
+ "metadata": {},
796
+ "source": [
797
+ "## Ghi chú\n",
798
+ "- **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để kiểm tra chạy trơn (1 epoch xong không OOM); rồi đặt `None`.\n",
799
+ "- **OOM trên T4?** giảm theo thứ tự: `MAX_SECONDS` (8→6) → `UNFREEZE_TOP_LAYERS` (6→4→2) → `BATCH` (4→2, tăng `ACCUM`).\n",
800
+ "- **Đọc mục 6:** so EMOS/VAD VAL nội bộ với mốc exp07 (EMOS 0.795 · VAL 0.581 · ARO 0.752 · DOM 0.705).\n",
801
+ " - Nếu fine-tune **thắng** → nộp answer.txt exp08 (5 cột cảm xúc của exp08 + QMOS mượn exp07).\n",
802
+ " - Nếu **thua** → giữ exp07; vẫn là kết quả cho paper (\"fine-tune chưa vượt frozen-fusion trên data nhỏ\").\n",
803
+ "- **QMOS:** Add Input answer.txt exp07 vào `/kaggle/input/exp07-answer/answer.txt` để mượn cột QMOS 0.548;\n",
804
+ " không có thì tự chấm UTMOSv2 (T05, vô địch VMC2024 — mạnh hơn UTMOS, cần Internet On).\n",
805
+ "- **Ablation cho paper:** `UNFREEZE_TOP_LAYERS=0` (≈ head-only) vs `=6` (fine-tune) → bảng \"frozen vs fine-tuned\".\n",
806
+ " `USE_AUDEERING=False` → đo đóng góp nhánh phụ.\n",
807
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp08)."
808
+ ]
809
+ }
810
+ ],
811
+ "metadata": {
812
+ "jupytext": {
813
+ "cell_metadata_filter": "-all",
814
+ "main_language": "python",
815
+ "notebook_metadata_filter": "-all"
816
+ }
817
+ },
818
+ "nbformat": 4,
819
+ "nbformat_minor": 5
820
+ }
track2/exp08_finetune_emotion_pipeline.py ADDED
@@ -0,0 +1,673 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp08 (FINE-TUNE WavLM cho 5 cột cảm xúc) — Kaggle
3
+ #
4
+ # **Khác mọi exp trước:** exp03–07 đều **đóng băng** backbone (chỉ trích đặc trưng + train head nhỏ trên cache).
5
+ # exp08 **MỞ BĂNG (fine-tune)** WavLM-large để nó học lại đặc trưng riêng cho bài MOS cảm xúc 2026.
6
+ #
7
+ # ## Thiết kế (chốt với mentor 5/6)
8
+ # ```
9
+ # wav ─┬─► WavLM-large (warm-start SAILER, TRAINABLE: chỉ mở băng N lớp trên) ─► pool ─► emb_wavlm ┐
10
+ # └─► audeering MSP-dim (FROZEN, cache .npz) ─► [emb_aud | vad3] ├─► TRUNK ─┬─► EMOS (+target)
11
+ # ┘ ├─► CAT (5)
12
+ # └─► VAD (3)
13
+ # QMOS: KHÔNG train ở đây → mượn cột QMOS của exp07 (0.548) hoặc UTMOSv2 (T05, vô địch VMC2024).
14
+ # ```
15
+ # - **Warm-start:** khởi tạo WavLM từ checkpoint **SAILER** (`tiantiaf/wavlm-large-categorical-emotion`,
16
+ # đã giỏi cảm xúc) thay vì WavLM "trắng" → điểm xuất phát tốt hơn nhiều.
17
+ # - **Phụ (frozen):** audeering — dimensional, bổ trợ góc nhìn categorical của WavLM, kỳ vọng kéo **VAL**.
18
+ # - **Đóng băng partial:** chỉ train `UNFREEZE_TOP_LAYERS` lớp Transformer trên cùng + feature-extractor giữ băng
19
+ # → tiết kiệm VRAM T4 + chống overfit (chỉ 12.7k mẫu).
20
+ #
21
+ # ## ⚠️ Đánh đổi phải biết trước (so freeze+head)
22
+ # - **Mất lợi thế cache:** mỗi epoch chạy lại cả WavLM (forward+backward) → chậm & đốt giờ GPU (30h/tuần).
23
+ # → **Lần đầu BẮT BUỘC đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.
24
+ # - **Dễ overfit / OOM:** nếu OOM → giảm `BATCH`, tăng `ACCUM`, giảm `MAX_SECONDS`, giảm `UNFREEZE_TOP_LAYERS`.
25
+ # - **Lưới an toàn:** exp07 vẫn là bản nộp vô địch tới khi exp08 **thắng trên VAL nội bộ** (đừng đốt lượt nộp).
26
+ #
27
+ # **Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.
28
+
29
+ # %% [markdown]
30
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
31
+
32
+ # %%
33
+ import os
34
+
35
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
36
+ WAV_DIR = f"{DATA_ROOT}/wav"
37
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
38
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
39
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
40
+
41
+ OUT_DIR = "/kaggle/working"
42
+ CACHE_DIR = "/kaggle/working/ft_cache" # cache audeering (.npz) — backbone WavLM KHÔNG cache (đang train)
43
+ os.makedirs(CACHE_DIR, exist_ok=True)
44
+
45
+ # (Tùy chọn) TÁI DÙNG cache audeering cũ: trỏ tới dataset chứa aud_train.npz/aud_dev.npz → tự copy sang CACHE_DIR.
46
+ # Để "" nếu chạy mới hoàn toàn. /kaggle/input read-only nên phải copy sang working để ghi/append.
47
+ CACHE_INPUT = "/kaggle/input/datasets/minhtoan2/cache-exp8" # << SỬA slug cho khớp (hoặc "")
48
+ if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
49
+ import shutil
50
+ _n = 0
51
+ for _fn in os.listdir(CACHE_INPUT):
52
+ if _fn.startswith("aud_") and _fn.endswith(".npz"):
53
+ shutil.copy(os.path.join(CACHE_INPUT, _fn), os.path.join(CACHE_DIR, _fn)); _n += 1
54
+ print(f"📦 Tái dùng cache: copy {_n} file aud_*.npz từ {CACHE_INPUT} → {CACHE_DIR}")
55
+
56
+ # Mượn cột QMOS của exp07 (tốt nhất 0.548). Trỏ tới answer.txt exp07 nếu có; không thì dùng UTMOSv2.
57
+ EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) Add Input answer.txt exp07; không có → UTMOSv2
58
+
59
+ # ── Fine-tune / siêu tham số ─────────────────────────────────────────────────
60
+ DEVICE = "cuda"
61
+ SR = 16000
62
+ MAX_SECONDS = 8 # cắt audio để chặn bộ nhớ backprop; OOM thì giảm còn 6
63
+ UNFREEZE_TOP_LAYERS = 6 # số lớp Transformer trên cùng được train (0 = freeze hết = quay về head-only)
64
+ TRUNK_HIDDEN = 512
65
+ HEAD_HIDDEN = 128
66
+ DROPOUT = 0.3
67
+ LR_BACKBONE = 1e-5 # LR nhỏ cho backbone fine-tune
68
+ LR_HEAD = 1e-3 # LR lớn cho trunk + head (train từ đầu)
69
+ WEIGHT_DECAY = 1e-5
70
+ EPOCHS = 12 # TRẦN; early-stop quyết định số epoch thực (8 hơi thấp cho lần chạy thật)
71
+ PATIENCE = 3 # dừng khi val SRCC không lên 3 epoch; LUÔN giữ best_state
72
+ BATCH = 4 # nhỏ vì backbone to; tăng ACCUM để bù
73
+ ACCUM = 8 # effective batch = BATCH*ACCUM = 32
74
+ VAL_FRAC = 0.10
75
+ SEED = 42
76
+ USE_AMP = True # mixed precision fp16 — tiết kiệm VRAM
77
+ USE_GRAD_CKPT = True # gradient checkpointing — tiết kiệm VRAM (đổi lấy chậm hơn)
78
+ USE_AUDEERING = True # nhánh phụ frozen audeering; False = chỉ WavLM
79
+ USE_UNCERTAINTY = True # tự cân 5 loss (Kendall); False = trọng số 1.0
80
+
81
+ LIMIT_TRAIN = 300 # << LẦN ĐẦU để 300; chạy thật đặt None
82
+ LIMIT_DEV = 20 # << LẦN ĐẦU để 20; chạy thật đặt None
83
+
84
+ # Mốc exp07 để so (cảnh báo nếu fine-tune KHÔNG thắng → giữ exp07)
85
+ EXP07 = {"emos": 0.795, "cat_err": 0.153, "val": 0.581, "aro": 0.752, "dom": 0.705}
86
+
87
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
88
+
89
+ _EMO_ALIAS = {
90
+ "angry": "angry", "anger": "angry",
91
+ "happy": "happy", "happiness": "happy", "joy": "happy",
92
+ "neutral": "neutral", "calm": "neutral",
93
+ "sad": "sad", "sadness": "sad",
94
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
95
+ }
96
+
97
+ def norm_emotion(label):
98
+ key = str(label).strip().lower()
99
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
100
+
101
+ def stem(p):
102
+ return os.path.splitext(os.path.basename(str(p)))[0]
103
+
104
+ print("DATA_ROOT:", DATA_ROOT)
105
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
106
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
107
+ print(f"Fine-tune: mở băng {UNFREEZE_TOP_LAYERS} lớp trên · BATCH {BATCH}×ACCUM {ACCUM} · MAX {MAX_SECONDS}s")
108
+
109
+ # %% [markdown]
110
+ # ## 1. Cài đặt + tải code SAILER (clone + sys.path, KHÔNG pip install -e .)
111
+
112
+ # %%
113
+ import sys, subprocess
114
+
115
+ def pip_install(*pkgs):
116
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
117
+
118
+ pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile",
119
+ "scipy", "scikit-learn", "pandas", "tqdm")
120
+
121
+ REPO_DIR = "/kaggle/working/vox-profile-release"
122
+ if not os.path.exists(REPO_DIR):
123
+ subprocess.run(["git", "clone", "--depth", "1",
124
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
125
+ if REPO_DIR not in sys.path:
126
+ sys.path.insert(0, REPO_DIR)
127
+
128
+ # %% [markdown]
129
+ # ## 2. Nạp SAILER → lấy backbone WavLM bên trong để FINE-TUNE
130
+ # Thay vì gọi wrapper như hộp đen, ta **lôi module WavLM-large (HuggingFace) bên trong wrapper** ra
131
+ # → toàn quyền đóng băng/mở băng từng lớp + tự pool. Nếu không tìm thấy (cấu trúc lạ) → **fallback**
132
+ # nạp `microsoft/wavlm-large` trắng (mất warm-start, có cảnh báo).
133
+
134
+ # %%
135
+ import torch
136
+ import torch.nn as nn
137
+ import torch.nn.functional as F
138
+
139
+ device = DEVICE if torch.cuda.is_available() else "cpu"
140
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
141
+
142
+ def find_hf_backbone(module):
143
+ """Tìm submodule kiểu HF Wav2Vec2/WavLM backbone: có .feature_extractor và .encoder.layers."""
144
+ cands = []
145
+ for name, m in module.named_modules():
146
+ enc = getattr(m, "encoder", None)
147
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
148
+ and getattr(enc, "layers", None) is not None:
149
+ cands.append((name, m))
150
+ if not cands:
151
+ return None, None
152
+ cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
153
+ return cands[0]
154
+
155
+ wavlm = None
156
+ try:
157
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
158
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
159
+ name, wavlm = find_hf_backbone(_wrapper)
160
+ if wavlm is not None:
161
+ print(f"✅ Warm-start SAILER: lấy backbone WavLM bên trong wrapper tại '.{name}' "
162
+ f"({sum(p.numel() for p in wavlm.parameters())/1e6:.0f}M params)")
163
+ else:
164
+ print("⚠️ Không tìm thấy backbone HF bên trong wrapper SAILER → sẽ fallback WavLM trắng.")
165
+ except Exception as e:
166
+ print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
167
+
168
+ if wavlm is None:
169
+ from transformers import WavLMModel
170
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
171
+ print("ℹ️ Fallback: nạp microsoft/wavlm-large (KHÔNG warm-start SAILER).")
172
+
173
+ wavlm = wavlm.to(device)
174
+ WAVLM_DIM = int(wavlm.config.hidden_size)
175
+
176
+ # ── Đóng băng partial: feature-extractor + tất cả trừ UNFREEZE_TOP_LAYERS lớp trên ──
177
+ for p in wavlm.parameters():
178
+ p.requires_grad = False
179
+ enc_layers = wavlm.encoder.layers
180
+ n_layers = len(enc_layers)
181
+ for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:
182
+ for p in layer.parameters():
183
+ p.requires_grad = True
184
+ n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)
185
+ print(f"WavLM: {n_layers} lớp encoder · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} lớp trên "
186
+ f"→ {n_train/1e6:.1f}M param train (trên dim {WAVLM_DIM})")
187
+
188
+ if USE_GRAD_CKPT:
189
+ wavlm.gradient_checkpointing_enable()
190
+ if hasattr(wavlm, "enable_input_require_grads"):
191
+ wavlm.enable_input_require_grads() # cần khi grad-ckpt + lớp dưới đóng băng
192
+
193
+ def masked_mean(hidden, attn_mask):
194
+ """Mean-pool theo thời gian, bỏ qua phần pad (giữ gradient)."""
195
+ if attn_mask is None:
196
+ return hidden.mean(dim=1)
197
+ try:
198
+ fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
199
+ except Exception:
200
+ return hidden.mean(dim=1)
201
+ fm = fm.unsqueeze(-1).to(hidden.dtype)
202
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
203
+
204
+ def wavlm_embed(input_values, attn_mask):
205
+ out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state # [B,T,D]
206
+ return masked_mean(out, attn_mask)
207
+
208
+ # %% [markdown]
209
+ # ## 3. Nạp audeering MSP-dim (FROZEN) — đặc trưng phụ
210
+ # Lấy `[emb_pool(1024) | vad3(1–5)]` mỗi wav rồi **cache .npz** (chạy 1 lần). Kỹ thuật nạp head tay
211
+ # y hệt exp05 (tránh lỗi version transformers khi subclass `Wav2Vec2PreTrainedModel`).
212
+
213
+ # %%
214
+ AUD_DIM = 0
215
+ aud_backbone = aud_head = aud_proc = None
216
+ if USE_AUDEERING:
217
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
218
+ from huggingface_hub import hf_hub_download
219
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
220
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
221
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
222
+ aud_backbone = Wav2Vec2Model(aud_cfg)
223
+ try:
224
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
225
+ hf_hub_download(AUD_NAME, "model.safetensors"))
226
+ except Exception:
227
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
228
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
229
+ missing, unexpected = aud_backbone.load_state_dict(bb_sd, strict=False)
230
+ print(f" audeering backbone: thiếu {len(missing)} / dư {len(unexpected)} key (strict=False)")
231
+ _hid = _sd["classifier.dense.weight"].shape[0]
232
+ _out = _sd["classifier.out_proj.weight"].shape[0]
233
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
234
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
235
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
236
+ aud_backbone = aud_backbone.to(device).eval()
237
+ aud_head = aud_head.to(device).eval()
238
+ AUD_DIM = _hid + 3 # emb_pool + [VAL,ARO,DOM]
239
+ print(f"✅ audeering frozen (đặc trưng phụ {AUD_DIM}-D = emb {_hid} + vad 3)")
240
+
241
+ # %%
242
+ import numpy as np
243
+ import librosa
244
+ from tqdm.auto import tqdm
245
+
246
+ def load_wav(name_or_stem, in_wav_dir=True):
247
+ p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(
248
+ WAV_DIR, name_or_stem if str(name_or_stem).endswith(".wav") else str(name_or_stem) + ".wav")
249
+ if not os.path.exists(p):
250
+ return None
251
+ wave, _ = librosa.load(p, sr=SR, mono=True)
252
+ return wave[: MAX_SECONDS * SR].astype(np.float32)
253
+
254
+ @torch.no_grad()
255
+ def extract_audeering(stems, tag):
256
+ """→ dict {stem: float32[AUD_DIM]}; cache CACHE_DIR/aud_<tag>.npz (resume mỗi 500)."""
257
+ if not USE_AUDEERING:
258
+ return {}
259
+ cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
260
+ store = {}
261
+ if os.path.exists(cache_path):
262
+ z = np.load(cache_path, allow_pickle=True)
263
+ store = {k: z[k] for k in z.files}
264
+ print(f"[aud/{tag}] nạp cache: {len(store)}")
265
+ todo = [s for s in stems if s not in store]
266
+ for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
267
+ wave = load_wav(s)
268
+ if wave is None:
269
+ continue
270
+ x = aud_proc(wave, sampling_rate=SR).input_values[0]
271
+ x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
272
+ h = aud_backbone(x)[0].mean(dim=1) # [1, hid]
273
+ out = aud_head(h)[0].cpu().numpy() # [arousal, dominance, valence] ∈[0,1]
274
+ vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
275
+ store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
276
+ if (i + 1) % 500 == 0:
277
+ np.savez(cache_path, **store)
278
+ if todo:
279
+ np.savez(cache_path, **store)
280
+ return store
281
+
282
+ # %% [markdown]
283
+ # ## 4. Đọc & gộp nhãn theo wavID (EMOS / VAD / CAT) — như exp04/07 nhưng KHÔNG cần qMOS
284
+
285
+ # %%
286
+ import pandas as pd
287
+
288
+ def load_target_emotions():
289
+ tgt = {}
290
+ with open(METADATA_CSV, encoding="utf-8") as f:
291
+ for ln in f:
292
+ parts = ln.strip().split("|")
293
+ if len(parts) >= 2:
294
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
295
+ return tgt
296
+
297
+ def _col(cols_map, *names, df=None, default_idx=None):
298
+ for n in names:
299
+ if n in cols_map:
300
+ return cols_map[n]
301
+ return list(df.columns)[default_idx] if default_idx is not None else None
302
+
303
+ def parse_emocat_votes(cell):
304
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
305
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
306
+ e = norm_emotion(tok)
307
+ if e in EMOTIONS5:
308
+ v[EMOTIONS5.index(e)] += 1.0
309
+ return v
310
+
311
+ def load_train_labels():
312
+ df = pd.read_csv(TRAIN_CSV, sep="|")
313
+ cols = {c.lower().strip(): c for c in df.columns}
314
+ wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
315
+ emos_col = _col(cols, "emos", "emo", "emomos")
316
+ val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
317
+ cat_col = _col(cols, "emocat", "cat", "emotion")
318
+ assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
319
+ df["_stem"] = df[wav_col].map(stem)
320
+ rows = []
321
+ for sid, g in df.groupby("_stem"):
322
+ rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
323
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
324
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
325
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
326
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
327
+ if cat_col:
328
+ for cell in g[cat_col]:
329
+ votes += parse_emocat_votes(cell)
330
+ s = votes.sum()
331
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
332
+ for i in range(len(EMOTIONS5)):
333
+ rec[f"cat{i}"] = float(cat[i])
334
+ rows.append(rec)
335
+ return pd.DataFrame(rows)
336
+
337
+ target_map = load_target_emotions()
338
+ train_df = load_train_labels()
339
+ HAS_VAD = bool(train_df["val"].notna().any())
340
+ print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
341
+
342
+ # %% [markdown]
343
+ # ## 5. Dataset / DataLoader (load wav theo batch — KHÔNG cache WavLM vì đang train)
344
+
345
+ # %%
346
+ from torch.utils.data import Dataset, DataLoader
347
+
348
+ train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
349
+ if LIMIT_TRAIN:
350
+ train_stems = train_stems[:LIMIT_TRAIN]
351
+ aud_tr = extract_audeering(train_stems, "train")
352
+
353
+ lab = train_df.set_index("wavID")
354
+
355
+ # Chuẩn hóa nhãn liên tục về z-score (để các MSE cùng thang) — lưu để giải mã lúc dự đoán.
356
+ def _zfit(arr):
357
+ a = np.asarray(arr, dtype=np.float32)
358
+ return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
359
+
360
+ emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
361
+ if HAS_VAD:
362
+ vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], dtype=np.float32)
363
+ vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], dtype=np.float32)
364
+ else:
365
+ vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
366
+
367
+ def onehot_target(tgt):
368
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
369
+ if tgt in EMOTIONS5:
370
+ v[EMOTIONS5.index(tgt)] = 1.0
371
+ return v
372
+
373
+ class EmoDataset(Dataset):
374
+ def __init__(self, stems):
375
+ self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]
376
+ def __len__(self):
377
+ return len(self.stems)
378
+ def __getitem__(self, i):
379
+ s = self.stems[i]
380
+ wave = load_wav(s)
381
+ emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
382
+ if HAS_VAD:
383
+ vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
384
+ else:
385
+ vad = np.zeros(3, dtype=np.float32)
386
+ cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
387
+ aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)
388
+ return {"wave": wave, "tgt": onehot_target(target_map.get(s)), "aud": aud,
389
+ "emos": np.float32(emos), "vad": vad, "cat": cat,
390
+ "emos_raw": np.float32(lab.loc[s, "emos"]),
391
+ "vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
392
+
393
+ def collate(batch):
394
+ lens = [len(b["wave"]) for b in batch]
395
+ L = max(lens)
396
+ waves = np.zeros((len(batch), L), dtype=np.float32)
397
+ mask = np.zeros((len(batch), L), dtype=np.float32)
398
+ for i, b in enumerate(batch):
399
+ waves[i, : len(b["wave"])] = b["wave"]; mask[i, : len(b["wave"])] = 1.0
400
+ out = {
401
+ "input_values": torch.from_numpy(waves), "attn_mask": torch.from_numpy(mask).long(),
402
+ "tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
403
+ "aud": torch.from_numpy(np.stack([b["aud"] for b in batch])) if USE_AUDEERING else None,
404
+ "emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
405
+ "vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
406
+ "cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
407
+ "emos_raw": np.stack([b["emos_raw"] for b in batch]),
408
+ "vad_raw": np.stack([b["vad_raw"] for b in batch]),
409
+ }
410
+ return out
411
+
412
+ from sklearn.model_selection import train_test_split
413
+ ds = EmoDataset(train_stems)
414
+ print("Dataset hợp lệ:", len(ds), "wav")
415
+ tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
416
+ tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
417
+ va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
418
+
419
+ # %% [markdown]
420
+ # ## 6. Head fusion (trunk + 3 head cảm xúc) + train loop (AMP + grad accumulation)
421
+
422
+ # %%
423
+ from scipy.stats import spearmanr
424
+
425
+ torch.manual_seed(SEED); np.random.seed(SEED)
426
+ N_EMO = len(EMOTIONS5)
427
+ TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
428
+
429
+ class EmoHeads(nn.Module):
430
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
431
+ super().__init__()
432
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
433
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
434
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
435
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
436
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
437
+ def forward(self, feat, tgt):
438
+ h = self.trunk(feat)
439
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
440
+
441
+ heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
442
+ print(f"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM if USE_AUDEERING else 0})")
443
+
444
+ TASKS = ["emos", "cat", "val", "aro", "dom"]
445
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
446
+ bb_params = [p for p in wavlm.parameters() if p.requires_grad]
447
+ head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
448
+ opt = torch.optim.AdamW([
449
+ {"params": bb_params, "lr": LR_BACKBONE},
450
+ {"params": head_params, "lr": LR_HEAD},
451
+ ], weight_decay=WEIGHT_DECAY)
452
+ scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
453
+ mse = nn.MSELoss()
454
+
455
+ def soft_ce(logits, target_dist):
456
+ return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
457
+
458
+ def forward_batch(b):
459
+ feat_wavlm = wavlm_embed(b["input_values"].to(device), b["attn_mask"].to(device))
460
+ if USE_AUDEERING:
461
+ feat = torch.cat([feat_wavlm, b["aud"].to(device)], dim=1)
462
+ else:
463
+ feat = feat_wavlm
464
+ return heads(feat, b["tgt"].to(device))
465
+
466
+ def compute_loss(emos_p, cat_l, vad_p, b):
467
+ L = {}
468
+ L["emos"] = mse(emos_p, b["emos"].to(device))
469
+ L["cat"] = soft_ce(cat_l, b["cat"].to(device))
470
+ if HAS_VAD:
471
+ vt = b["vad"].to(device)
472
+ L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
473
+ else:
474
+ z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
475
+ if USE_UNCERTAINTY:
476
+ return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
477
+ return sum(L.values())
478
+
479
+ @torch.no_grad()
480
+ def evaluate():
481
+ wavlm.eval(); heads.eval()
482
+ P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
483
+ catP, catY = [], []
484
+ for b in va_loader:
485
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
486
+ emos_p, cat_l, vad_p = forward_batch(b)
487
+ P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
488
+ vad_p = vad_p.float().cpu().numpy()
489
+ for j, t in enumerate(["val", "aro", "dom"]):
490
+ P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
491
+ catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
492
+ out = {}
493
+ for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
494
+ out[t] = spearmanr(P[t], Y[t]).correlation
495
+ q = np.concatenate(catP); p = np.concatenate(catY)
496
+ out["cat_err"] = float(np.abs(q - p).sum(1).mean()) # ~ tổng |Δ| trung bình (xấp xỉ CAT-ERR)
497
+ return out
498
+
499
+ def mean_srcc(m):
500
+ keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
501
+ return float(np.mean([m[k] for k in keys]))
502
+
503
+ # Lưu checkpoint FULL (có backbone WavLM) — gọi NGAY mỗi best để kernel chết giữa chừng vẫn còn file.
504
+ CKPT_PATH = os.path.join(OUT_DIR, "ft_emotion_full.pt")
505
+ def save_full_ckpt(state, val_emos=float("nan")):
506
+ torch.save({"wavlm": state["wavlm"], "heads": state["heads"],
507
+ "emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
508
+ "WAVLM_DIM": WAVLM_DIM, "AUD_DIM": AUD_DIM,
509
+ "UNFREEZE_TOP_LAYERS": UNFREEZE_TOP_LAYERS, "val_emos": float(val_emos)}, CKPT_PATH)
510
+
511
+ best, best_state, bad = -1e9, None, 0
512
+ for ep in range(1, EPOCHS + 1):
513
+ wavlm.train(); heads.train()
514
+ opt.zero_grad(); run = 0.0; nb = 0
515
+ for step, b in enumerate(tqdm(tr_loader, desc=f"epoch {ep}")):
516
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
517
+ emos_p, cat_l, vad_p = forward_batch(b)
518
+ loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
519
+ scaler.scale(loss).backward()
520
+ if (step + 1) % ACCUM == 0:
521
+ scaler.step(opt); scaler.update(); opt.zero_grad()
522
+ run += loss.item() * ACCUM; nb += 1
523
+ m = evaluate(); sc = mean_srcc(m)
524
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
525
+ print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
526
+ if sc > best:
527
+ best = sc
528
+ best_state = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
529
+ "heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
530
+ save_full_ckpt(best_state, m["emos"]) # LƯU NGAY mỗi best → an toàn nếu kernel chết
531
+ print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})")
532
+ bad = 0
533
+ else:
534
+ bad += 1
535
+ if bad >= PATIENCE:
536
+ print(f"Early stop ở epoch {ep}."); break
537
+
538
+ if best_state:
539
+ wavlm.load_state_dict(best_state["wavlm"]); heads.load_state_dict(best_state["heads"])
540
+ final = evaluate()
541
+ print("\n✅ VAL (nội bộ) — exp08 (fine-tune WavLM cho cảm xúc):")
542
+ print(f" EMOS={final['emos']:.4f} (exp07 {EXP07['emos']})")
543
+ if HAS_VAD:
544
+ print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} "
545
+ f"(exp07 {EXP07['val']}/{EXP07['aro']}/{EXP07['dom']})")
546
+ warn = [f"EMOS {final['emos']:.3f}<{EXP07['emos']}"] if final["emos"] < EXP07["emos"] - 0.005 else []
547
+ if HAS_VAD:
548
+ warn += [f"{t.upper()} {final[t]:.3f}<{EXP07[t]}" for t in ["val", "aro", "dom"] if final[t] < EXP07[t] - 0.005]
549
+ print(" ⚠️ CHƯA thắng exp07 ở:", "; ".join(warn), "→ cân nhắc giữ exp07." if warn else "")
550
+ if not warn:
551
+ print(" ✅ Fine-tune thắng/ngang exp07 ở mọi cột cảm xúc → đáng nộp.")
552
+ # Lưu lần cuối từ best (đã lưu sẵn mỗi best trong loop; đây là phát cuối cho chắc).
553
+ save_full_ckpt(best_state if best_state else
554
+ {"wavlm": wavlm.state_dict(), "heads": heads.state_dict()}, final["emos"])
555
+ print(f"✅ Đã lưu {CKPT_PATH} (CÓ backbone WavLM + heads → resume được). "
556
+ f"NHỚ Save Version để file ra Output!")
557
+
558
+ # %% [markdown]
559
+ # ## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc từ exp08; QMOS mượn exp07 hoặc UTMOS)
560
+
561
+ # %%
562
+ def list_dev():
563
+ with open(DEV_SCP) as f:
564
+ return [ln.strip() for ln in f if ln.strip()]
565
+
566
+ dev_names = list_dev()
567
+ if LIMIT_DEV:
568
+ dev_names = dev_names[:LIMIT_DEV]
569
+ dev_stems = [stem(n) for n in dev_names]
570
+ print("DEV:", len(dev_names), "mẫu")
571
+ aud_dev = extract_audeering(dev_stems, "dev")
572
+
573
+ # QMOS: ưu tiên mượn cột QMOS của exp07; không có file → chấm UTMOSv2 (T05, vô địch VMC2024).
574
+ def load_exp07_qmos():
575
+ if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
576
+ import csv
577
+ d = {}
578
+ with open(EXP07_ANSWER) as f:
579
+ r = csv.DictReader(f)
580
+ for row in r:
581
+ d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
582
+ print(f"✅ Mượn QMOS từ exp07 ({EXP07_ANSWER}): {len(d)//2} wav")
583
+ return d
584
+ return None
585
+
586
+ qmos_map = load_exp07_qmos()
587
+ if qmos_map is None:
588
+ print("ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024 Track 1).")
589
+ pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git") # cần Internet On, checkpoint tự tải
590
+ import utmosv2
591
+ v2 = utmosv2.create_model(pretrained=True)
592
+ qmos_map = {}
593
+ for n in tqdm(dev_names, desc="UTMOSv2"):
594
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
595
+ if not os.path.exists(wav):
596
+ continue
597
+ out = v2.predict(input_path=wav) # trả float hoặc dict {'predicted_mos': ...} tùy phiên bản
598
+ qmos_map[n] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
599
+ del v2; torch.cuda.empty_cache() if device == "cuda" else None
600
+
601
+ @torch.no_grad()
602
+ def predict_emotion(sid):
603
+ wave = load_wav(sid)
604
+ if wave is None or (USE_AUDEERING and sid not in aud_dev):
605
+ return None
606
+ wavlm.eval(); heads.eval()
607
+ iv = torch.from_numpy(wave).unsqueeze(0).to(device)
608
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
609
+ tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
610
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
611
+ fw = wavlm_embed(iv, am)
612
+ if USE_AUDEERING:
613
+ aud = torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)
614
+ feat = torch.cat([fw, aud], dim=1)
615
+ else:
616
+ feat = fw
617
+ emos_p, cat_l, vad_p = heads(feat, tgt)
618
+ emos = float(emos_p.item()) * emos_sd + emos_mu
619
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
620
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
621
+ return emos, cat5, vad3
622
+
623
+ def fmt_cat(p5):
624
+ return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
625
+
626
+ def build_answer(out_path):
627
+ n_real = n_def = 0
628
+ with open(out_path, "w") as f:
629
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
630
+ for name in tqdm(dev_names, desc="answer"):
631
+ sid = stem(name)
632
+ pr = predict_emotion(sid)
633
+ if pr is None:
634
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
635
+ else:
636
+ emos, cat5, vad3 = pr; n_real += 1
637
+ qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
638
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
639
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
640
+
641
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
642
+ build_answer(answer_path)
643
+
644
+ # %% [markdown]
645
+ # ## 8. Validate + đóng zip
646
+
647
+ # %%
648
+ def validate(path):
649
+ import csv
650
+ with open(path) as f:
651
+ rows = list(csv.reader(f))
652
+ assert rows[0][0] == "wav" and "QMOS" in rows[0] and "EMOS" in rows[0], "Header sai"
653
+ for i, r in enumerate(rows[1:], 2):
654
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
655
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
656
+
657
+ validate(answer_path)
658
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp08_ft-emotion.zip answer.txt "
659
+ f"&& unzip -l submission_track2_exp08_ft-emotion.zip")
660
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp08_ft-emotion.zip"))
661
+
662
+ # %% [markdown]
663
+ # ## Ghi chú
664
+ # - **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để kiểm tra chạy trơn (1 epoch xong không OOM); rồi đặt `None`.
665
+ # - **OOM trên T4?** giảm theo thứ tự: `MAX_SECONDS` (8→6) → `UNFREEZE_TOP_LAYERS` (6→4→2) → `BATCH` (4→2, tăng `ACCUM`).
666
+ # - **Đọc mục 6:** so EMOS/VAD VAL nội bộ với mốc exp07 (EMOS 0.795 · VAL 0.581 · ARO 0.752 · DOM 0.705).
667
+ # - Nếu fine-tune **thắng** → nộp answer.txt exp08 (5 cột cảm xúc của exp08 + QMOS mượn exp07).
668
+ # - Nếu **thua** → giữ exp07; vẫn là kết quả cho paper ("fine-tune chưa vượt frozen-fusion trên data nhỏ").
669
+ # - **QMOS:** Add Input answer.txt exp07 vào `/kaggle/input/exp07-answer/answer.txt` để mượn cột QMOS 0.548;
670
+ # không có thì tự chấm UTMOSv2 (T05, vô địch VMC2024 — mạnh hơn UTMOS, cần Internet On).
671
+ # - **Ablation cho paper:** `UNFREEZE_TOP_LAYERS=0` (≈ head-only) vs `=6` (fine-tune) → bảng "frozen vs fine-tuned".
672
+ # `USE_AUDEERING=False` → đo đóng góp nhánh phụ.
673
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp08).
track2/exp08b_finetune_resume.ipynb ADDED
@@ -0,0 +1,782 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "ce468400",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp08-RESUME (fine-tune TIẾP từ checkpoint + cache) — Kaggle\n",
9
+ "\n",
10
+ "**Mục đích:** train tiếp model fine-tune cảm xúc (exp08) từ **checkpoint đã lưu** thay vì train lại từ\n",
11
+ "đầu — tiết kiệm giờ GPU. Tận dụng:\n",
12
+ "- `ft_emotion_full.pt` (CÓ cả backbone WavLM + heads + thống kê chuẩn hóa) → nạp lại đúng trạng thái.\n",
13
+ "- **cache audeering** `aud_*.npz` (đặc trưng frozen) → KHÔNG trích lại (~đỡ chục phút).\n",
14
+ "\n",
15
+ "> ⚠️ Bắt buộc dùng checkpoint **đủ backbone** (`ft_emotion_full.pt` từ cell \"TRAIN TIẾP\", hoặc bản\n",
16
+ "> `ft_emotion_meta.pt` MỚI đã vá để lưu cả `wavlm`). Bản `ft_emotion_meta.pt` CŨ chỉ có `heads` → KHÔNG dùng được.\n",
17
+ "\n",
18
+ "## Chuẩn bị input trên Kaggle (Add Input)\n",
19
+ "1. Dataset Track 2 (`vmc2026-track2-full`) — wav + nhãn.\n",
20
+ "2. **Checkpoint**: upload `ft_emotion_full.pt` thành 1 Dataset → trỏ `RESUME_CKPT`.\n",
21
+ "3. **Cache** (tùy chọn nhưng nên có): upload thư mục chứa `aud_train.npz`, `aud_dev.npz` → trỏ `CACHE_INPUT`.\n",
22
+ "4. (tùy chọn) `answer.txt` exp07 để mượn cột QMOS 0.548.\n",
23
+ "\n",
24
+ "**Cách chạy:** GPU T4 + Internet On → sửa các slug ở cell 0 → Run All. Lần đầu để `LIMIT_TRAIN=300`."
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "id": "1c6752ee",
30
+ "metadata": {},
31
+ "source": [
32
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "8d6317ac",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "import os, shutil\n",
43
+ "\n",
44
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
45
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
46
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\"\n",
47
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\"\n",
48
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
49
+ "\n",
50
+ "# ── Checkpoint + cache để RESUME ─────────────────────────────────────────────\n",
51
+ "RESUME_CKPT = \"/kaggle/input/ft-emotion-full/ft_emotion_full.pt\" # << CHECKPOINT đủ backbone\n",
52
+ "CACHE_INPUT = \"/kaggle/input/ft-emotion-cache\" # << thư mục chứa aud_*.npz (hoặc \"\" nếu không có)\n",
53
+ "EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) mượn QMOS 0.548; không có → UTMOSv2\n",
54
+ "\n",
55
+ "OUT_DIR = \"/kaggle/working\"\n",
56
+ "CACHE_DIR = \"/kaggle/working/ft_cache\" # /kaggle/input read-only → copy cache sang đây để ghi/append được\n",
57
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
58
+ "\n",
59
+ "# ── Fine-tune / siêu tham số (train TIẾP) ────────────────────────────────────\n",
60
+ "DEVICE = \"cuda\"\n",
61
+ "SR = 16000\n",
62
+ "MAX_SECONDS = 8\n",
63
+ "UNFREEZE_TOP_LAYERS = 6 # PHẢI khớp checkpoint (mặc định exp08 = 6)\n",
64
+ "TRUNK_HIDDEN = 512 # PHẢI khớp checkpoint\n",
65
+ "HEAD_HIDDEN = 128 # PHẢI khớp checkpoint\n",
66
+ "DROPOUT = 0.3\n",
67
+ "LR_BACKBONE = 1e-5\n",
68
+ "LR_HEAD = 1e-3\n",
69
+ "RESUME_LR_SCALE = 1.0 # <1.0 để giảm LR khi train tiếp (vd 0.5 nếu val đã chững)\n",
70
+ "WEIGHT_DECAY = 1e-5\n",
71
+ "EPOCHS = 10 # số epoch train THÊM (run này)\n",
72
+ "PATIENCE = 5 # dừng khi val không lên; LUÔN giữ best\n",
73
+ "BATCH = 4\n",
74
+ "ACCUM = 8 # effective batch = 32\n",
75
+ "VAL_FRAC = 0.10\n",
76
+ "SEED = 42\n",
77
+ "USE_AMP = True\n",
78
+ "USE_GRAD_CKPT = True\n",
79
+ "USE_AUDEERING = True # PHẢI khớp checkpoint (exp08 = True)\n",
80
+ "USE_UNCERTAINTY = True\n",
81
+ "\n",
82
+ "LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
83
+ "LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
84
+ "\n",
85
+ "# Mốc exp07 + exp08 để so\n",
86
+ "EXP07 = {\"emos\": 0.795, \"cat_err\": 0.153, \"val\": 0.581, \"aro\": 0.752, \"dom\": 0.705}\n",
87
+ "EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751} # bản đã nộp\n",
88
+ "\n",
89
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
90
+ "_EMO_ALIAS = {\n",
91
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
92
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
93
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
94
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
95
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
96
+ "}\n",
97
+ "\n",
98
+ "def norm_emotion(label):\n",
99
+ " key = str(label).strip().lower()\n",
100
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
101
+ "\n",
102
+ "def stem(p):\n",
103
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
104
+ "\n",
105
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
106
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP, RESUME_CKPT]:\n",
107
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
108
+ "\n",
109
+ "# Copy cache (aud_*.npz) từ input read-only sang working để append được\n",
110
+ "if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
111
+ " n = 0\n",
112
+ " for fn in os.listdir(CACHE_INPUT):\n",
113
+ " if fn.startswith(\"aud_\") and fn.endswith(\".npz\"):\n",
114
+ " shutil.copy(os.path.join(CACHE_INPUT, fn), os.path.join(CACHE_DIR, fn)); n += 1\n",
115
+ " print(f\"📦 Copy {n} file cache audeering từ {CACHE_INPUT} → {CACHE_DIR}\")\n",
116
+ "else:\n",
117
+ " print(\"ℹ️ Không có CACHE_INPUT → sẽ tự trích audeering (chậm hơn lần đầu).\")"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "id": "57e6416a",
123
+ "metadata": {},
124
+ "source": [
125
+ "## 1. Cài đặt + tải code SAILER (để dựng đúng kiến trúc WavLM rồi nạp checkpoint đè lên)"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "76497a3f",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "import sys, subprocess\n",
136
+ "\n",
137
+ "def pip_install(*pkgs):\n",
138
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
139
+ "\n",
140
+ "pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\",\n",
141
+ " \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
142
+ "\n",
143
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
144
+ "if not os.path.exists(REPO_DIR):\n",
145
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
146
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
147
+ "if REPO_DIR not in sys.path:\n",
148
+ " sys.path.insert(0, REPO_DIR)"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "markdown",
153
+ "id": "cf6cf213",
154
+ "metadata": {},
155
+ "source": [
156
+ "## 2. Dựng WavLM (như exp08) → NẠP trọng số backbone từ checkpoint\n",
157
+ "Dựng đúng kiến trúc (SAILER wrapper → lấy backbone HF; fallback WavLM trắng), rồi `load_state_dict`\n",
158
+ "bằng `ckpt[\"wavlm\"]` → khôi phục đúng trạng thái fine-tune đã lưu."
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "20a2e84b",
165
+ "metadata": {
166
+ "lines_to_next_cell": 1
167
+ },
168
+ "outputs": [],
169
+ "source": [
170
+ "import torch\n",
171
+ "import torch.nn as nn\n",
172
+ "import torch.nn.functional as F\n",
173
+ "\n",
174
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
175
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
176
+ "\n",
177
+ "ckpt = torch.load(RESUME_CKPT, map_location=\"cpu\", weights_only=False) # ckpt có numpy (vad_mu) → cần False\n",
178
+ "assert \"wavlm\" in ckpt, (\"❌ Checkpoint KHÔNG có 'wavlm' (backbone). Đây là bản ft_emotion_meta.pt CŨ \"\n",
179
+ " \"chỉ lưu heads → không resume được. Hãy dùng ft_emotion_full.pt.\")\n",
180
+ "print(\"✅ Nạp checkpoint:\", RESUME_CKPT, \"| keys:\", list(ckpt.keys()))\n",
181
+ "\n",
182
+ "def find_hf_backbone(module):\n",
183
+ " cands = []\n",
184
+ " for name, m in module.named_modules():\n",
185
+ " enc = getattr(m, \"encoder\", None)\n",
186
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
187
+ " and getattr(enc, \"layers\", None) is not None:\n",
188
+ " cands.append((name, m))\n",
189
+ " if not cands:\n",
190
+ " return None, None\n",
191
+ " cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
192
+ " return cands[0]\n",
193
+ "\n",
194
+ "wavlm = None\n",
195
+ "try:\n",
196
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
197
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
198
+ " name, wavlm = find_hf_backbone(_wrapper)\n",
199
+ " if wavlm is not None:\n",
200
+ " print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'\")\n",
201
+ "except Exception as e:\n",
202
+ " print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
203
+ "\n",
204
+ "if wavlm is None:\n",
205
+ " from transformers import WavLMModel\n",
206
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
207
+ " print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
208
+ "\n",
209
+ "wavlm = wavlm.to(device)\n",
210
+ "WAVLM_DIM = int(wavlm.config.hidden_size)\n",
211
+ "\n",
212
+ "# Nạp trọng số đã fine-tune từ checkpoint (đè lên kiến trúc vừa dựng)\n",
213
+ "miss, unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
214
+ "print(f\"🔁 load wavlm từ checkpoint: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).\")\n",
215
+ "if len(miss) > 20 or len(unexp) > 20:\n",
216
+ " print(\" ⚠️ Lệch key nhiều → kiến trúc có thể không khớp checkpoint. Kiểm tra UNFREEZE/USE_AUDEERING.\")\n",
217
+ "\n",
218
+ "# Đóng băng partial: chỉ mở UNFREEZE_TOP_LAYERS lớp trên\n",
219
+ "for p in wavlm.parameters():\n",
220
+ " p.requires_grad = False\n",
221
+ "enc_layers = wavlm.encoder.layers\n",
222
+ "n_layers = len(enc_layers)\n",
223
+ "for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:\n",
224
+ " for p in layer.parameters():\n",
225
+ " p.requires_grad = True\n",
226
+ "n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)\n",
227
+ "print(f\"WavLM: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {WAVLM_DIM})\")\n",
228
+ "\n",
229
+ "if USE_GRAD_CKPT:\n",
230
+ " wavlm.gradient_checkpointing_enable()\n",
231
+ " if hasattr(wavlm, \"enable_input_require_grads\"):\n",
232
+ " wavlm.enable_input_require_grads()\n",
233
+ "\n",
234
+ "def masked_mean(hidden, attn_mask):\n",
235
+ " if attn_mask is None:\n",
236
+ " return hidden.mean(dim=1)\n",
237
+ " try:\n",
238
+ " fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
239
+ " except Exception:\n",
240
+ " return hidden.mean(dim=1)\n",
241
+ " fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
242
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
243
+ "\n",
244
+ "def wavlm_embed(input_values, attn_mask):\n",
245
+ " out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
246
+ " return masked_mean(out, attn_mask)"
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "id": "156a5f4d",
252
+ "metadata": {},
253
+ "source": [
254
+ "## 3. audeering FROZEN (đặc trưng phụ) — dùng cache nếu có"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "id": "670569c7",
261
+ "metadata": {
262
+ "lines_to_next_cell": 1
263
+ },
264
+ "outputs": [],
265
+ "source": [
266
+ "import numpy as np\n",
267
+ "import librosa\n",
268
+ "from tqdm.auto import tqdm\n",
269
+ "\n",
270
+ "AUD_DIM = 0\n",
271
+ "aud_backbone = aud_head = aud_proc = None\n",
272
+ "if USE_AUDEERING:\n",
273
+ " from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
274
+ " from huggingface_hub import hf_hub_download\n",
275
+ " AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
276
+ " aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
277
+ " aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
278
+ " aud_backbone = Wav2Vec2Model(aud_cfg)\n",
279
+ " try:\n",
280
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
281
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
282
+ " except Exception:\n",
283
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
284
+ " bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
285
+ " aud_backbone.load_state_dict(bb_sd, strict=False)\n",
286
+ " _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
287
+ " _out = _sd[\"classifier.out_proj.weight\"].shape[0]\n",
288
+ " aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
289
+ " aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
290
+ " aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
291
+ " aud_backbone = aud_backbone.to(device).eval()\n",
292
+ " aud_head = aud_head.to(device).eval()\n",
293
+ " AUD_DIM = _hid + 3\n",
294
+ " print(f\"✅ audeering frozen ({AUD_DIM}-D)\")\n",
295
+ "\n",
296
+ "def load_wav(name_or_stem):\n",
297
+ " p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(\n",
298
+ " WAV_DIR, name_or_stem if str(name_or_stem).endswith(\".wav\") else str(name_or_stem) + \".wav\")\n",
299
+ " if not os.path.exists(p):\n",
300
+ " return None\n",
301
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
302
+ " return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
303
+ "\n",
304
+ "@torch.no_grad()\n",
305
+ "def extract_audeering(stems, tag):\n",
306
+ " if not USE_AUDEERING:\n",
307
+ " return {}\n",
308
+ " cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
309
+ " store = {}\n",
310
+ " if os.path.exists(cache_path):\n",
311
+ " z = np.load(cache_path, allow_pickle=True)\n",
312
+ " store = {k: z[k] for k in z.files}\n",
313
+ " print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
314
+ " todo = [s for s in stems if s not in store]\n",
315
+ " for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
316
+ " wave = load_wav(s)\n",
317
+ " if wave is None:\n",
318
+ " continue\n",
319
+ " x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
320
+ " x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
321
+ " h = aud_backbone(x)[0].mean(dim=1)\n",
322
+ " out = aud_head(h)[0].cpu().numpy()\n",
323
+ " vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32)\n",
324
+ " store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
325
+ " if (i + 1) % 500 == 0:\n",
326
+ " np.savez(cache_path, **store)\n",
327
+ " if todo:\n",
328
+ " np.savez(cache_path, **store)\n",
329
+ " return store"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "id": "c5ed6f49",
335
+ "metadata": {},
336
+ "source": [
337
+ "## 4. Đọc & gộp nhãn theo wavID"
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": null,
343
+ "id": "910d097f",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "import pandas as pd\n",
348
+ "\n",
349
+ "def load_target_emotions():\n",
350
+ " tgt = {}\n",
351
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
352
+ " for ln in f:\n",
353
+ " parts = ln.strip().split(\"|\")\n",
354
+ " if len(parts) >= 2:\n",
355
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
356
+ " return tgt\n",
357
+ "\n",
358
+ "def _col(cols_map, *names, df=None, default_idx=None):\n",
359
+ " for n in names:\n",
360
+ " if n in cols_map:\n",
361
+ " return cols_map[n]\n",
362
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
363
+ "\n",
364
+ "def parse_emocat_votes(cell):\n",
365
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
366
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
367
+ " e = norm_emotion(tok)\n",
368
+ " if e in EMOTIONS5:\n",
369
+ " v[EMOTIONS5.index(e)] += 1.0\n",
370
+ " return v\n",
371
+ "\n",
372
+ "def load_train_labels():\n",
373
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
374
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
375
+ " wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
376
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
377
+ " val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
378
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
379
+ " assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
380
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
381
+ " rows = []\n",
382
+ " for sid, g in df.groupby(\"_stem\"):\n",
383
+ " rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
384
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
385
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
386
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
387
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
388
+ " if cat_col:\n",
389
+ " for cell in g[cat_col]:\n",
390
+ " votes += parse_emocat_votes(cell)\n",
391
+ " s = votes.sum()\n",
392
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
393
+ " for i in range(len(EMOTIONS5)):\n",
394
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
395
+ " rows.append(rec)\n",
396
+ " return pd.DataFrame(rows)\n",
397
+ "\n",
398
+ "target_map = load_target_emotions()\n",
399
+ "train_df = load_train_labels()\n",
400
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
401
+ "print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "markdown",
406
+ "id": "1d9509a7",
407
+ "metadata": {},
408
+ "source": [
409
+ "## 5. Dataset/loader — DÙNG thống kê chuẩn hóa TỪ CHECKPOINT (để khớp head đã train)"
410
+ ]
411
+ },
412
+ {
413
+ "cell_type": "code",
414
+ "execution_count": null,
415
+ "id": "4c09387b",
416
+ "metadata": {},
417
+ "outputs": [],
418
+ "source": [
419
+ "from torch.utils.data import Dataset, DataLoader\n",
420
+ "from sklearn.model_selection import train_test_split\n",
421
+ "\n",
422
+ "train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
423
+ "if LIMIT_TRAIN:\n",
424
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
425
+ "aud_tr = extract_audeering(train_stems, \"train\")\n",
426
+ "\n",
427
+ "lab = train_df.set_index(\"wavID\")\n",
428
+ "\n",
429
+ "# QUAN TRỌNG: lấy mean/std từ checkpoint (head đã train theo thang này) thay vì tính lại.\n",
430
+ "emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
431
+ "vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
432
+ "print(f\"Dùng chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
433
+ "\n",
434
+ "def onehot_target(tgt):\n",
435
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
436
+ " if tgt in EMOTIONS5:\n",
437
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
438
+ " return v\n",
439
+ "\n",
440
+ "class EmoDataset(Dataset):\n",
441
+ " def __init__(self, stems):\n",
442
+ " self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]\n",
443
+ " def __len__(self):\n",
444
+ " return len(self.stems)\n",
445
+ " def __getitem__(self, i):\n",
446
+ " s = self.stems[i]\n",
447
+ " wave = load_wav(s)\n",
448
+ " emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
449
+ " if HAS_VAD:\n",
450
+ " vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
451
+ " else:\n",
452
+ " vad = np.zeros(3, dtype=np.float32)\n",
453
+ " cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
454
+ " aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)\n",
455
+ " return {\"wave\": wave, \"tgt\": onehot_target(target_map.get(s)), \"aud\": aud,\n",
456
+ " \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
457
+ " \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
458
+ " \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
459
+ "\n",
460
+ "def collate(batch):\n",
461
+ " lens = [len(b[\"wave\"]) for b in batch]\n",
462
+ " L = max(lens)\n",
463
+ " waves = np.zeros((len(batch), L), dtype=np.float32)\n",
464
+ " mask = np.zeros((len(batch), L), dtype=np.float32)\n",
465
+ " for i, b in enumerate(batch):\n",
466
+ " waves[i, : len(b[\"wave\"])] = b[\"wave\"]; mask[i, : len(b[\"wave\"])] = 1.0\n",
467
+ " return {\n",
468
+ " \"input_values\": torch.from_numpy(waves), \"attn_mask\": torch.from_numpy(mask).long(),\n",
469
+ " \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
470
+ " \"aud\": torch.from_numpy(np.stack([b[\"aud\"] for b in batch])) if USE_AUDEERING else None,\n",
471
+ " \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
472
+ " \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
473
+ " \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
474
+ " \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
475
+ " \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
476
+ " }\n",
477
+ "\n",
478
+ "ds = EmoDataset(train_stems)\n",
479
+ "print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
480
+ "tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
481
+ "tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
482
+ "va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "markdown",
487
+ "id": "5c16d942",
488
+ "metadata": {},
489
+ "source": [
490
+ "## 6. Heads (NẠP từ checkpoint) + optimizer + train TIẾP"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": null,
496
+ "id": "7adfb320",
497
+ "metadata": {
498
+ "lines_to_next_cell": 1
499
+ },
500
+ "outputs": [],
501
+ "source": [
502
+ "from scipy.stats import spearmanr\n",
503
+ "\n",
504
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
505
+ "N_EMO = len(EMOTIONS5)\n",
506
+ "TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
507
+ "\n",
508
+ "class EmoHeads(nn.Module):\n",
509
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
510
+ " super().__init__()\n",
511
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
512
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
513
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
514
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
515
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
516
+ " def forward(self, feat, tgt):\n",
517
+ " h = self.trunk(feat)\n",
518
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
519
+ "\n",
520
+ "heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
521
+ "hmiss, hunexp = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
522
+ "print(f\"🔁 load heads từ checkpoint: thiếu {len(hmiss)} / dư {len(hunexp)} key (kỳ vọng 0).\")\n",
523
+ "print(f\"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM if USE_AUDEERING else 0})\")\n",
524
+ "\n",
525
+ "TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
526
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
527
+ "bb_params = [p for p in wavlm.parameters() if p.requires_grad]\n",
528
+ "head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
529
+ "opt = torch.optim.AdamW([\n",
530
+ " {\"params\": bb_params, \"lr\": LR_BACKBONE * RESUME_LR_SCALE},\n",
531
+ " {\"params\": head_params, \"lr\": LR_HEAD * RESUME_LR_SCALE},\n",
532
+ "], weight_decay=WEIGHT_DECAY)\n",
533
+ "scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
534
+ "mse = nn.MSELoss()\n",
535
+ "\n",
536
+ "def soft_ce(logits, target_dist):\n",
537
+ " return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
538
+ "\n",
539
+ "def forward_batch(b):\n",
540
+ " feat_wavlm = wavlm_embed(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
541
+ " feat = torch.cat([feat_wavlm, b[\"aud\"].to(device)], dim=1) if USE_AUDEERING else feat_wavlm\n",
542
+ " return heads(feat, b[\"tgt\"].to(device))\n",
543
+ "\n",
544
+ "def compute_loss(emos_p, cat_l, vad_p, b):\n",
545
+ " L = {}\n",
546
+ " L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
547
+ " L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
548
+ " if HAS_VAD:\n",
549
+ " vt = b[\"vad\"].to(device)\n",
550
+ " L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
551
+ " else:\n",
552
+ " z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
553
+ " if USE_UNCERTAINTY:\n",
554
+ " return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
555
+ " return sum(L.values())\n",
556
+ "\n",
557
+ "@torch.no_grad()\n",
558
+ "def evaluate():\n",
559
+ " wavlm.eval(); heads.eval()\n",
560
+ " P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
561
+ " catP, catY = [], []\n",
562
+ " for b in va_loader:\n",
563
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
564
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
565
+ " P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
566
+ " vad_p = vad_p.float().cpu().numpy()\n",
567
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
568
+ " P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
569
+ " catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
570
+ " out = {}\n",
571
+ " for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
572
+ " out[t] = spearmanr(P[t], Y[t]).correlation\n",
573
+ " q = np.concatenate(catP); p = np.concatenate(catY)\n",
574
+ " out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
575
+ " return out\n",
576
+ "\n",
577
+ "def mean_srcc(m):\n",
578
+ " keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
579
+ " return float(np.mean([m[k] for k in keys]))\n",
580
+ "\n",
581
+ "# Init best TỪ checkpoint hiện tại → chỉ lưu nếu train tiếp TỐT HƠN\n",
582
+ "m0 = evaluate(); best = mean_srcc(m0)\n",
583
+ "best_state = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
584
+ " \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
585
+ "print(f\"📍 Checkpoint hiện tại: mean SRCC = {best:.4f} | \"\n",
586
+ " + \" \".join(f\"{k}={m0[k]:.3f}\" for k in ['emos','val','aro','dom'] if k in m0))\n",
587
+ "\n",
588
+ "bad = 0\n",
589
+ "for ep in range(1, EPOCHS + 1):\n",
590
+ " wavlm.train(); heads.train()\n",
591
+ " opt.zero_grad(); run = 0.0; nb = 0\n",
592
+ " for step, b in enumerate(tqdm(tr_loader, desc=f\"+epoch {ep}\")):\n",
593
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
594
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
595
+ " loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
596
+ " scaler.scale(loss).backward()\n",
597
+ " if (step + 1) % ACCUM == 0:\n",
598
+ " scaler.step(opt); scaler.update(); opt.zero_grad()\n",
599
+ " run += loss.item() * ACCUM; nb += 1\n",
600
+ " m = evaluate(); sc = mean_srcc(m)\n",
601
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
602
+ " print(f\"+epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
603
+ " if sc > best:\n",
604
+ " best = sc\n",
605
+ " best_state = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
606
+ " \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
607
+ " bad = 0\n",
608
+ " else:\n",
609
+ " bad += 1\n",
610
+ " if bad >= PATIENCE:\n",
611
+ " print(f\"Early stop (resume) ở +epoch {ep}.\"); break\n",
612
+ "\n",
613
+ "wavlm.load_state_dict(best_state[\"wavlm\"]); heads.load_state_dict(best_state[\"heads\"])\n",
614
+ "final = evaluate()\n",
615
+ "print(\"\\n✅ VAL sau resume:\")\n",
616
+ "print(f\" EMOS={final['emos']:.4f} (ckpt {m0['emos']:.3f} · exp08 nộp {EXP08['emos']})\")\n",
617
+ "if HAS_VAD:\n",
618
+ " print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} \"\n",
619
+ " f\"(exp08 nộp {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
620
+ "print(f\" mean SRCC: ckpt {mean_srcc(m0):.4f} → sau resume {mean_srcc(final):.4f} \"\n",
621
+ " + (\"🚀 cải thiện\" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else \"➖ không cải thiện (giữ ckpt cũ)\"))\n",
622
+ "\n",
623
+ "torch.save({\"wavlm\": best_state[\"wavlm\"], \"heads\": best_state[\"heads\"],\n",
624
+ " \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
625
+ " \"WAVLM_DIM\": WAVLM_DIM, \"AUD_DIM\": AUD_DIM, \"UNFREEZE_TOP_LAYERS\": UNFREEZE_TOP_LAYERS,\n",
626
+ " \"val_emos\": final[\"emos\"]}, os.path.join(OUT_DIR, \"ft_emotion_full.pt\"))\n",
627
+ "print(\"Đã lưu FULL (có backbone):\", os.path.join(OUT_DIR, \"ft_emotion_full.pt\"))"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "markdown",
632
+ "id": "dd6bb5d8",
633
+ "metadata": {},
634
+ "source": [
635
+ "## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc từ resume; QMOS mượn exp07 hoặc UTMOSv2)"
636
+ ]
637
+ },
638
+ {
639
+ "cell_type": "code",
640
+ "execution_count": null,
641
+ "id": "6b7753e8",
642
+ "metadata": {
643
+ "lines_to_next_cell": 1
644
+ },
645
+ "outputs": [],
646
+ "source": [
647
+ "def list_dev():\n",
648
+ " with open(DEV_SCP) as f:\n",
649
+ " return [ln.strip() for ln in f if ln.strip()]\n",
650
+ "\n",
651
+ "dev_names = list_dev()\n",
652
+ "if LIMIT_DEV:\n",
653
+ " dev_names = dev_names[:LIMIT_DEV]\n",
654
+ "dev_stems = [stem(n) for n in dev_names]\n",
655
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
656
+ "aud_dev = extract_audeering(dev_stems, \"dev\")\n",
657
+ "\n",
658
+ "def load_exp07_qmos():\n",
659
+ " if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
660
+ " import csv\n",
661
+ " d = {}\n",
662
+ " with open(EXP07_ANSWER) as f:\n",
663
+ " for row in csv.DictReader(f):\n",
664
+ " d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
665
+ " print(f\"✅ Mượn QMOS từ exp07 ({EXP07_ANSWER}): {len(d)//2} wav\")\n",
666
+ " return d\n",
667
+ " return None\n",
668
+ "\n",
669
+ "qmos_map = load_exp07_qmos()\n",
670
+ "if qmos_map is None:\n",
671
+ " print(\"ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).\")\n",
672
+ " pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
673
+ " import utmosv2\n",
674
+ " v2 = utmosv2.create_model(pretrained=True)\n",
675
+ " qmos_map = {}\n",
676
+ " for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
677
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
678
+ " if not os.path.exists(wav):\n",
679
+ " continue\n",
680
+ " out = v2.predict(input_path=wav)\n",
681
+ " qmos_map[n] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
682
+ " del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
683
+ "\n",
684
+ "@torch.no_grad()\n",
685
+ "def predict_emotion(sid):\n",
686
+ " wave = load_wav(sid)\n",
687
+ " if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
688
+ " return None\n",
689
+ " wavlm.eval(); heads.eval()\n",
690
+ " iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
691
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
692
+ " tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
693
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
694
+ " fw = wavlm_embed(iv, am)\n",
695
+ " feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw\n",
696
+ " emos_p, cat_l, vad_p = heads(feat, tgt)\n",
697
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
698
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
699
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
700
+ " return emos, cat5, vad3\n",
701
+ "\n",
702
+ "def fmt_cat(p5):\n",
703
+ " return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
704
+ "\n",
705
+ "def build_answer(out_path):\n",
706
+ " n_real = n_def = 0\n",
707
+ " with open(out_path, \"w\") as f:\n",
708
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
709
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
710
+ " sid = stem(name)\n",
711
+ " pr = predict_emotion(sid)\n",
712
+ " if pr is None:\n",
713
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
714
+ " else:\n",
715
+ " emos, cat5, vad3 = pr; n_real += 1\n",
716
+ " qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
717
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
718
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
719
+ "\n",
720
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
721
+ "build_answer(answer_path)"
722
+ ]
723
+ },
724
+ {
725
+ "cell_type": "markdown",
726
+ "id": "7dac208f",
727
+ "metadata": {},
728
+ "source": [
729
+ "## 8. Validate + zip"
730
+ ]
731
+ },
732
+ {
733
+ "cell_type": "code",
734
+ "execution_count": null,
735
+ "id": "c123c058",
736
+ "metadata": {},
737
+ "outputs": [],
738
+ "source": [
739
+ "def validate(path):\n",
740
+ " import csv\n",
741
+ " with open(path) as f:\n",
742
+ " rows = list(csv.reader(f))\n",
743
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0], \"Header sai\"\n",
744
+ " for i, r in enumerate(rows[1:], 2):\n",
745
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
746
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
747
+ "\n",
748
+ "validate(answer_path)\n",
749
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp08_resume.zip answer.txt && unzip -l submission_track2_exp08_resume.zip\")\n",
750
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp08_resume.zip\"))"
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "markdown",
755
+ "id": "e5f82cf0",
756
+ "metadata": {},
757
+ "source": [
758
+ "## Ghi chú\n",
759
+ "- **Đầu vào bắt buộc:** `RESUME_CKPT` = `ft_emotion_full.pt` (CÓ backbone). Bản `ft_emotion_meta.pt` cũ chỉ\n",
760
+ " có heads → cell 2 sẽ assert lỗi nhắc dùng file đủ.\n",
761
+ "- **Cache:** trỏ `CACHE_INPUT` tới dataset chứa `aud_train.npz`/`aud_dev.npz` → khỏi trích lại audeering.\n",
762
+ " Nếu LIMIT khác lần trước, cache thiếu stem nào sẽ tự trích bù (resume theo stem).\n",
763
+ "- **Chuẩn hóa lấy TỪ checkpoint** (`emos_mu/sd`, `vad_mu/sd`) → khớp thang head đã train (đừng tính lại).\n",
764
+ "- **best init từ checkpoint** → chỉ lưu nếu train tiếp THỰC SỰ tốt hơn (không sợ tụt).\n",
765
+ "- Nếu val chững: đặt `RESUME_LR_SCALE=0.5` (giảm LR) hoặc tăng `UNFREEZE_TOP_LAYERS` (lưu ý: mở thêm lớp\n",
766
+ " thì lớp mới chưa được train trong checkpoint → cần nhiều epoch hơn).\n",
767
+ "- QMOS: tốt nhất Add Input `answer.txt` exp07 (0.548). Để trộn cột chuẩn, xem kết quả exp08: 5 cột cảm xúc\n",
768
+ " resume + QMOS exp07 → hệ thống mạnh nhất 6 cột.\n",
769
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md`."
770
+ ]
771
+ }
772
+ ],
773
+ "metadata": {
774
+ "jupytext": {
775
+ "cell_metadata_filter": "-all",
776
+ "main_language": "python",
777
+ "notebook_metadata_filter": "-all"
778
+ }
779
+ },
780
+ "nbformat": 4,
781
+ "nbformat_minor": 5
782
+ }
track2/exp08b_finetune_resume_pipeline.py ADDED
@@ -0,0 +1,642 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp08-RESUME (fine-tune TIẾP từ checkpoint + cache) — Kaggle
3
+ #
4
+ # **Mục đích:** train tiếp model fine-tune cảm xúc (exp08) từ **checkpoint đã lưu** thay vì train lại từ
5
+ # đầu — tiết kiệm giờ GPU. Tận dụng:
6
+ # - `ft_emotion_full.pt` (CÓ cả backbone WavLM + heads + thống kê chuẩn hóa) → nạp lại đúng trạng thái.
7
+ # - **cache audeering** `aud_*.npz` (đặc trưng frozen) → KHÔNG trích lại (~đỡ chục phút).
8
+ #
9
+ # > ⚠️ Bắt buộc dùng checkpoint **đủ backbone** (`ft_emotion_full.pt` từ cell "TRAIN TIẾP", hoặc bản
10
+ # > `ft_emotion_meta.pt` MỚI đã vá để lưu cả `wavlm`). Bản `ft_emotion_meta.pt` CŨ chỉ có `heads` → KHÔNG dùng được.
11
+ #
12
+ # ## Chuẩn bị input trên Kaggle (Add Input)
13
+ # 1. Dataset Track 2 (`vmc2026-track2-full`) — wav + nhãn.
14
+ # 2. **Checkpoint**: upload `ft_emotion_full.pt` thành 1 Dataset → trỏ `RESUME_CKPT`.
15
+ # 3. **Cache** (tùy chọn nhưng nên có): upload thư mục chứa `aud_train.npz`, `aud_dev.npz` → trỏ `CACHE_INPUT`.
16
+ # 4. (tùy chọn) `answer.txt` exp07 để mượn cột QMOS 0.548.
17
+ #
18
+ # **Cách chạy:** GPU T4 + Internet On → sửa các slug ở cell 0 → Run All. Lần đầu để `LIMIT_TRAIN=300`.
19
+
20
+ # %% [markdown]
21
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
22
+
23
+ # %%
24
+ import os, shutil
25
+
26
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
27
+ WAV_DIR = f"{DATA_ROOT}/wav"
28
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv"
29
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv"
30
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
31
+
32
+ # ── Checkpoint + cache để RESUME ─────────────────────────────────────────────
33
+ RESUME_CKPT = "/kaggle/input/ft-emotion-full/ft_emotion_full.pt" # << CHECKPOINT đủ backbone
34
+ CACHE_INPUT = "/kaggle/input/ft-emotion-cache" # << thư mục chứa aud_*.npz (hoặc "" nếu không có)
35
+ EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) mượn QMOS 0.548; không có → UTMOSv2
36
+
37
+ OUT_DIR = "/kaggle/working"
38
+ CACHE_DIR = "/kaggle/working/ft_cache" # /kaggle/input read-only → copy cache sang đây để ghi/append được
39
+ os.makedirs(CACHE_DIR, exist_ok=True)
40
+
41
+ # ── Fine-tune / siêu tham số (train TIẾP) ────────────────────────────────────
42
+ DEVICE = "cuda"
43
+ SR = 16000
44
+ MAX_SECONDS = 8
45
+ UNFREEZE_TOP_LAYERS = 6 # PHẢI khớp checkpoint (mặc định exp08 = 6)
46
+ TRUNK_HIDDEN = 512 # PHẢI khớp checkpoint
47
+ HEAD_HIDDEN = 128 # PHẢI khớp checkpoint
48
+ DROPOUT = 0.3
49
+ LR_BACKBONE = 1e-5
50
+ LR_HEAD = 1e-3
51
+ RESUME_LR_SCALE = 1.0 # <1.0 để giảm LR khi train tiếp (vd 0.5 nếu val đã chững)
52
+ WEIGHT_DECAY = 1e-5
53
+ EPOCHS = 10 # số epoch train THÊM (run này)
54
+ PATIENCE = 5 # dừng khi val không lên; LUÔN giữ best
55
+ BATCH = 4
56
+ ACCUM = 8 # effective batch = 32
57
+ VAL_FRAC = 0.10
58
+ SEED = 42
59
+ USE_AMP = True
60
+ USE_GRAD_CKPT = True
61
+ USE_AUDEERING = True # PHẢI khớp checkpoint (exp08 = True)
62
+ USE_UNCERTAINTY = True
63
+
64
+ LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
65
+ LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
66
+
67
+ # Mốc exp07 + exp08 để so
68
+ EXP07 = {"emos": 0.795, "cat_err": 0.153, "val": 0.581, "aro": 0.752, "dom": 0.705}
69
+ EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751} # bản đã nộp
70
+
71
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
72
+ _EMO_ALIAS = {
73
+ "angry": "angry", "anger": "angry",
74
+ "happy": "happy", "happiness": "happy", "joy": "happy",
75
+ "neutral": "neutral", "calm": "neutral",
76
+ "sad": "sad", "sadness": "sad",
77
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
78
+ }
79
+
80
+ def norm_emotion(label):
81
+ key = str(label).strip().lower()
82
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
83
+
84
+ def stem(p):
85
+ return os.path.splitext(os.path.basename(str(p)))[0]
86
+
87
+ print("DATA_ROOT:", DATA_ROOT)
88
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP, RESUME_CKPT]:
89
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
90
+
91
+ # Copy cache (aud_*.npz) từ input read-only sang working để append được
92
+ if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
93
+ n = 0
94
+ for fn in os.listdir(CACHE_INPUT):
95
+ if fn.startswith("aud_") and fn.endswith(".npz"):
96
+ shutil.copy(os.path.join(CACHE_INPUT, fn), os.path.join(CACHE_DIR, fn)); n += 1
97
+ print(f"📦 Copy {n} file cache audeering từ {CACHE_INPUT} → {CACHE_DIR}")
98
+ else:
99
+ print("ℹ️ Không có CACHE_INPUT → sẽ tự trích audeering (chậm hơn lần đầu).")
100
+
101
+ # %% [markdown]
102
+ # ## 1. Cài đặt + tải code SAILER (để dựng đúng kiến trúc WavLM rồi nạp checkpoint đè lên)
103
+
104
+ # %%
105
+ import sys, subprocess
106
+
107
+ def pip_install(*pkgs):
108
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
109
+
110
+ pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile",
111
+ "scipy", "scikit-learn", "pandas", "tqdm")
112
+
113
+ REPO_DIR = "/kaggle/working/vox-profile-release"
114
+ if not os.path.exists(REPO_DIR):
115
+ subprocess.run(["git", "clone", "--depth", "1",
116
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
117
+ if REPO_DIR not in sys.path:
118
+ sys.path.insert(0, REPO_DIR)
119
+
120
+ # %% [markdown]
121
+ # ## 2. Dựng WavLM (như exp08) → NẠP trọng số backbone từ checkpoint
122
+ # Dựng đúng kiến trúc (SAILER wrapper → lấy backbone HF; fallback WavLM trắng), rồi `load_state_dict`
123
+ # bằng `ckpt["wavlm"]` → khôi phục đúng trạng thái fine-tune đã lưu.
124
+
125
+ # %%
126
+ import torch
127
+ import torch.nn as nn
128
+ import torch.nn.functional as F
129
+
130
+ device = DEVICE if torch.cuda.is_available() else "cpu"
131
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
132
+
133
+ ckpt = torch.load(RESUME_CKPT, map_location="cpu", weights_only=False) # ckpt có numpy (vad_mu) → cần False
134
+ assert "wavlm" in ckpt, ("❌ Checkpoint KHÔNG có 'wavlm' (backbone). Đây là bản ft_emotion_meta.pt CŨ "
135
+ "chỉ lưu heads → không resume được. Hãy dùng ft_emotion_full.pt.")
136
+ print("✅ Nạp checkpoint:", RESUME_CKPT, "| keys:", list(ckpt.keys()))
137
+
138
+ def find_hf_backbone(module):
139
+ cands = []
140
+ for name, m in module.named_modules():
141
+ enc = getattr(m, "encoder", None)
142
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
143
+ and getattr(enc, "layers", None) is not None:
144
+ cands.append((name, m))
145
+ if not cands:
146
+ return None, None
147
+ cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
148
+ return cands[0]
149
+
150
+ wavlm = None
151
+ try:
152
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
153
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
154
+ name, wavlm = find_hf_backbone(_wrapper)
155
+ if wavlm is not None:
156
+ print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'")
157
+ except Exception as e:
158
+ print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
159
+
160
+ if wavlm is None:
161
+ from transformers import WavLMModel
162
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
163
+ print("ℹ️ Fallback: microsoft/wavlm-large.")
164
+
165
+ wavlm = wavlm.to(device)
166
+ WAVLM_DIM = int(wavlm.config.hidden_size)
167
+
168
+ # Nạp trọng số đã fine-tune từ checkpoint (đè lên kiến trúc vừa dựng)
169
+ miss, unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
170
+ print(f"🔁 load wavlm từ checkpoint: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).")
171
+ if len(miss) > 20 or len(unexp) > 20:
172
+ print(" ⚠️ Lệch key nhiều → kiến trúc có thể không khớp checkpoint. Kiểm tra UNFREEZE/USE_AUDEERING.")
173
+
174
+ # Đóng băng partial: chỉ mở UNFREEZE_TOP_LAYERS lớp trên
175
+ for p in wavlm.parameters():
176
+ p.requires_grad = False
177
+ enc_layers = wavlm.encoder.layers
178
+ n_layers = len(enc_layers)
179
+ for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:
180
+ for p in layer.parameters():
181
+ p.requires_grad = True
182
+ n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)
183
+ print(f"WavLM: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {WAVLM_DIM})")
184
+
185
+ if USE_GRAD_CKPT:
186
+ wavlm.gradient_checkpointing_enable()
187
+ if hasattr(wavlm, "enable_input_require_grads"):
188
+ wavlm.enable_input_require_grads()
189
+
190
+ def masked_mean(hidden, attn_mask):
191
+ if attn_mask is None:
192
+ return hidden.mean(dim=1)
193
+ try:
194
+ fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
195
+ except Exception:
196
+ return hidden.mean(dim=1)
197
+ fm = fm.unsqueeze(-1).to(hidden.dtype)
198
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
199
+
200
+ def wavlm_embed(input_values, attn_mask):
201
+ out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
202
+ return masked_mean(out, attn_mask)
203
+
204
+ # %% [markdown]
205
+ # ## 3. audeering FROZEN (đặc trưng phụ) — dùng cache nếu có
206
+
207
+ # %%
208
+ import numpy as np
209
+ import librosa
210
+ from tqdm.auto import tqdm
211
+
212
+ AUD_DIM = 0
213
+ aud_backbone = aud_head = aud_proc = None
214
+ if USE_AUDEERING:
215
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
216
+ from huggingface_hub import hf_hub_download
217
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
218
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
219
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
220
+ aud_backbone = Wav2Vec2Model(aud_cfg)
221
+ try:
222
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
223
+ hf_hub_download(AUD_NAME, "model.safetensors"))
224
+ except Exception:
225
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
226
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
227
+ aud_backbone.load_state_dict(bb_sd, strict=False)
228
+ _hid = _sd["classifier.dense.weight"].shape[0]
229
+ _out = _sd["classifier.out_proj.weight"].shape[0]
230
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
231
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
232
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
233
+ aud_backbone = aud_backbone.to(device).eval()
234
+ aud_head = aud_head.to(device).eval()
235
+ AUD_DIM = _hid + 3
236
+ print(f"✅ audeering frozen ({AUD_DIM}-D)")
237
+
238
+ def load_wav(name_or_stem):
239
+ p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(
240
+ WAV_DIR, name_or_stem if str(name_or_stem).endswith(".wav") else str(name_or_stem) + ".wav")
241
+ if not os.path.exists(p):
242
+ return None
243
+ wave, _ = librosa.load(p, sr=SR, mono=True)
244
+ return wave[: MAX_SECONDS * SR].astype(np.float32)
245
+
246
+ @torch.no_grad()
247
+ def extract_audeering(stems, tag):
248
+ if not USE_AUDEERING:
249
+ return {}
250
+ cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
251
+ store = {}
252
+ if os.path.exists(cache_path):
253
+ z = np.load(cache_path, allow_pickle=True)
254
+ store = {k: z[k] for k in z.files}
255
+ print(f"[aud/{tag}] nạp cache: {len(store)}")
256
+ todo = [s for s in stems if s not in store]
257
+ for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
258
+ wave = load_wav(s)
259
+ if wave is None:
260
+ continue
261
+ x = aud_proc(wave, sampling_rate=SR).input_values[0]
262
+ x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
263
+ h = aud_backbone(x)[0].mean(dim=1)
264
+ out = aud_head(h)[0].cpu().numpy()
265
+ vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32)
266
+ store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
267
+ if (i + 1) % 500 == 0:
268
+ np.savez(cache_path, **store)
269
+ if todo:
270
+ np.savez(cache_path, **store)
271
+ return store
272
+
273
+ # %% [markdown]
274
+ # ## 4. Đọc & gộp nhãn theo wavID
275
+
276
+ # %%
277
+ import pandas as pd
278
+
279
+ def load_target_emotions():
280
+ tgt = {}
281
+ with open(METADATA_CSV, encoding="utf-8") as f:
282
+ for ln in f:
283
+ parts = ln.strip().split("|")
284
+ if len(parts) >= 2:
285
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
286
+ return tgt
287
+
288
+ def _col(cols_map, *names, df=None, default_idx=None):
289
+ for n in names:
290
+ if n in cols_map:
291
+ return cols_map[n]
292
+ return list(df.columns)[default_idx] if default_idx is not None else None
293
+
294
+ def parse_emocat_votes(cell):
295
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
296
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
297
+ e = norm_emotion(tok)
298
+ if e in EMOTIONS5:
299
+ v[EMOTIONS5.index(e)] += 1.0
300
+ return v
301
+
302
+ def load_train_labels():
303
+ df = pd.read_csv(TRAIN_CSV, sep="|")
304
+ cols = {c.lower().strip(): c for c in df.columns}
305
+ wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
306
+ emos_col = _col(cols, "emos", "emo", "emomos")
307
+ val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
308
+ cat_col = _col(cols, "emocat", "cat", "emotion")
309
+ assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
310
+ df["_stem"] = df[wav_col].map(stem)
311
+ rows = []
312
+ for sid, g in df.groupby("_stem"):
313
+ rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
314
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
315
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
316
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
317
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
318
+ if cat_col:
319
+ for cell in g[cat_col]:
320
+ votes += parse_emocat_votes(cell)
321
+ s = votes.sum()
322
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
323
+ for i in range(len(EMOTIONS5)):
324
+ rec[f"cat{i}"] = float(cat[i])
325
+ rows.append(rec)
326
+ return pd.DataFrame(rows)
327
+
328
+ target_map = load_target_emotions()
329
+ train_df = load_train_labels()
330
+ HAS_VAD = bool(train_df["val"].notna().any())
331
+ print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
332
+
333
+ # %% [markdown]
334
+ # ## 5. Dataset/loader — DÙNG thống kê chuẩn hóa TỪ CHECKPOINT (để khớp head đã train)
335
+
336
+ # %%
337
+ from torch.utils.data import Dataset, DataLoader
338
+ from sklearn.model_selection import train_test_split
339
+
340
+ train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
341
+ if LIMIT_TRAIN:
342
+ train_stems = train_stems[:LIMIT_TRAIN]
343
+ aud_tr = extract_audeering(train_stems, "train")
344
+
345
+ lab = train_df.set_index("wavID")
346
+
347
+ # QUAN TRỌNG: lấy mean/std từ checkpoint (head đã train theo thang này) thay vì tính lại.
348
+ emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
349
+ vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
350
+ print(f"Dùng chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
351
+
352
+ def onehot_target(tgt):
353
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
354
+ if tgt in EMOTIONS5:
355
+ v[EMOTIONS5.index(tgt)] = 1.0
356
+ return v
357
+
358
+ class EmoDataset(Dataset):
359
+ def __init__(self, stems):
360
+ self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]
361
+ def __len__(self):
362
+ return len(self.stems)
363
+ def __getitem__(self, i):
364
+ s = self.stems[i]
365
+ wave = load_wav(s)
366
+ emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
367
+ if HAS_VAD:
368
+ vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
369
+ else:
370
+ vad = np.zeros(3, dtype=np.float32)
371
+ cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
372
+ aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)
373
+ return {"wave": wave, "tgt": onehot_target(target_map.get(s)), "aud": aud,
374
+ "emos": np.float32(emos), "vad": vad, "cat": cat,
375
+ "emos_raw": np.float32(lab.loc[s, "emos"]),
376
+ "vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
377
+
378
+ def collate(batch):
379
+ lens = [len(b["wave"]) for b in batch]
380
+ L = max(lens)
381
+ waves = np.zeros((len(batch), L), dtype=np.float32)
382
+ mask = np.zeros((len(batch), L), dtype=np.float32)
383
+ for i, b in enumerate(batch):
384
+ waves[i, : len(b["wave"])] = b["wave"]; mask[i, : len(b["wave"])] = 1.0
385
+ return {
386
+ "input_values": torch.from_numpy(waves), "attn_mask": torch.from_numpy(mask).long(),
387
+ "tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
388
+ "aud": torch.from_numpy(np.stack([b["aud"] for b in batch])) if USE_AUDEERING else None,
389
+ "emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
390
+ "vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
391
+ "cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
392
+ "emos_raw": np.stack([b["emos_raw"] for b in batch]),
393
+ "vad_raw": np.stack([b["vad_raw"] for b in batch]),
394
+ }
395
+
396
+ ds = EmoDataset(train_stems)
397
+ print("Dataset hợp lệ:", len(ds), "wav")
398
+ tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
399
+ tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
400
+ va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
401
+
402
+ # %% [markdown]
403
+ # ## 6. Heads (NẠP từ checkpoint) + optimizer + train TIẾP
404
+
405
+ # %%
406
+ from scipy.stats import spearmanr
407
+
408
+ torch.manual_seed(SEED); np.random.seed(SEED)
409
+ N_EMO = len(EMOTIONS5)
410
+ TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
411
+
412
+ class EmoHeads(nn.Module):
413
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
414
+ super().__init__()
415
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
416
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
417
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
418
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
419
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
420
+ def forward(self, feat, tgt):
421
+ h = self.trunk(feat)
422
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
423
+
424
+ heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
425
+ hmiss, hunexp = heads.load_state_dict(ckpt["heads"], strict=False)
426
+ print(f"🔁 load heads từ checkpoint: thiếu {len(hmiss)} / dư {len(hunexp)} key (kỳ vọng 0).")
427
+ print(f"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM if USE_AUDEERING else 0})")
428
+
429
+ TASKS = ["emos", "cat", "val", "aro", "dom"]
430
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
431
+ bb_params = [p for p in wavlm.parameters() if p.requires_grad]
432
+ head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
433
+ opt = torch.optim.AdamW([
434
+ {"params": bb_params, "lr": LR_BACKBONE * RESUME_LR_SCALE},
435
+ {"params": head_params, "lr": LR_HEAD * RESUME_LR_SCALE},
436
+ ], weight_decay=WEIGHT_DECAY)
437
+ scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
438
+ mse = nn.MSELoss()
439
+
440
+ def soft_ce(logits, target_dist):
441
+ return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
442
+
443
+ def forward_batch(b):
444
+ feat_wavlm = wavlm_embed(b["input_values"].to(device), b["attn_mask"].to(device))
445
+ feat = torch.cat([feat_wavlm, b["aud"].to(device)], dim=1) if USE_AUDEERING else feat_wavlm
446
+ return heads(feat, b["tgt"].to(device))
447
+
448
+ def compute_loss(emos_p, cat_l, vad_p, b):
449
+ L = {}
450
+ L["emos"] = mse(emos_p, b["emos"].to(device))
451
+ L["cat"] = soft_ce(cat_l, b["cat"].to(device))
452
+ if HAS_VAD:
453
+ vt = b["vad"].to(device)
454
+ L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
455
+ else:
456
+ z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
457
+ if USE_UNCERTAINTY:
458
+ return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
459
+ return sum(L.values())
460
+
461
+ @torch.no_grad()
462
+ def evaluate():
463
+ wavlm.eval(); heads.eval()
464
+ P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
465
+ catP, catY = [], []
466
+ for b in va_loader:
467
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
468
+ emos_p, cat_l, vad_p = forward_batch(b)
469
+ P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
470
+ vad_p = vad_p.float().cpu().numpy()
471
+ for j, t in enumerate(["val", "aro", "dom"]):
472
+ P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
473
+ catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
474
+ out = {}
475
+ for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
476
+ out[t] = spearmanr(P[t], Y[t]).correlation
477
+ q = np.concatenate(catP); p = np.concatenate(catY)
478
+ out["cat_err"] = float(np.abs(q - p).sum(1).mean())
479
+ return out
480
+
481
+ def mean_srcc(m):
482
+ keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
483
+ return float(np.mean([m[k] for k in keys]))
484
+
485
+ # Init best TỪ checkpoint hiện tại → chỉ lưu nếu train tiếp TỐT HƠN
486
+ m0 = evaluate(); best = mean_srcc(m0)
487
+ best_state = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
488
+ "heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
489
+ print(f"📍 Checkpoint hiện tại: mean SRCC = {best:.4f} | "
490
+ + " ".join(f"{k}={m0[k]:.3f}" for k in ['emos','val','aro','dom'] if k in m0))
491
+
492
+ bad = 0
493
+ for ep in range(1, EPOCHS + 1):
494
+ wavlm.train(); heads.train()
495
+ opt.zero_grad(); run = 0.0; nb = 0
496
+ for step, b in enumerate(tqdm(tr_loader, desc=f"+epoch {ep}")):
497
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
498
+ emos_p, cat_l, vad_p = forward_batch(b)
499
+ loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
500
+ scaler.scale(loss).backward()
501
+ if (step + 1) % ACCUM == 0:
502
+ scaler.step(opt); scaler.update(); opt.zero_grad()
503
+ run += loss.item() * ACCUM; nb += 1
504
+ m = evaluate(); sc = mean_srcc(m)
505
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
506
+ print(f"+epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
507
+ if sc > best:
508
+ best = sc
509
+ best_state = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
510
+ "heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
511
+ bad = 0
512
+ else:
513
+ bad += 1
514
+ if bad >= PATIENCE:
515
+ print(f"Early stop (resume) ở +epoch {ep}."); break
516
+
517
+ wavlm.load_state_dict(best_state["wavlm"]); heads.load_state_dict(best_state["heads"])
518
+ final = evaluate()
519
+ print("\n✅ VAL sau resume:")
520
+ print(f" EMOS={final['emos']:.4f} (ckpt {m0['emos']:.3f} · exp08 nộp {EXP08['emos']})")
521
+ if HAS_VAD:
522
+ print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} "
523
+ f"(exp08 nộp {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
524
+ print(f" mean SRCC: ckpt {mean_srcc(m0):.4f} → sau resume {mean_srcc(final):.4f} "
525
+ + ("🚀 cải thiện" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else "➖ không cải thiện (giữ ckpt cũ)"))
526
+
527
+ torch.save({"wavlm": best_state["wavlm"], "heads": best_state["heads"],
528
+ "emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
529
+ "WAVLM_DIM": WAVLM_DIM, "AUD_DIM": AUD_DIM, "UNFREEZE_TOP_LAYERS": UNFREEZE_TOP_LAYERS,
530
+ "val_emos": final["emos"]}, os.path.join(OUT_DIR, "ft_emotion_full.pt"))
531
+ print("Đã lưu FULL (có backbone):", os.path.join(OUT_DIR, "ft_emotion_full.pt"))
532
+
533
+ # %% [markdown]
534
+ # ## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc từ resume; QMOS mượn exp07 hoặc UTMOSv2)
535
+
536
+ # %%
537
+ def list_dev():
538
+ with open(DEV_SCP) as f:
539
+ return [ln.strip() for ln in f if ln.strip()]
540
+
541
+ dev_names = list_dev()
542
+ if LIMIT_DEV:
543
+ dev_names = dev_names[:LIMIT_DEV]
544
+ dev_stems = [stem(n) for n in dev_names]
545
+ print("DEV:", len(dev_names), "mẫu")
546
+ aud_dev = extract_audeering(dev_stems, "dev")
547
+
548
+ def load_exp07_qmos():
549
+ if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
550
+ import csv
551
+ d = {}
552
+ with open(EXP07_ANSWER) as f:
553
+ for row in csv.DictReader(f):
554
+ d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
555
+ print(f"✅ Mượn QMOS từ exp07 ({EXP07_ANSWER}): {len(d)//2} wav")
556
+ return d
557
+ return None
558
+
559
+ qmos_map = load_exp07_qmos()
560
+ if qmos_map is None:
561
+ print("ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).")
562
+ pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
563
+ import utmosv2
564
+ v2 = utmosv2.create_model(pretrained=True)
565
+ qmos_map = {}
566
+ for n in tqdm(dev_names, desc="UTMOSv2"):
567
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
568
+ if not os.path.exists(wav):
569
+ continue
570
+ out = v2.predict(input_path=wav)
571
+ qmos_map[n] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
572
+ del v2; torch.cuda.empty_cache() if device == "cuda" else None
573
+
574
+ @torch.no_grad()
575
+ def predict_emotion(sid):
576
+ wave = load_wav(sid)
577
+ if wave is None or (USE_AUDEERING and sid not in aud_dev):
578
+ return None
579
+ wavlm.eval(); heads.eval()
580
+ iv = torch.from_numpy(wave).unsqueeze(0).to(device)
581
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
582
+ tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
583
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
584
+ fw = wavlm_embed(iv, am)
585
+ feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw
586
+ emos_p, cat_l, vad_p = heads(feat, tgt)
587
+ emos = float(emos_p.item()) * emos_sd + emos_mu
588
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
589
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
590
+ return emos, cat5, vad3
591
+
592
+ def fmt_cat(p5):
593
+ return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
594
+
595
+ def build_answer(out_path):
596
+ n_real = n_def = 0
597
+ with open(out_path, "w") as f:
598
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
599
+ for name in tqdm(dev_names, desc="answer"):
600
+ sid = stem(name)
601
+ pr = predict_emotion(sid)
602
+ if pr is None:
603
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
604
+ else:
605
+ emos, cat5, vad3 = pr; n_real += 1
606
+ qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
607
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
608
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
609
+
610
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
611
+ build_answer(answer_path)
612
+
613
+ # %% [markdown]
614
+ # ## 8. Validate + zip
615
+
616
+ # %%
617
+ def validate(path):
618
+ import csv
619
+ with open(path) as f:
620
+ rows = list(csv.reader(f))
621
+ assert rows[0][0] == "wav" and "QMOS" in rows[0], "Header sai"
622
+ for i, r in enumerate(rows[1:], 2):
623
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
624
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
625
+
626
+ validate(answer_path)
627
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp08_resume.zip answer.txt && unzip -l submission_track2_exp08_resume.zip")
628
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp08_resume.zip"))
629
+
630
+ # %% [markdown]
631
+ # ## Ghi chú
632
+ # - **Đầu vào bắt buộc:** `RESUME_CKPT` = `ft_emotion_full.pt` (CÓ backbone). Bản `ft_emotion_meta.pt` cũ chỉ
633
+ # có heads → cell 2 sẽ assert lỗi nhắc dùng file đủ.
634
+ # - **Cache:** trỏ `CACHE_INPUT` tới dataset chứa `aud_train.npz`/`aud_dev.npz` → khỏi trích lại audeering.
635
+ # Nếu LIMIT khác lần trước, cache thiếu stem nào sẽ tự trích bù (resume theo stem).
636
+ # - **Chuẩn hóa lấy TỪ checkpoint** (`emos_mu/sd`, `vad_mu/sd`) → khớp thang head đã train (đừng tính lại).
637
+ # - **best init từ checkpoint** → chỉ lưu nếu train tiếp THỰC SỰ tốt hơn (không sợ tụt).
638
+ # - Nếu val chững: đặt `RESUME_LR_SCALE=0.5` (giảm LR) hoặc tăng `UNFREEZE_TOP_LAYERS` (lưu ý: mở thêm lớp
639
+ # thì lớp mới chưa được train trong checkpoint → cần nhiều epoch hơn).
640
+ # - QMOS: tốt nhất Add Input `answer.txt` exp07 (0.548). Để trộn cột chuẩn, xem kết quả exp08: 5 cột cảm xúc
641
+ # resume + QMOS exp07 → hệ thống mạnh nhất 6 cột.
642
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md`.
track2/exp09a_qmos_utmosv2_probe.ipynb ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "0afb9ac3",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp09a (PROBE: UTMOSv2 vs UTMOS cho QMOS) — Kaggle\n",
9
+ "\n",
10
+ "**Mục đích (rẻ, KHÔNG tốn lượt nộp):** trước khi fine-tune QMOS, kiểm tra xem\n",
11
+ "**UTMOSv2** (hệ thống **T05 — vô địch VoiceMOS Challenge 2024 Track 1**, naturalness MOS)\n",
12
+ "có **mạnh hơn UTMOS 2022** (đang dùng) trên dữ liệu Track 2 hay không.\n",
13
+ "\n",
14
+ "## Ý tưởng A/B không tốn lượt nộp\n",
15
+ "Tập **train** Track 2 CÓ nhãn `qMOS` thật (`sets/train.csv`). Ta:\n",
16
+ "1. Chấm một mẫu train bằng **UTMOS** (torch.hub `utmos22_strong`) — baseline đang dùng.\n",
17
+ "2. Chấm cùng mẫu đó bằng **UTMOSv2** (`sarulab-speech/UTMOSv2`, MIT).\n",
18
+ "3. So **SRCC mỗi model vs nhãn qMOS vàng** → biết model nào \"xếp hạng\" giống người chấm hơn.\n",
19
+ "\n",
20
+ "> SRCC chấm **thứ hạng** (scale-invariant) → khỏi lo lệch thang điểm. Mẫu ~2.000 wav là đủ ổn định.\n",
21
+ "\n",
22
+ "## Vì sao đáng thử\n",
23
+ "- UTMOSv2 = #1 ở 7/16 metric VMC2024 Track 1 (bỏ xa hạng 3) → bản kế nhiệm trực tiếp của UTMOS.\n",
24
+ "- **Lưu ý:** UTMOSv2 cũng train trên giọng *không* cảm xúc → vẫn có thể lệch domain; A/B này để\n",
25
+ " biết nó có **đáng** làm \"neo\" mạnh hơn cho head QMOS fine-tune (exp09) hay không.\n",
26
+ "\n",
27
+ "**Cách chạy:** GPU T4 + **Internet On** (UTMOSv2 cài từ git + tải checkpoint) → Add Input dataset\n",
28
+ "Track 2 → sửa `DATA_ROOT` → Run All. Lần đầu để `PROBE_N=300` cho nhanh, OK rồi tăng `2000`."
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "markdown",
33
+ "id": "c4f6fdc3",
34
+ "metadata": {},
35
+ "source": [
36
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "id": "df72dc96",
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "import os\n",
47
+ "\n",
48
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
49
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
50
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
51
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
52
+ "\n",
53
+ "OUT_DIR = \"/kaggle/working\"\n",
54
+ "CACHE_DIR = \"/kaggle/working/qmos_probe_cache\"\n",
55
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
56
+ "\n",
57
+ "DEVICE = \"cuda\"\n",
58
+ "PROBE_N = 2000 # số wav train để A/B (lần đầu để 300 cho nhanh). SRCC ~2000 mẫu đã ổn định.\n",
59
+ "SEED = 42\n",
60
+ "\n",
61
+ "# (Tùy chọn) Nếu muốn TẠO LUÔN answer.txt đổi cột QMOS←UTMOSv2 để nộp xác nhận trên DEV:\n",
62
+ "# trỏ tới answer.txt của exp07 (giữ nguyên 5 cột cảm xúc, chỉ thay QMOS).\n",
63
+ "# Để None nếu chỉ muốn chạy A/B nội bộ.\n",
64
+ "EXP07_ANSWER = None # ví dụ: \"/kaggle/input/exp07-answer/answer.txt\"\n",
65
+ "\n",
66
+ "def stem(p):\n",
67
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
68
+ "\n",
69
+ "for p in [WAV_DIR, TRAIN_CSV, DEV_SCP]:\n",
70
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "id": "c18d2e88",
76
+ "metadata": {},
77
+ "source": [
78
+ "## 1. Cài đặt (UTMOS + UTMOSv2)"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "c21a3cb0",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "import sys, subprocess\n",
89
+ "\n",
90
+ "def pip_install(*pkgs):\n",
91
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
92
+ "\n",
93
+ "pip_install(\"speechmos\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
94
+ "# UTMOSv2 (T05) — cài từ git, cần Internet On. Checkpoint tự tải lần đầu.\n",
95
+ "pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "id": "ad56ee86",
101
+ "metadata": {},
102
+ "source": [
103
+ "## 2. Nhãn qMOS vàng (gộp trung bình theo wav)"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "code",
108
+ "execution_count": null,
109
+ "id": "66c28d52",
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "import numpy as np\n",
114
+ "import pandas as pd\n",
115
+ "\n",
116
+ "def load_qmos_labels():\n",
117
+ " \"\"\"train.csv (sep '|') → dict {stem: qMOS trung bình theo wav}.\"\"\"\n",
118
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
119
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
120
+ " wav_col = cols.get(\"wavid\") or cols.get(\"wav\") or list(df.columns)[1]\n",
121
+ " qmos_col = cols.get(\"qmos\") or cols.get(\"mos\")\n",
122
+ " assert qmos_col, f\"Không thấy cột qMOS (cột: {list(df.columns)})\"\n",
123
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
124
+ " g = df.groupby(\"_stem\")[qmos_col].mean()\n",
125
+ " return {s: float(v) for s, v in g.items()}\n",
126
+ "\n",
127
+ "qmos_gold = load_qmos_labels()\n",
128
+ "print(f\"Số wav train có nhãn qMOS: {len(qmos_gold)}\")\n",
129
+ "\n",
130
+ "# Chọn mẫu probe (chỉ giữ wav thật sự tồn tại trên đĩa)\n",
131
+ "rng = np.random.default_rng(SEED)\n",
132
+ "all_stems = [s for s in qmos_gold if os.path.exists(os.path.join(WAV_DIR, s + \".wav\"))]\n",
133
+ "rng.shuffle(all_stems)\n",
134
+ "probe_stems = all_stems[:PROBE_N]\n",
135
+ "print(f\"Mẫu probe: {len(probe_stems)} / {len(all_stems)} wav tồn tại\")"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "markdown",
140
+ "id": "13d793d1",
141
+ "metadata": {},
142
+ "source": [
143
+ "## 3. Hàm chấm: UTMOS (cũ) và UTMOSv2 (mới) — đều cache .npz"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "id": "05dd3e90",
150
+ "metadata": {
151
+ "lines_to_next_cell": 1
152
+ },
153
+ "outputs": [],
154
+ "source": [
155
+ "import torch\n",
156
+ "from scipy.stats import spearmanr, pearsonr\n",
157
+ "\n",
158
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
159
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU\")\n",
160
+ "\n",
161
+ "def score_utmos(stems, tag):\n",
162
+ " \"\"\"UTMOS 2022 (torch.hub utmos22_strong). → dict {stem: score}. Cache.\"\"\"\n",
163
+ " import librosa\n",
164
+ " from tqdm.auto import tqdm\n",
165
+ " cache = os.path.join(CACHE_DIR, f\"utmos_{tag}.npz\")\n",
166
+ " store = {}\n",
167
+ " if os.path.exists(cache):\n",
168
+ " z = np.load(cache, allow_pickle=True)\n",
169
+ " store = {k: float(z[k]) for k in z.files}\n",
170
+ " print(f\"[utmos/{tag}] nạp cache: {len(store)}\")\n",
171
+ " todo = [s for s in stems if s not in store]\n",
172
+ " if todo:\n",
173
+ " predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\",\n",
174
+ " trust_repo=True).to(device).eval()\n",
175
+ " with torch.no_grad():\n",
176
+ " for i, s in enumerate(tqdm(todo, desc=f\"utmos {tag}\")):\n",
177
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
178
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
179
+ " store[s] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),\n",
180
+ " sr=16000).mean().item())\n",
181
+ " if (i + 1) % 500 == 0:\n",
182
+ " np.savez(cache, **{k: np.float32(v) for k, v in store.items()})\n",
183
+ " np.savez(cache, **{k: np.float32(v) for k, v in store.items()})\n",
184
+ " del predictor\n",
185
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
186
+ " return store\n",
187
+ "\n",
188
+ "def score_utmosv2(stems, tag):\n",
189
+ " \"\"\"UTMOSv2 / T05 (sarulab-speech/UTMOSv2). → dict {stem: score}. Cache.\"\"\"\n",
190
+ " from tqdm.auto import tqdm\n",
191
+ " cache = os.path.join(CACHE_DIR, f\"utmosv2_{tag}.npz\")\n",
192
+ " store = {}\n",
193
+ " if os.path.exists(cache):\n",
194
+ " z = np.load(cache, allow_pickle=True)\n",
195
+ " store = {k: float(z[k]) for k in z.files}\n",
196
+ " print(f\"[utmosv2/{tag}] nạp cache: {len(store)}\")\n",
197
+ " todo = [s for s in stems if s not in store]\n",
198
+ " if todo:\n",
199
+ " import utmosv2\n",
200
+ " model = utmosv2.create_model(pretrained=True) # ensemble, checkpoint tự tải\n",
201
+ " for i, s in enumerate(tqdm(todo, desc=f\"utmosv2 {tag}\")):\n",
202
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
203
+ " out = model.predict(input_path=wav)\n",
204
+ " # predict trả về float (hoặc dict có 'predicted_mos') tùy phiên bản\n",
205
+ " store[s] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
206
+ " if (i + 1) % 200 == 0:\n",
207
+ " np.savez(cache, **{k: np.float32(v) for k, v in store.items()})\n",
208
+ " np.savez(cache, **{k: np.float32(v) for k, v in store.items()})\n",
209
+ " del model\n",
210
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
211
+ " return store"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "markdown",
216
+ "id": "d04d250c",
217
+ "metadata": {},
218
+ "source": [
219
+ "## 4. Chạy A/B trên mẫu train → in SRCC mỗi model vs nhãn qMOS vàng"
220
+ ]
221
+ },
222
+ {
223
+ "cell_type": "code",
224
+ "execution_count": null,
225
+ "id": "eb9fe414",
226
+ "metadata": {
227
+ "lines_to_next_cell": 1
228
+ },
229
+ "outputs": [],
230
+ "source": [
231
+ "utmos_s = score_utmos(probe_stems, \"probe\")\n",
232
+ "utmosv2_s = score_utmosv2(probe_stems, \"probe\")\n",
233
+ "\n",
234
+ "# Chỉ so trên các stem cả 2 model đều chấm được (để công bằng)\n",
235
+ "common = [s for s in probe_stems if s in utmos_s and s in utmosv2_s and s in qmos_gold]\n",
236
+ "y_gold = np.array([qmos_gold[s] for s in common])\n",
237
+ "p_v1 = np.array([utmos_s[s] for s in common])\n",
238
+ "p_v2 = np.array([utmosv2_s[s] for s in common])\n",
239
+ "print(f\"\\nSố mẫu so sánh chung: {len(common)}\")\n",
240
+ "\n",
241
+ "srcc_v1 = spearmanr(p_v1, y_gold).correlation\n",
242
+ "srcc_v2 = spearmanr(p_v2, y_gold).correlation\n",
243
+ "lcc_v1 = pearsonr(p_v1, y_gold)[0]\n",
244
+ "lcc_v2 = pearsonr(p_v2, y_gold)[0]\n",
245
+ "\n",
246
+ "print(\"\\n📊 A/B trên TRAIN (nhãn qMOS vàng) — UTT-SRCC là metric chính:\")\n",
247
+ "print(f\" UTMOS 2022 (đang dùng) : SRCC = {srcc_v1:.4f} | LCC = {lcc_v1:.4f}\")\n",
248
+ "print(f\" UTMOSv2 / T05 (mới) : SRCC = {srcc_v2:.4f} | LCC = {lcc_v2:.4f}\")\n",
249
+ "delta = srcc_v2 - srcc_v1\n",
250
+ "if delta > 0.01:\n",
251
+ " print(f\" ✅ UTMOSv2 THẮNG (+{delta:.4f} SRCC) → đáng dùng làm neo cho exp09 / đổi cột QMOS.\")\n",
252
+ "elif delta < -0.01:\n",
253
+ " print(f\" ⚠️ UTMOSv2 THUA ({delta:.4f} SRCC) → giữ UTMOS; lệch domain cảm xúc quá mạnh.\")\n",
254
+ "else:\n",
255
+ " print(f\" ➖ Ngang nhau ({delta:+.4f}) → ưu tiên model nào tiện hơn; chốt bằng fine-tune.\")\n",
256
+ "\n",
257
+ "# Mốc tham chiếu leaderboard: UTMOS zero-shot DEV = 0.414; head QMOS exp07 = 0.548.\n",
258
+ "# (SRCC train ≠ SRCC dev nhưng cùng xu hướng → dùng để quyết hướng, không phải điểm nộp.)\n",
259
+ "print(\"\\nℹ️ Mốc leaderboard DEV để đối chiếu: UTMOS zero-shot 0.414 · head QMOS exp07 0.548.\")"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "markdown",
264
+ "id": "d808b9d6",
265
+ "metadata": {},
266
+ "source": [
267
+ "## 5. (Tùy chọn) Tạo answer.txt đổi cột QMOS←UTMOSv2 để nộp xác nhận DEV\n",
268
+ "Chỉ chạy nếu `EXP07_ANSWER` trỏ tới answer.txt exp07. Giữ nguyên 5 cột cảm xúc, chỉ thay QMOS."
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": null,
274
+ "id": "b08a39f5",
275
+ "metadata": {},
276
+ "outputs": [],
277
+ "source": [
278
+ "def build_swapped_answer(exp07_answer_path, out_path):\n",
279
+ " \"\"\"Đọc answer.txt exp07 (wav,QMOS,EMOS,CAT,VAL,ARO,DOM), thay QMOS = UTMOSv2(dev).\"\"\"\n",
280
+ " import csv\n",
281
+ " with open(DEV_SCP) as f:\n",
282
+ " dev_names = [ln.strip() for ln in f if ln.strip()]\n",
283
+ " dev_stems = [stem(n) for n in dev_names]\n",
284
+ " utmosv2_dev = score_utmosv2(dev_stems, \"dev\") # chấm DEV bằng UTMOSv2 (cache riêng)\n",
285
+ "\n",
286
+ " with open(exp07_answer_path) as f:\n",
287
+ " rows = list(csv.reader(f))\n",
288
+ " header, body = rows[0], rows[1:]\n",
289
+ " qi = header.index(\"QMOS\")\n",
290
+ " n_swap = 0\n",
291
+ " with open(out_path, \"w\") as f:\n",
292
+ " f.write(\",\".join(header) + \"\\n\")\n",
293
+ " for r in body:\n",
294
+ " sid = stem(r[0])\n",
295
+ " if sid in utmosv2_dev:\n",
296
+ " r[qi] = f\"{utmosv2_dev[sid]:.6g}\"\n",
297
+ " n_swap += 1\n",
298
+ " f.write(\",\".join(r) + \"\\n\")\n",
299
+ " print(f\"Ghi {len(body)} dòng → {out_path} | đổi QMOS được {n_swap} dòng\")\n",
300
+ " return out_path\n",
301
+ "\n",
302
+ "if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
303
+ " out = os.path.join(OUT_DIR, \"answer.txt\")\n",
304
+ " build_swapped_answer(EXP07_ANSWER, out)\n",
305
+ " os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp09a_utmosv2.zip answer.txt \"\n",
306
+ " f\"&& unzip -l submission_track2_exp09a_utmosv2.zip\")\n",
307
+ " print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp09a_utmosv2.zip\"))\n",
308
+ "else:\n",
309
+ " print(\"Bỏ qua mục 5 (EXP07_ANSWER=None hoặc không tồn tại). Chỉ chạy A/B nội bộ.\")"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "markdown",
314
+ "id": "adc8ba21",
315
+ "metadata": {},
316
+ "source": [
317
+ "## Ghi chú\n",
318
+ "- **Đọc kết quả mục 4:** UTMOSv2 SRCC có > UTMOS không?\n",
319
+ " - **Thắng rõ** → dùng UTMOSv2 làm **neo** cho `exp09` (fine-tune WavLM trên nhãn qMOS) thay UTMOS;\n",
320
+ " và/hoặc nộp answer.txt đổi cột (mục 5) để xác nhận trên leaderboard DEV.\n",
321
+ " - **Thua/ngang** → giữ UTMOS làm neo; kết luận \"UTMOSv2 vẫn lệch domain cảm xúc\" (phát hiện cho paper).\n",
322
+ "- **Gotcha Kaggle:** UTMOSv2 cài từ git + tải checkpoint → **Internet On**. Bản nộp Internet-off cần\n",
323
+ " pre-download weights thành Kaggle Dataset.\n",
324
+ "- UTMOSv2 là **ensemble nhiều fold** → chậm hơn UTMOS. Nếu lâu, giảm `PROBE_N` hoặc chấm dần (có cache).\n",
325
+ "- License: UTMOSv2 **MIT** · UTMOS BSD-3. Ghi vào `docs/12_system_description.md`.\n",
326
+ "- Ghi config → kết qu�� → nhận xét vào `docs/04_experiments_log.md` (mục exp09a)."
327
+ ]
328
+ }
329
+ ],
330
+ "metadata": {
331
+ "jupytext": {
332
+ "cell_metadata_filter": "-all",
333
+ "main_language": "python",
334
+ "notebook_metadata_filter": "-all"
335
+ }
336
+ },
337
+ "nbformat": 4,
338
+ "nbformat_minor": 5
339
+ }
track2/exp09a_qmos_utmosv2_probe_pipeline.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp09a (PROBE: UTMOSv2 vs UTMOS cho QMOS) — Kaggle
3
+ #
4
+ # **Mục đích (rẻ, KHÔNG tốn lượt nộp):** trước khi fine-tune QMOS, kiểm tra xem
5
+ # **UTMOSv2** (hệ thống **T05 — vô địch VoiceMOS Challenge 2024 Track 1**, naturalness MOS)
6
+ # có **mạnh hơn UTMOS 2022** (đang dùng) trên dữ liệu Track 2 hay không.
7
+ #
8
+ # ## Ý tưởng A/B không tốn lượt nộp
9
+ # Tập **train** Track 2 CÓ nhãn `qMOS` thật (`sets/train.csv`). Ta:
10
+ # 1. Chấm một mẫu train bằng **UTMOS** (torch.hub `utmos22_strong`) — baseline đang dùng.
11
+ # 2. Chấm cùng mẫu đó bằng **UTMOSv2** (`sarulab-speech/UTMOSv2`, MIT).
12
+ # 3. So **SRCC mỗi model vs nhãn qMOS vàng** → biết model nào "xếp hạng" giống người chấm hơn.
13
+ #
14
+ # > SRCC chấm **thứ hạng** (scale-invariant) → khỏi lo lệch thang điểm. Mẫu ~2.000 wav là đủ ổn định.
15
+ #
16
+ # ## Vì sao đáng thử
17
+ # - UTMOSv2 = #1 ở 7/16 metric VMC2024 Track 1 (bỏ xa hạng 3) → bản kế nhiệm trực tiếp của UTMOS.
18
+ # - **Lưu ý:** UTMOSv2 cũng train trên giọng *không* cảm xúc → vẫn có thể lệch domain; A/B này để
19
+ # biết nó có **đáng** làm "neo" mạnh hơn cho head QMOS fine-tune (exp09) hay không.
20
+ #
21
+ # **Cách chạy:** GPU T4 + **Internet On** (UTMOSv2 cài từ git + tải checkpoint) → Add Input dataset
22
+ # Track 2 → sửa `DATA_ROOT` → Run All. Lần đầu để `PROBE_N=300` cho nhanh, OK rồi tăng `2000`.
23
+
24
+ # %% [markdown]
25
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
26
+
27
+ # %%
28
+ import os
29
+
30
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
31
+ WAV_DIR = f"{DATA_ROOT}/wav"
32
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
33
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
34
+
35
+ OUT_DIR = "/kaggle/working"
36
+ CACHE_DIR = "/kaggle/working/qmos_probe_cache"
37
+ os.makedirs(CACHE_DIR, exist_ok=True)
38
+
39
+ DEVICE = "cuda"
40
+ PROBE_N = 2000 # số wav train để A/B (lần đầu để 300 cho nhanh). SRCC ~2000 mẫu đã ổn định.
41
+ SEED = 42
42
+
43
+ # (Tùy chọn) Nếu muốn TẠO LUÔN answer.txt đổi cột QMOS←UTMOSv2 để nộp xác nhận trên DEV:
44
+ # trỏ tới answer.txt của exp07 (giữ nguyên 5 cột cảm xúc, chỉ thay QMOS).
45
+ # Để None nếu chỉ muốn chạy A/B nội bộ.
46
+ EXP07_ANSWER = None # ví dụ: "/kaggle/input/exp07-answer/answer.txt"
47
+
48
+ def stem(p):
49
+ return os.path.splitext(os.path.basename(str(p)))[0]
50
+
51
+ for p in [WAV_DIR, TRAIN_CSV, DEV_SCP]:
52
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
53
+
54
+ # %% [markdown]
55
+ # ## 1. Cài đặt (UTMOS + UTMOSv2)
56
+
57
+ # %%
58
+ import sys, subprocess
59
+
60
+ def pip_install(*pkgs):
61
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
62
+
63
+ pip_install("speechmos", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
64
+ # UTMOSv2 (T05) — cài từ git, cần Internet On. Checkpoint tự tải lần đầu.
65
+ pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
66
+
67
+ # %% [markdown]
68
+ # ## 2. Nhãn qMOS vàng (gộp trung bình theo wav)
69
+
70
+ # %%
71
+ import numpy as np
72
+ import pandas as pd
73
+
74
+ def load_qmos_labels():
75
+ """train.csv (sep '|') → dict {stem: qMOS trung bình theo wav}."""
76
+ df = pd.read_csv(TRAIN_CSV, sep="|")
77
+ cols = {c.lower().strip(): c for c in df.columns}
78
+ wav_col = cols.get("wavid") or cols.get("wav") or list(df.columns)[1]
79
+ qmos_col = cols.get("qmos") or cols.get("mos")
80
+ assert qmos_col, f"Không thấy cột qMOS (cột: {list(df.columns)})"
81
+ df["_stem"] = df[wav_col].map(stem)
82
+ g = df.groupby("_stem")[qmos_col].mean()
83
+ return {s: float(v) for s, v in g.items()}
84
+
85
+ qmos_gold = load_qmos_labels()
86
+ print(f"Số wav train có nhãn qMOS: {len(qmos_gold)}")
87
+
88
+ # Chọn mẫu probe (chỉ giữ wav thật sự tồn tại trên đĩa)
89
+ rng = np.random.default_rng(SEED)
90
+ all_stems = [s for s in qmos_gold if os.path.exists(os.path.join(WAV_DIR, s + ".wav"))]
91
+ rng.shuffle(all_stems)
92
+ probe_stems = all_stems[:PROBE_N]
93
+ print(f"Mẫu probe: {len(probe_stems)} / {len(all_stems)} wav tồn tại")
94
+
95
+ # %% [markdown]
96
+ # ## 3. Hàm chấm: UTMOS (cũ) và UTMOSv2 (mới) — đều cache .npz
97
+
98
+ # %%
99
+ import torch
100
+ from scipy.stats import spearmanr, pearsonr
101
+
102
+ device = DEVICE if torch.cuda.is_available() else "cpu"
103
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU")
104
+
105
+ def score_utmos(stems, tag):
106
+ """UTMOS 2022 (torch.hub utmos22_strong). → dict {stem: score}. Cache."""
107
+ import librosa
108
+ from tqdm.auto import tqdm
109
+ cache = os.path.join(CACHE_DIR, f"utmos_{tag}.npz")
110
+ store = {}
111
+ if os.path.exists(cache):
112
+ z = np.load(cache, allow_pickle=True)
113
+ store = {k: float(z[k]) for k in z.files}
114
+ print(f"[utmos/{tag}] nạp cache: {len(store)}")
115
+ todo = [s for s in stems if s not in store]
116
+ if todo:
117
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong",
118
+ trust_repo=True).to(device).eval()
119
+ with torch.no_grad():
120
+ for i, s in enumerate(tqdm(todo, desc=f"utmos {tag}")):
121
+ wav = os.path.join(WAV_DIR, s + ".wav")
122
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
123
+ store[s] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),
124
+ sr=16000).mean().item())
125
+ if (i + 1) % 500 == 0:
126
+ np.savez(cache, **{k: np.float32(v) for k, v in store.items()})
127
+ np.savez(cache, **{k: np.float32(v) for k, v in store.items()})
128
+ del predictor
129
+ torch.cuda.empty_cache() if device == "cuda" else None
130
+ return store
131
+
132
+ def score_utmosv2(stems, tag):
133
+ """UTMOSv2 / T05 (sarulab-speech/UTMOSv2). → dict {stem: score}. Cache."""
134
+ from tqdm.auto import tqdm
135
+ cache = os.path.join(CACHE_DIR, f"utmosv2_{tag}.npz")
136
+ store = {}
137
+ if os.path.exists(cache):
138
+ z = np.load(cache, allow_pickle=True)
139
+ store = {k: float(z[k]) for k in z.files}
140
+ print(f"[utmosv2/{tag}] nạp cache: {len(store)}")
141
+ todo = [s for s in stems if s not in store]
142
+ if todo:
143
+ import utmosv2
144
+ model = utmosv2.create_model(pretrained=True) # ensemble, checkpoint tự tải
145
+ for i, s in enumerate(tqdm(todo, desc=f"utmosv2 {tag}")):
146
+ wav = os.path.join(WAV_DIR, s + ".wav")
147
+ out = model.predict(input_path=wav)
148
+ # predict trả về float (hoặc dict có 'predicted_mos') tùy phiên bản
149
+ store[s] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
150
+ if (i + 1) % 200 == 0:
151
+ np.savez(cache, **{k: np.float32(v) for k, v in store.items()})
152
+ np.savez(cache, **{k: np.float32(v) for k, v in store.items()})
153
+ del model
154
+ torch.cuda.empty_cache() if device == "cuda" else None
155
+ return store
156
+
157
+ # %% [markdown]
158
+ # ## 4. Chạy A/B trên mẫu train → in SRCC mỗi model vs nhãn qMOS vàng
159
+
160
+ # %%
161
+ utmos_s = score_utmos(probe_stems, "probe")
162
+ utmosv2_s = score_utmosv2(probe_stems, "probe")
163
+
164
+ # Chỉ so trên các stem cả 2 model đều chấm được (để công bằng)
165
+ common = [s for s in probe_stems if s in utmos_s and s in utmosv2_s and s in qmos_gold]
166
+ y_gold = np.array([qmos_gold[s] for s in common])
167
+ p_v1 = np.array([utmos_s[s] for s in common])
168
+ p_v2 = np.array([utmosv2_s[s] for s in common])
169
+ print(f"\nSố mẫu so sánh chung: {len(common)}")
170
+
171
+ srcc_v1 = spearmanr(p_v1, y_gold).correlation
172
+ srcc_v2 = spearmanr(p_v2, y_gold).correlation
173
+ lcc_v1 = pearsonr(p_v1, y_gold)[0]
174
+ lcc_v2 = pearsonr(p_v2, y_gold)[0]
175
+
176
+ print("\n📊 A/B trên TRAIN (nhãn qMOS vàng) — UTT-SRCC là metric chính:")
177
+ print(f" UTMOS 2022 (đang dùng) : SRCC = {srcc_v1:.4f} | LCC = {lcc_v1:.4f}")
178
+ print(f" UTMOSv2 / T05 (mới) : SRCC = {srcc_v2:.4f} | LCC = {lcc_v2:.4f}")
179
+ delta = srcc_v2 - srcc_v1
180
+ if delta > 0.01:
181
+ print(f" ✅ UTMOSv2 THẮNG (+{delta:.4f} SRCC) → đáng dùng làm neo cho exp09 / đổi cột QMOS.")
182
+ elif delta < -0.01:
183
+ print(f" ⚠️ UTMOSv2 THUA ({delta:.4f} SRCC) → giữ UTMOS; lệch domain cảm xúc quá mạnh.")
184
+ else:
185
+ print(f" ➖ Ngang nhau ({delta:+.4f}) → ưu tiên model nào tiện hơn; chốt bằng fine-tune.")
186
+
187
+ # Mốc tham chiếu leaderboard: UTMOS zero-shot DEV = 0.414; head QMOS exp07 = 0.548.
188
+ # (SRCC train ≠ SRCC dev nhưng cùng xu hướng → dùng để quyết hướng, không phải điểm nộp.)
189
+ print("\nℹ️ Mốc leaderboard DEV để đối chiếu: UTMOS zero-shot 0.414 · head QMOS exp07 0.548.")
190
+
191
+ # %% [markdown]
192
+ # ## 5. (Tùy chọn) Tạo answer.txt đổi cột QMOS←UTMOSv2 để nộp xác nhận DEV
193
+ # Chỉ chạy nếu `EXP07_ANSWER` trỏ tới answer.txt exp07. Giữ nguyên 5 cột cảm xúc, chỉ thay QMOS.
194
+
195
+ # %%
196
+ def build_swapped_answer(exp07_answer_path, out_path):
197
+ """Đọc answer.txt exp07 (wav,QMOS,EMOS,CAT,VAL,ARO,DOM), thay QMOS = UTMOSv2(dev)."""
198
+ import csv
199
+ with open(DEV_SCP) as f:
200
+ dev_names = [ln.strip() for ln in f if ln.strip()]
201
+ dev_stems = [stem(n) for n in dev_names]
202
+ utmosv2_dev = score_utmosv2(dev_stems, "dev") # chấm DEV bằng UTMOSv2 (cache riêng)
203
+
204
+ with open(exp07_answer_path) as f:
205
+ rows = list(csv.reader(f))
206
+ header, body = rows[0], rows[1:]
207
+ qi = header.index("QMOS")
208
+ n_swap = 0
209
+ with open(out_path, "w") as f:
210
+ f.write(",".join(header) + "\n")
211
+ for r in body:
212
+ sid = stem(r[0])
213
+ if sid in utmosv2_dev:
214
+ r[qi] = f"{utmosv2_dev[sid]:.6g}"
215
+ n_swap += 1
216
+ f.write(",".join(r) + "\n")
217
+ print(f"Ghi {len(body)} dòng → {out_path} | đổi QMOS được {n_swap} dòng")
218
+ return out_path
219
+
220
+ if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
221
+ out = os.path.join(OUT_DIR, "answer.txt")
222
+ build_swapped_answer(EXP07_ANSWER, out)
223
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp09a_utmosv2.zip answer.txt "
224
+ f"&& unzip -l submission_track2_exp09a_utmosv2.zip")
225
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp09a_utmosv2.zip"))
226
+ else:
227
+ print("Bỏ qua mục 5 (EXP07_ANSWER=None hoặc không tồn tại). Chỉ chạy A/B nội bộ.")
228
+
229
+ # %% [markdown]
230
+ # ## Ghi chú
231
+ # - **Đọc kết quả mục 4:** UTMOSv2 SRCC có > UTMOS không?
232
+ # - **Thắng rõ** → dùng UTMOSv2 làm **neo** cho `exp09` (fine-tune WavLM trên nhãn qMOS) thay UTMOS;
233
+ # và/hoặc nộp answer.txt đổi cột (mục 5) để xác nhận trên leaderboard DEV.
234
+ # - **Thua/ngang** → giữ UTMOS làm neo; kết luận "UTMOSv2 vẫn lệch domain cảm xúc" (phát hiện cho paper).
235
+ # - **Gotcha Kaggle:** UTMOSv2 cài từ git + tải checkpoint → **Internet On**. Bản nộp Internet-off cần
236
+ # pre-download weights thành Kaggle Dataset.
237
+ # - UTMOSv2 là **ensemble nhiều fold** → chậm hơn UTMOS. Nếu lâu, giảm `PROBE_N` hoặc chấm dần (có cache).
238
+ # - License: UTMOSv2 **MIT** · UTMOS BSD-3. Ghi vào `docs/12_system_description.md`.
239
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp09a).
track2/exp10_finetune_audeering.ipynb ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "678096c6",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp10 (fine-tune AUDEERING riêng + ensemble VAD với exp08) — Kaggle T4\n",
9
+ "\n",
10
+ "**Ý tưởng (Hướng A — an toàn cho T4):** thay vì nhồi 2 backbone large vào 1 model (dễ OOM),\n",
11
+ "ta fine-tune **audeering wav2vec2-large** RIÊNG (1 backbone → vừa T4), rồi **ensemble cột VAD**\n",
12
+ "với exp08 (WavLM fine-tune). Mỗi lần chỉ 1 backbone trong VRAM → không OOM.\n",
13
+ "\n",
14
+ "```\n",
15
+ " [exp08] WavLM fine-tune ─► VAD_wavlm ┐\n",
16
+ " ├─ trung bình ─► VAD cuối (mạnh hơn cả 2)\n",
17
+ " [exp10] audeering fine-tune ─► VAD_aud ┘\n",
18
+ "```\n",
19
+ "audeering vốn là model **dimensional (chuyên VAD)** → fine-tune nó để bổ trợ VAD cho exp08.\n",
20
+ "\n",
21
+ "**Cách chạy:** GPU T4 + Internet On → sửa slug cell 0 → Run All. Lần đầu `LIMIT_TRAIN=300`.\n",
22
+ "Để ensemble: Add Input answer.txt exp08 → trỏ `EXP08_ANSWER`."
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "markdown",
27
+ "id": "291f23e3",
28
+ "metadata": {},
29
+ "source": [
30
+ "## 0. Cấu hình"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "id": "41d619c3",
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "import os\n",
41
+ "\n",
42
+ "DATA_ROOT = \"/kaggle/input/datasets/minhtoan2/vmc2026-track2-full\" # << SỬA slug\n",
43
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
44
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\"\n",
45
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\"\n",
46
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
47
+ "\n",
48
+ "OUT_DIR = \"/kaggle/working\"\n",
49
+ "\n",
50
+ "# QMOS mượn exp07 (0.548); ensemble VAD với answer.txt exp08.\n",
51
+ "EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) mượn QMOS; không có → UTMOSv2\n",
52
+ "EXP08_ANSWER = \"/kaggle/input/exp08-answer/answer.txt\" # << (tùy chọn) để ENSEMBLE VAD; không có → chỉ ra answer audeering\n",
53
+ "\n",
54
+ "# ── Fine-tune audeering (1 backbone) ─────────────────────────────────────────\n",
55
+ "DEVICE = \"cuda\"\n",
56
+ "SR = 16000\n",
57
+ "MAX_SECONDS = 8\n",
58
+ "UNFREEZE_TOP_LAYERS = 6 # số lớp encoder audeering mở băng (T4 thừa sức 1 backbone)\n",
59
+ "TRUNK_HIDDEN = 512\n",
60
+ "HEAD_HIDDEN = 128\n",
61
+ "DROPOUT = 0.3\n",
62
+ "LR_BACKBONE = 1e-5\n",
63
+ "LR_HEAD = 1e-3\n",
64
+ "WEIGHT_DECAY = 1e-5\n",
65
+ "EPOCHS = 12\n",
66
+ "PATIENCE = 4\n",
67
+ "BATCH = 4\n",
68
+ "ACCUM = 8\n",
69
+ "VAL_FRAC = 0.10\n",
70
+ "SEED = 42\n",
71
+ "USE_AMP = True\n",
72
+ "USE_GRAD_CKPT = True\n",
73
+ "USE_UNCERTAINTY = True\n",
74
+ "\n",
75
+ "# Ensemble: cột nào lấy TRUNG BÌNH giữa exp08 và exp10; cột khác giữ từ exp08.\n",
76
+ "ENSEMBLE_COLS = [\"VAL\", \"ARO\", \"DOM\"] # audeering mạnh VAD → ensemble VAD. Thêm \"EMOS\" nếu muốn.\n",
77
+ "\n",
78
+ "LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
79
+ "LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
80
+ "\n",
81
+ "EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751}\n",
82
+ "\n",
83
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
84
+ "_EMO_ALIAS = {\n",
85
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
86
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
87
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
88
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
89
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
90
+ "}\n",
91
+ "\n",
92
+ "def norm_emotion(label):\n",
93
+ " key = str(label).strip().lower()\n",
94
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
95
+ "\n",
96
+ "def stem(p):\n",
97
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
98
+ "\n",
99
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
100
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
101
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "id": "0adb988b",
107
+ "metadata": {},
108
+ "source": [
109
+ "## 1. Cài đặt"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "id": "1713d69b",
116
+ "metadata": {},
117
+ "outputs": [],
118
+ "source": [
119
+ "import sys, subprocess\n",
120
+ "\n",
121
+ "def pip_install(*pkgs):\n",
122
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
123
+ "\n",
124
+ "pip_install(\"transformers\", \"huggingface_hub\", \"safetensors\", \"speechmos\",\n",
125
+ " \"librosa\", \"soundfile\", \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "id": "4fcb8b30",
131
+ "metadata": {},
132
+ "source": [
133
+ "## 2. Nạp audeering wav2vec2-large làm backbone FINE-TUNE\n",
134
+ "Nạp backbone tay (tránh lỗi subclass `Wav2Vec2PreTrainedModel` ở transformers mới) rồi mở băng lớp trên."
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "f0a39dab",
141
+ "metadata": {
142
+ "lines_to_next_cell": 1
143
+ },
144
+ "outputs": [],
145
+ "source": [
146
+ "import torch\n",
147
+ "import torch.nn as nn\n",
148
+ "import torch.nn.functional as F\n",
149
+ "import numpy as np\n",
150
+ "\n",
151
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
152
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
153
+ "\n",
154
+ "from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
155
+ "from huggingface_hub import hf_hub_download\n",
156
+ "\n",
157
+ "AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
158
+ "aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
159
+ "aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
160
+ "aud = Wav2Vec2Model(aud_cfg)\n",
161
+ "try:\n",
162
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
163
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
164
+ "except Exception:\n",
165
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
166
+ "bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
167
+ "miss, unexp = aud.load_state_dict(bb_sd, strict=False)\n",
168
+ "print(f\"audeering backbone: thiếu {len(miss)} / dư {len(unexp)} key (strict=False)\")\n",
169
+ "aud = aud.to(device)\n",
170
+ "AUD_DIM = int(aud.config.hidden_size)\n",
171
+ "\n",
172
+ "# Đóng băng tất cả, mở băng UNFREEZE_TOP_LAYERS lớp encoder trên cùng\n",
173
+ "for p in aud.parameters():\n",
174
+ " p.requires_grad = False\n",
175
+ "enc_layers = aud.encoder.layers\n",
176
+ "n_layers = len(enc_layers)\n",
177
+ "for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:\n",
178
+ " for p in layer.parameters():\n",
179
+ " p.requires_grad = True\n",
180
+ "n_train = sum(p.numel() for p in aud.parameters() if p.requires_grad)\n",
181
+ "print(f\"audeering: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {AUD_DIM})\")\n",
182
+ "\n",
183
+ "if USE_GRAD_CKPT:\n",
184
+ " aud.gradient_checkpointing_enable()\n",
185
+ " if hasattr(aud, \"enable_input_require_grads\"):\n",
186
+ " aud.enable_input_require_grads()\n",
187
+ "\n",
188
+ "def masked_mean(hidden, attn_mask):\n",
189
+ " if attn_mask is None:\n",
190
+ " return hidden.mean(dim=1)\n",
191
+ " try:\n",
192
+ " fm = aud._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
193
+ " except Exception:\n",
194
+ " return hidden.mean(dim=1)\n",
195
+ " fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
196
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
197
+ "\n",
198
+ "def aud_embed(input_values, attn_mask):\n",
199
+ " out = aud(input_values, attention_mask=attn_mask).last_hidden_state\n",
200
+ " return masked_mean(out, attn_mask)"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "markdown",
205
+ "id": "54e993b9",
206
+ "metadata": {},
207
+ "source": [
208
+ "## 3. Nhãn (gộp theo wavID) — như exp08"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "id": "46cc0e42",
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "import librosa\n",
219
+ "import pandas as pd\n",
220
+ "from tqdm.auto import tqdm\n",
221
+ "\n",
222
+ "def load_target_emotions():\n",
223
+ " tgt = {}\n",
224
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
225
+ " for ln in f:\n",
226
+ " parts = ln.strip().split(\"|\")\n",
227
+ " if len(parts) >= 2:\n",
228
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
229
+ " return tgt\n",
230
+ "\n",
231
+ "def _col(cols_map, *names, df=None, default_idx=None):\n",
232
+ " for n in names:\n",
233
+ " if n in cols_map:\n",
234
+ " return cols_map[n]\n",
235
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
236
+ "\n",
237
+ "def parse_emocat_votes(cell):\n",
238
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
239
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
240
+ " e = norm_emotion(tok)\n",
241
+ " if e in EMOTIONS5:\n",
242
+ " v[EMOTIONS5.index(e)] += 1.0\n",
243
+ " return v\n",
244
+ "\n",
245
+ "def load_train_labels():\n",
246
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
247
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
248
+ " wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
249
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
250
+ " val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
251
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
252
+ " assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
253
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
254
+ " rows = []\n",
255
+ " for sid, g in df.groupby(\"_stem\"):\n",
256
+ " rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
257
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
258
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
259
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
260
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
261
+ " if cat_col:\n",
262
+ " for cell in g[cat_col]:\n",
263
+ " votes += parse_emocat_votes(cell)\n",
264
+ " s = votes.sum()\n",
265
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
266
+ " for i in range(len(EMOTIONS5)):\n",
267
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
268
+ " rows.append(rec)\n",
269
+ " return pd.DataFrame(rows)\n",
270
+ "\n",
271
+ "target_map = load_target_emotions()\n",
272
+ "train_df = load_train_labels()\n",
273
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
274
+ "print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "markdown",
279
+ "id": "0e0bf0bb",
280
+ "metadata": {},
281
+ "source": [
282
+ "## 4. Dataset/loader (input_values qua audeering processor) + chuẩn hóa nhãn"
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": null,
288
+ "id": "65768ef9",
289
+ "metadata": {},
290
+ "outputs": [],
291
+ "source": [
292
+ "from torch.utils.data import Dataset, DataLoader\n",
293
+ "from sklearn.model_selection import train_test_split\n",
294
+ "\n",
295
+ "train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
296
+ "if LIMIT_TRAIN:\n",
297
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
298
+ "lab = train_df.set_index(\"wavID\")\n",
299
+ "\n",
300
+ "def _zfit(arr):\n",
301
+ " a = np.asarray(arr, dtype=np.float32)\n",
302
+ " return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
303
+ "\n",
304
+ "emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
305
+ "if HAS_VAD:\n",
306
+ " vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
307
+ " vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
308
+ "else:\n",
309
+ " vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
310
+ "\n",
311
+ "def onehot_target(tgt):\n",
312
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
313
+ " if tgt in EMOTIONS5:\n",
314
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
315
+ " return v\n",
316
+ "\n",
317
+ "def load_iv(sid):\n",
318
+ " \"\"\"Đọc wav → chuẩn hóa bằng audeering processor → input_values (1D float32).\"\"\"\n",
319
+ " p = os.path.join(WAV_DIR, sid if str(sid).endswith(\".wav\") else str(sid) + \".wav\")\n",
320
+ " if not os.path.exists(p):\n",
321
+ " return None\n",
322
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
323
+ " wave = wave[: MAX_SECONDS * SR]\n",
324
+ " iv = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
325
+ " return np.asarray(iv, dtype=np.float32)\n",
326
+ "\n",
327
+ "class AudDataset(Dataset):\n",
328
+ " def __init__(self, stems):\n",
329
+ " self.stems = [s for s in stems if load_iv(s) is not None]\n",
330
+ " def __len__(self):\n",
331
+ " return len(self.stems)\n",
332
+ " def __getitem__(self, i):\n",
333
+ " s = self.stems[i]\n",
334
+ " iv = load_iv(s)\n",
335
+ " emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
336
+ " if HAS_VAD:\n",
337
+ " vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
338
+ " else:\n",
339
+ " vad = np.zeros(3, dtype=np.float32)\n",
340
+ " cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
341
+ " return {\"iv\": iv, \"tgt\": onehot_target(target_map.get(s)),\n",
342
+ " \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
343
+ " \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
344
+ " \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
345
+ "\n",
346
+ "def collate(batch):\n",
347
+ " L = max(len(b[\"iv\"]) for b in batch)\n",
348
+ " ivs = np.zeros((len(batch), L), dtype=np.float32)\n",
349
+ " mask = np.zeros((len(batch), L), dtype=np.float32)\n",
350
+ " for i, b in enumerate(batch):\n",
351
+ " ivs[i, : len(b[\"iv\"])] = b[\"iv\"]; mask[i, : len(b[\"iv\"])] = 1.0\n",
352
+ " return {\n",
353
+ " \"input_values\": torch.from_numpy(ivs), \"attn_mask\": torch.from_numpy(mask).long(),\n",
354
+ " \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
355
+ " \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
356
+ " \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
357
+ " \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
358
+ " \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
359
+ " \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
360
+ " }\n",
361
+ "\n",
362
+ "ds = AudDataset(train_stems)\n",
363
+ "print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
364
+ "tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
365
+ "tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
366
+ "va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "markdown",
371
+ "id": "697d9ca3",
372
+ "metadata": {},
373
+ "source": [
374
+ "## 5. Heads + train loop (lưu ft_audeering_full.pt mỗi best)"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "id": "8fe7ec40",
381
+ "metadata": {
382
+ "lines_to_next_cell": 1
383
+ },
384
+ "outputs": [],
385
+ "source": [
386
+ "from scipy.stats import spearmanr\n",
387
+ "\n",
388
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
389
+ "N_EMO = len(EMOTIONS5)\n",
390
+ "\n",
391
+ "class EmoHeads(nn.Module):\n",
392
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
393
+ " super().__init__()\n",
394
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
395
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
396
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
397
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
398
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
399
+ " def forward(self, feat, tgt):\n",
400
+ " h = self.trunk(feat)\n",
401
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
402
+ "\n",
403
+ "heads = EmoHeads(AUD_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
404
+ "\n",
405
+ "TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
406
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
407
+ "bb_params = [p for p in aud.parameters() if p.requires_grad]\n",
408
+ "head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
409
+ "opt = torch.optim.AdamW([{\"params\": bb_params, \"lr\": LR_BACKBONE},\n",
410
+ " {\"params\": head_params, \"lr\": LR_HEAD}], weight_decay=WEIGHT_DECAY)\n",
411
+ "scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
412
+ "mse = nn.MSELoss()\n",
413
+ "\n",
414
+ "def soft_ce(logits, target_dist):\n",
415
+ " return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
416
+ "\n",
417
+ "def forward_batch(b):\n",
418
+ " feat = aud_embed(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
419
+ " return heads(feat, b[\"tgt\"].to(device))\n",
420
+ "\n",
421
+ "def compute_loss(emos_p, cat_l, vad_p, b):\n",
422
+ " L = {}\n",
423
+ " L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
424
+ " L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
425
+ " if HAS_VAD:\n",
426
+ " vt = b[\"vad\"].to(device)\n",
427
+ " L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
428
+ " else:\n",
429
+ " z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
430
+ " if USE_UNCERTAINTY:\n",
431
+ " return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
432
+ " return sum(L.values())\n",
433
+ "\n",
434
+ "@torch.no_grad()\n",
435
+ "def evaluate():\n",
436
+ " aud.eval(); heads.eval()\n",
437
+ " P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
438
+ " catP, catY = [], []\n",
439
+ " for b in va_loader:\n",
440
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
441
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
442
+ " P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
443
+ " vad_p = vad_p.float().cpu().numpy()\n",
444
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
445
+ " P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
446
+ " catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
447
+ " out = {}\n",
448
+ " for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
449
+ " out[t] = spearmanr(P[t], Y[t]).correlation\n",
450
+ " q = np.concatenate(catP); p = np.concatenate(catY)\n",
451
+ " out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
452
+ " return out\n",
453
+ "\n",
454
+ "def mean_srcc(m):\n",
455
+ " keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
456
+ " return float(np.mean([m[k] for k in keys]))\n",
457
+ "\n",
458
+ "CKPT_PATH = os.path.join(OUT_DIR, \"ft_audeering_full.pt\")\n",
459
+ "def save_full_ckpt(state, val_emos=float(\"nan\")):\n",
460
+ " torch.save({\"aud\": state[\"aud\"], \"heads\": state[\"heads\"],\n",
461
+ " \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
462
+ " \"AUD_DIM\": AUD_DIM, \"UNFREEZE_TOP_LAYERS\": UNFREEZE_TOP_LAYERS,\n",
463
+ " \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
464
+ "\n",
465
+ "best, best_state, bad = -1e9, None, 0\n",
466
+ "for ep in range(1, EPOCHS + 1):\n",
467
+ " aud.train(); heads.train()\n",
468
+ " opt.zero_grad(); run = 0.0; nb = 0\n",
469
+ " for step, b in enumerate(tqdm(tr_loader, desc=f\"epoch {ep}\")):\n",
470
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
471
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
472
+ " loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
473
+ " scaler.scale(loss).backward()\n",
474
+ " if (step + 1) % ACCUM == 0:\n",
475
+ " scaler.step(opt); scaler.update(); opt.zero_grad()\n",
476
+ " run += loss.item() * ACCUM; nb += 1\n",
477
+ " m = evaluate(); sc = mean_srcc(m)\n",
478
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
479
+ " print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
480
+ " if sc > best:\n",
481
+ " best = sc\n",
482
+ " best_state = {\"aud\": {k: v.cpu().clone() for k, v in aud.state_dict().items()},\n",
483
+ " \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
484
+ " save_full_ckpt(best_state, m[\"emos\"])\n",
485
+ " print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\")\n",
486
+ " bad = 0\n",
487
+ " else:\n",
488
+ " bad += 1\n",
489
+ " if bad >= PATIENCE:\n",
490
+ " print(f\"Early stop ở epoch {ep}.\"); break\n",
491
+ "\n",
492
+ "if best_state:\n",
493
+ " aud.load_state_dict(best_state[\"aud\"]); heads.load_state_dict(best_state[\"heads\"])\n",
494
+ "final = evaluate()\n",
495
+ "print(\"\\n✅ VAL (nội bộ) — exp10 (fine-tune audeering):\")\n",
496
+ "print(f\" EMOS={final['emos']:.4f}\", end=\"\")\n",
497
+ "if HAS_VAD:\n",
498
+ " print(f\" | VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} (exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
499
+ "else:\n",
500
+ " print()\n",
501
+ "print(f\" → so exp08: audeering {'mạnh' if HAS_VAD and final['val'] > EXP08['val'] else 'yếu/ngang'} ở VAL. \"\n",
502
+ " f\"Ensemble sẽ lấy trung bình 2 model.\")\n",
503
+ "save_full_ckpt(best_state if best_state else {\"aud\": aud.state_dict(), \"heads\": heads.state_dict()}, final[\"emos\"])\n",
504
+ "print(f\"✅ Đã lưu {CKPT_PATH}. NHỚ Save Version!\")"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "markdown",
509
+ "id": "4911ff48",
510
+ "metadata": {},
511
+ "source": [
512
+ "## 6. Dự đoán DEV → predictions + answer_audeering.txt"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": null,
518
+ "id": "4ed9a022",
519
+ "metadata": {},
520
+ "outputs": [],
521
+ "source": [
522
+ "def list_dev():\n",
523
+ " with open(DEV_SCP) as f:\n",
524
+ " return [ln.strip() for ln in f if ln.strip()]\n",
525
+ "\n",
526
+ "dev_names = list_dev()\n",
527
+ "if LIMIT_DEV:\n",
528
+ " dev_names = dev_names[:LIMIT_DEV]\n",
529
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
530
+ "\n",
531
+ "def load_exp07_qmos():\n",
532
+ " if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
533
+ " import csv\n",
534
+ " d = {}\n",
535
+ " with open(EXP07_ANSWER) as f:\n",
536
+ " for row in csv.DictReader(f):\n",
537
+ " d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
538
+ " print(f\"✅ Mượn QMOS từ exp07: {len(d)//2} wav\")\n",
539
+ " return d\n",
540
+ " return None\n",
541
+ "\n",
542
+ "qmos_map = load_exp07_qmos()\n",
543
+ "if qmos_map is None:\n",
544
+ " print(\"ℹ️ Không có exp07 → QMOS bằng UTMOSv2.\")\n",
545
+ " pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
546
+ " import utmosv2\n",
547
+ " v2 = utmosv2.create_model(pretrained=True)\n",
548
+ " qmos_map = {}\n",
549
+ " for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
550
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
551
+ " if os.path.exists(wav):\n",
552
+ " o = v2.predict(input_path=wav)\n",
553
+ " qmos_map[n] = float(o[\"predicted_mos\"]) if isinstance(o, dict) else float(o)\n",
554
+ " del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
555
+ "\n",
556
+ "@torch.no_grad()\n",
557
+ "def predict_emotion(sid):\n",
558
+ " iv = load_iv(sid)\n",
559
+ " if iv is None:\n",
560
+ " return None\n",
561
+ " aud.eval(); heads.eval()\n",
562
+ " ivt = torch.from_numpy(iv).unsqueeze(0).to(device)\n",
563
+ " am = torch.ones((1, len(iv)), dtype=torch.long, device=device)\n",
564
+ " tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
565
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
566
+ " feat = aud_embed(ivt, am)\n",
567
+ " emos_p, cat_l, vad_p = heads(feat, tgt)\n",
568
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
569
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
570
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
571
+ " return emos, cat5, vad3\n",
572
+ "\n",
573
+ "def fmt_cat(p5):\n",
574
+ " return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
575
+ "\n",
576
+ "dev_pred = {} # name -> (emos, cat5, vad3)\n",
577
+ "with open(os.path.join(OUT_DIR, \"answer_audeering.txt\"), \"w\") as f:\n",
578
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
579
+ " for name in tqdm(dev_names, desc=\"answer_aud\"):\n",
580
+ " sid = stem(name)\n",
581
+ " pr = predict_emotion(sid)\n",
582
+ " if pr is None:\n",
583
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])\n",
584
+ " else:\n",
585
+ " emos, cat5, vad3 = pr\n",
586
+ " dev_pred[name] = (emos, cat5, vad3)\n",
587
+ " qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
588
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
589
+ "print(\"Đã ghi answer_audeering.txt\")"
590
+ ]
591
+ },
592
+ {
593
+ "cell_type": "markdown",
594
+ "id": "72f8f313",
595
+ "metadata": {},
596
+ "source": [
597
+ "## 7. ENSEMBLE với exp08 → answer.txt cuối (trung bình cột VAD)\n",
598
+ "Lấy answer.txt exp08 làm nền; cột trong `ENSEMBLE_COLS` = trung bình (exp08 + exp10). Còn lại giữ exp08."
599
+ ]
600
+ },
601
+ {
602
+ "cell_type": "code",
603
+ "execution_count": null,
604
+ "id": "a33720f9",
605
+ "metadata": {
606
+ "lines_to_next_cell": 1
607
+ },
608
+ "outputs": [],
609
+ "source": [
610
+ "import csv\n",
611
+ "COL_IDX = {\"QMOS\": 1, \"EMOS\": 2, \"VAL\": 4, \"ARO\": 5, \"DOM\": 6} # vị trí cột trong answer.txt\n",
612
+ "AUD_VAL = {\"EMOS\": lambda p: p[0], \"VAL\": lambda p: p[2][0], \"ARO\": lambda p: p[2][1], \"DOM\": lambda p: p[2][2]}\n",
613
+ "\n",
614
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
615
+ "if EXP08_ANSWER and os.path.exists(EXP08_ANSWER):\n",
616
+ " with open(EXP08_ANSWER) as f:\n",
617
+ " rows = list(csv.reader(f))\n",
618
+ " header, body = rows[0], rows[1:]\n",
619
+ " n_ens = 0\n",
620
+ " with open(answer_path, \"w\") as f:\n",
621
+ " f.write(\",\".join(header) + \"\\n\")\n",
622
+ " for r in body:\n",
623
+ " name = r[0]; sid = stem(name)\n",
624
+ " pr = dev_pred.get(name) or dev_pred.get(sid)\n",
625
+ " if pr is not None:\n",
626
+ " for col in ENSEMBLE_COLS:\n",
627
+ " if col in COL_IDX and col in AUD_VAL:\n",
628
+ " v08 = float(r[COL_IDX[col]]); vaud = float(AUD_VAL[col](pr))\n",
629
+ " r[COL_IDX[col]] = f\"{0.5*(v08+vaud):.6g}\"\n",
630
+ " n_ens += 1\n",
631
+ " f.write(\",\".join(r) + \"\\n\")\n",
632
+ " print(f\"✅ Ensemble {ENSEMBLE_COLS}: {n_ens} dòng → {answer_path} (nền exp08 + trung bình audeering)\")\n",
633
+ "else:\n",
634
+ " print(\"ℹ️ Không có EXP08_ANSWER → answer.txt = answer_audeering.txt (chỉ audeering, chưa ensemble).\")\n",
635
+ " import shutil\n",
636
+ " shutil.copy(os.path.join(OUT_DIR, \"answer_audeering.txt\"), answer_path)"
637
+ ]
638
+ },
639
+ {
640
+ "cell_type": "markdown",
641
+ "id": "5089b7fa",
642
+ "metadata": {},
643
+ "source": [
644
+ "## 8. Validate + zip"
645
+ ]
646
+ },
647
+ {
648
+ "cell_type": "code",
649
+ "execution_count": null,
650
+ "id": "66272528",
651
+ "metadata": {},
652
+ "outputs": [],
653
+ "source": [
654
+ "def validate(path):\n",
655
+ " with open(path) as f:\n",
656
+ " rows = list(csv.reader(f))\n",
657
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0], \"Header sai\"\n",
658
+ " for i, r in enumerate(rows[1:], 2):\n",
659
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
660
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
661
+ "\n",
662
+ "validate(answer_path)\n",
663
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp10_ensemble.zip answer.txt && unzip -l submission_track2_exp10_ensemble.zip\")\n",
664
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp10_ensemble.zip\"))"
665
+ ]
666
+ },
667
+ {
668
+ "cell_type": "markdown",
669
+ "id": "994723d2",
670
+ "metadata": {},
671
+ "source": [
672
+ "## Ghi chú\n",
673
+ "- **Hướng A (T4-an toàn):** fine-tune audeering RIÊNG (1 backbone) → ensemble VAD với exp08 → KHÔNG OOM.\n",
674
+ "- **Đọc mục 5:** audeering VAL/ARO/DOM có ≥ exp08 không? Nếu ngang/hơn → ensemble đáng giá.\n",
675
+ "- **Ensemble (mục 7):** mặc định trung bình VAL/ARO/DOM. Thêm \"EMOS\" vào `ENSEMBLE_COLS` nếu audeering EMOS tốt.\n",
676
+ "- **Checkpoint:** lưu `ft_audeering_full.pt` mỗi best (kernel chết vẫn còn). Save Version sau khi xong.\n",
677
+ "- QMOS vẫn mượn exp07 (0.548). So sánh: nộp answer.txt ensemble vs exp08 thuần để xem ensemble có nhích VAD.\n",
678
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (exp10)."
679
+ ]
680
+ }
681
+ ],
682
+ "metadata": {
683
+ "jupytext": {
684
+ "cell_metadata_filter": "-all",
685
+ "main_language": "python",
686
+ "notebook_metadata_filter": "-all"
687
+ }
688
+ },
689
+ "nbformat": 4,
690
+ "nbformat_minor": 5
691
+ }
track2/exp10_finetune_audeering_pipeline.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp10 (fine-tune AUDEERING riêng + ensemble VAD với exp08) — Kaggle T4
3
+ #
4
+ # **Ý tưởng (Hướng A — an toàn cho T4):** thay vì nhồi 2 backbone large vào 1 model (dễ OOM),
5
+ # ta fine-tune **audeering wav2vec2-large** RIÊNG (1 backbone → vừa T4), rồi **ensemble cột VAD**
6
+ # với exp08 (WavLM fine-tune). Mỗi lần chỉ 1 backbone trong VRAM → không OOM.
7
+ #
8
+ # ```
9
+ # [exp08] WavLM fine-tune ─► VAD_wavlm ┐
10
+ # ├─ trung bình ─► VAD cuối (mạnh hơn cả 2)
11
+ # [exp10] audeering fine-tune ─► VAD_aud ┘
12
+ # ```
13
+ # audeering vốn là model **dimensional (chuyên VAD)** → fine-tune nó để bổ trợ VAD cho exp08.
14
+ #
15
+ # **Cách chạy:** GPU T4 + Internet On → sửa slug cell 0 → Run All. Lần đầu `LIMIT_TRAIN=300`.
16
+ # Để ensemble: Add Input answer.txt exp08 → trỏ `EXP08_ANSWER`.
17
+
18
+ # %% [markdown]
19
+ # ## 0. Cấu hình
20
+
21
+ # %%
22
+ import os
23
+
24
+ DATA_ROOT = "/kaggle/input/datasets/minhtoan2/vmc2026-track2-full" # << SỬA slug
25
+ WAV_DIR = f"{DATA_ROOT}/wav"
26
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv"
27
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv"
28
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
29
+
30
+ OUT_DIR = "/kaggle/working"
31
+
32
+ # QMOS mượn exp07 (0.548); ensemble VAD với answer.txt exp08.
33
+ EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) mượn QMOS; không có → UTMOSv2
34
+ EXP08_ANSWER = "/kaggle/input/exp08-answer/answer.txt" # << (tùy chọn) để ENSEMBLE VAD; không có → chỉ ra answer audeering
35
+
36
+ # ── Fine-tune audeering (1 backbone) ─────────────────────────────────────────
37
+ DEVICE = "cuda"
38
+ SR = 16000
39
+ MAX_SECONDS = 8
40
+ UNFREEZE_TOP_LAYERS = 6 # số lớp encoder audeering mở băng (T4 thừa sức 1 backbone)
41
+ TRUNK_HIDDEN = 512
42
+ HEAD_HIDDEN = 128
43
+ DROPOUT = 0.3
44
+ LR_BACKBONE = 1e-5
45
+ LR_HEAD = 1e-3
46
+ WEIGHT_DECAY = 1e-5
47
+ EPOCHS = 12
48
+ PATIENCE = 4
49
+ BATCH = 4
50
+ ACCUM = 8
51
+ VAL_FRAC = 0.10
52
+ SEED = 42
53
+ USE_AMP = True
54
+ USE_GRAD_CKPT = True
55
+ USE_UNCERTAINTY = True
56
+
57
+ # Ensemble: cột nào lấy TRUNG BÌNH giữa exp08 và exp10; cột khác giữ từ exp08.
58
+ ENSEMBLE_COLS = ["VAL", "ARO", "DOM"] # audeering mạnh VAD → ensemble VAD. Thêm "EMOS" nếu muốn.
59
+
60
+ LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
61
+ LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
62
+
63
+ EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751}
64
+
65
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
66
+ _EMO_ALIAS = {
67
+ "angry": "angry", "anger": "angry",
68
+ "happy": "happy", "happiness": "happy", "joy": "happy",
69
+ "neutral": "neutral", "calm": "neutral",
70
+ "sad": "sad", "sadness": "sad",
71
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
72
+ }
73
+
74
+ def norm_emotion(label):
75
+ key = str(label).strip().lower()
76
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
77
+
78
+ def stem(p):
79
+ return os.path.splitext(os.path.basename(str(p)))[0]
80
+
81
+ print("DATA_ROOT:", DATA_ROOT)
82
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
83
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
84
+
85
+ # %% [markdown]
86
+ # ## 1. Cài đặt
87
+
88
+ # %%
89
+ import sys, subprocess
90
+
91
+ def pip_install(*pkgs):
92
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
93
+
94
+ pip_install("transformers", "huggingface_hub", "safetensors", "speechmos",
95
+ "librosa", "soundfile", "scipy", "scikit-learn", "pandas", "tqdm")
96
+
97
+ # %% [markdown]
98
+ # ## 2. Nạp audeering wav2vec2-large làm backbone FINE-TUNE
99
+ # Nạp backbone tay (tránh lỗi subclass `Wav2Vec2PreTrainedModel` ở transformers mới) rồi mở băng lớp trên.
100
+
101
+ # %%
102
+ import torch
103
+ import torch.nn as nn
104
+ import torch.nn.functional as F
105
+ import numpy as np
106
+
107
+ device = DEVICE if torch.cuda.is_available() else "cpu"
108
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
109
+
110
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
111
+ from huggingface_hub import hf_hub_download
112
+
113
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
114
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
115
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
116
+ aud = Wav2Vec2Model(aud_cfg)
117
+ try:
118
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
119
+ hf_hub_download(AUD_NAME, "model.safetensors"))
120
+ except Exception:
121
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
122
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
123
+ miss, unexp = aud.load_state_dict(bb_sd, strict=False)
124
+ print(f"audeering backbone: thiếu {len(miss)} / dư {len(unexp)} key (strict=False)")
125
+ aud = aud.to(device)
126
+ AUD_DIM = int(aud.config.hidden_size)
127
+
128
+ # Đóng băng tất cả, mở băng UNFREEZE_TOP_LAYERS lớp encoder trên cùng
129
+ for p in aud.parameters():
130
+ p.requires_grad = False
131
+ enc_layers = aud.encoder.layers
132
+ n_layers = len(enc_layers)
133
+ for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:
134
+ for p in layer.parameters():
135
+ p.requires_grad = True
136
+ n_train = sum(p.numel() for p in aud.parameters() if p.requires_grad)
137
+ print(f"audeering: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {AUD_DIM})")
138
+
139
+ if USE_GRAD_CKPT:
140
+ aud.gradient_checkpointing_enable()
141
+ if hasattr(aud, "enable_input_require_grads"):
142
+ aud.enable_input_require_grads()
143
+
144
+ def masked_mean(hidden, attn_mask):
145
+ if attn_mask is None:
146
+ return hidden.mean(dim=1)
147
+ try:
148
+ fm = aud._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
149
+ except Exception:
150
+ return hidden.mean(dim=1)
151
+ fm = fm.unsqueeze(-1).to(hidden.dtype)
152
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
153
+
154
+ def aud_embed(input_values, attn_mask):
155
+ out = aud(input_values, attention_mask=attn_mask).last_hidden_state
156
+ return masked_mean(out, attn_mask)
157
+
158
+ # %% [markdown]
159
+ # ## 3. Nhãn (gộp theo wavID) — như exp08
160
+
161
+ # %%
162
+ import librosa
163
+ import pandas as pd
164
+ from tqdm.auto import tqdm
165
+
166
+ def load_target_emotions():
167
+ tgt = {}
168
+ with open(METADATA_CSV, encoding="utf-8") as f:
169
+ for ln in f:
170
+ parts = ln.strip().split("|")
171
+ if len(parts) >= 2:
172
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
173
+ return tgt
174
+
175
+ def _col(cols_map, *names, df=None, default_idx=None):
176
+ for n in names:
177
+ if n in cols_map:
178
+ return cols_map[n]
179
+ return list(df.columns)[default_idx] if default_idx is not None else None
180
+
181
+ def parse_emocat_votes(cell):
182
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
183
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
184
+ e = norm_emotion(tok)
185
+ if e in EMOTIONS5:
186
+ v[EMOTIONS5.index(e)] += 1.0
187
+ return v
188
+
189
+ def load_train_labels():
190
+ df = pd.read_csv(TRAIN_CSV, sep="|")
191
+ cols = {c.lower().strip(): c for c in df.columns}
192
+ wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
193
+ emos_col = _col(cols, "emos", "emo", "emomos")
194
+ val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
195
+ cat_col = _col(cols, "emocat", "cat", "emotion")
196
+ assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
197
+ df["_stem"] = df[wav_col].map(stem)
198
+ rows = []
199
+ for sid, g in df.groupby("_stem"):
200
+ rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
201
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
202
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
203
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
204
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
205
+ if cat_col:
206
+ for cell in g[cat_col]:
207
+ votes += parse_emocat_votes(cell)
208
+ s = votes.sum()
209
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
210
+ for i in range(len(EMOTIONS5)):
211
+ rec[f"cat{i}"] = float(cat[i])
212
+ rows.append(rec)
213
+ return pd.DataFrame(rows)
214
+
215
+ target_map = load_target_emotions()
216
+ train_df = load_train_labels()
217
+ HAS_VAD = bool(train_df["val"].notna().any())
218
+ print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
219
+
220
+ # %% [markdown]
221
+ # ## 4. Dataset/loader (input_values qua audeering processor) + chuẩn hóa nhãn
222
+
223
+ # %%
224
+ from torch.utils.data import Dataset, DataLoader
225
+ from sklearn.model_selection import train_test_split
226
+
227
+ train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
228
+ if LIMIT_TRAIN:
229
+ train_stems = train_stems[:LIMIT_TRAIN]
230
+ lab = train_df.set_index("wavID")
231
+
232
+ def _zfit(arr):
233
+ a = np.asarray(arr, dtype=np.float32)
234
+ return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
235
+
236
+ emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
237
+ if HAS_VAD:
238
+ vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], dtype=np.float32)
239
+ vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], dtype=np.float32)
240
+ else:
241
+ vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
242
+
243
+ def onehot_target(tgt):
244
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
245
+ if tgt in EMOTIONS5:
246
+ v[EMOTIONS5.index(tgt)] = 1.0
247
+ return v
248
+
249
+ def load_iv(sid):
250
+ """Đọc wav → chuẩn hóa bằng audeering processor → input_values (1D float32)."""
251
+ p = os.path.join(WAV_DIR, sid if str(sid).endswith(".wav") else str(sid) + ".wav")
252
+ if not os.path.exists(p):
253
+ return None
254
+ wave, _ = librosa.load(p, sr=SR, mono=True)
255
+ wave = wave[: MAX_SECONDS * SR]
256
+ iv = aud_proc(wave, sampling_rate=SR).input_values[0]
257
+ return np.asarray(iv, dtype=np.float32)
258
+
259
+ class AudDataset(Dataset):
260
+ def __init__(self, stems):
261
+ self.stems = [s for s in stems if load_iv(s) is not None]
262
+ def __len__(self):
263
+ return len(self.stems)
264
+ def __getitem__(self, i):
265
+ s = self.stems[i]
266
+ iv = load_iv(s)
267
+ emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
268
+ if HAS_VAD:
269
+ vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
270
+ else:
271
+ vad = np.zeros(3, dtype=np.float32)
272
+ cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
273
+ return {"iv": iv, "tgt": onehot_target(target_map.get(s)),
274
+ "emos": np.float32(emos), "vad": vad, "cat": cat,
275
+ "emos_raw": np.float32(lab.loc[s, "emos"]),
276
+ "vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
277
+
278
+ def collate(batch):
279
+ L = max(len(b["iv"]) for b in batch)
280
+ ivs = np.zeros((len(batch), L), dtype=np.float32)
281
+ mask = np.zeros((len(batch), L), dtype=np.float32)
282
+ for i, b in enumerate(batch):
283
+ ivs[i, : len(b["iv"])] = b["iv"]; mask[i, : len(b["iv"])] = 1.0
284
+ return {
285
+ "input_values": torch.from_numpy(ivs), "attn_mask": torch.from_numpy(mask).long(),
286
+ "tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
287
+ "emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
288
+ "vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
289
+ "cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
290
+ "emos_raw": np.stack([b["emos_raw"] for b in batch]),
291
+ "vad_raw": np.stack([b["vad_raw"] for b in batch]),
292
+ }
293
+
294
+ ds = AudDataset(train_stems)
295
+ print("Dataset hợp lệ:", len(ds), "wav")
296
+ tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
297
+ tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
298
+ va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
299
+
300
+ # %% [markdown]
301
+ # ## 5. Heads + train loop (lưu ft_audeering_full.pt mỗi best)
302
+
303
+ # %%
304
+ from scipy.stats import spearmanr
305
+
306
+ torch.manual_seed(SEED); np.random.seed(SEED)
307
+ N_EMO = len(EMOTIONS5)
308
+
309
+ class EmoHeads(nn.Module):
310
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
311
+ super().__init__()
312
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
313
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
314
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
315
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
316
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
317
+ def forward(self, feat, tgt):
318
+ h = self.trunk(feat)
319
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
320
+
321
+ heads = EmoHeads(AUD_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
322
+
323
+ TASKS = ["emos", "cat", "val", "aro", "dom"]
324
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
325
+ bb_params = [p for p in aud.parameters() if p.requires_grad]
326
+ head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
327
+ opt = torch.optim.AdamW([{"params": bb_params, "lr": LR_BACKBONE},
328
+ {"params": head_params, "lr": LR_HEAD}], weight_decay=WEIGHT_DECAY)
329
+ scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
330
+ mse = nn.MSELoss()
331
+
332
+ def soft_ce(logits, target_dist):
333
+ return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
334
+
335
+ def forward_batch(b):
336
+ feat = aud_embed(b["input_values"].to(device), b["attn_mask"].to(device))
337
+ return heads(feat, b["tgt"].to(device))
338
+
339
+ def compute_loss(emos_p, cat_l, vad_p, b):
340
+ L = {}
341
+ L["emos"] = mse(emos_p, b["emos"].to(device))
342
+ L["cat"] = soft_ce(cat_l, b["cat"].to(device))
343
+ if HAS_VAD:
344
+ vt = b["vad"].to(device)
345
+ L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
346
+ else:
347
+ z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
348
+ if USE_UNCERTAINTY:
349
+ return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
350
+ return sum(L.values())
351
+
352
+ @torch.no_grad()
353
+ def evaluate():
354
+ aud.eval(); heads.eval()
355
+ P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
356
+ catP, catY = [], []
357
+ for b in va_loader:
358
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
359
+ emos_p, cat_l, vad_p = forward_batch(b)
360
+ P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
361
+ vad_p = vad_p.float().cpu().numpy()
362
+ for j, t in enumerate(["val", "aro", "dom"]):
363
+ P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
364
+ catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
365
+ out = {}
366
+ for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
367
+ out[t] = spearmanr(P[t], Y[t]).correlation
368
+ q = np.concatenate(catP); p = np.concatenate(catY)
369
+ out["cat_err"] = float(np.abs(q - p).sum(1).mean())
370
+ return out
371
+
372
+ def mean_srcc(m):
373
+ keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
374
+ return float(np.mean([m[k] for k in keys]))
375
+
376
+ CKPT_PATH = os.path.join(OUT_DIR, "ft_audeering_full.pt")
377
+ def save_full_ckpt(state, val_emos=float("nan")):
378
+ torch.save({"aud": state["aud"], "heads": state["heads"],
379
+ "emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
380
+ "AUD_DIM": AUD_DIM, "UNFREEZE_TOP_LAYERS": UNFREEZE_TOP_LAYERS,
381
+ "val_emos": float(val_emos)}, CKPT_PATH)
382
+
383
+ best, best_state, bad = -1e9, None, 0
384
+ for ep in range(1, EPOCHS + 1):
385
+ aud.train(); heads.train()
386
+ opt.zero_grad(); run = 0.0; nb = 0
387
+ for step, b in enumerate(tqdm(tr_loader, desc=f"epoch {ep}")):
388
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
389
+ emos_p, cat_l, vad_p = forward_batch(b)
390
+ loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
391
+ scaler.scale(loss).backward()
392
+ if (step + 1) % ACCUM == 0:
393
+ scaler.step(opt); scaler.update(); opt.zero_grad()
394
+ run += loss.item() * ACCUM; nb += 1
395
+ m = evaluate(); sc = mean_srcc(m)
396
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
397
+ print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
398
+ if sc > best:
399
+ best = sc
400
+ best_state = {"aud": {k: v.cpu().clone() for k, v in aud.state_dict().items()},
401
+ "heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
402
+ save_full_ckpt(best_state, m["emos"])
403
+ print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})")
404
+ bad = 0
405
+ else:
406
+ bad += 1
407
+ if bad >= PATIENCE:
408
+ print(f"Early stop ở epoch {ep}."); break
409
+
410
+ if best_state:
411
+ aud.load_state_dict(best_state["aud"]); heads.load_state_dict(best_state["heads"])
412
+ final = evaluate()
413
+ print("\n✅ VAL (nội bộ) — exp10 (fine-tune audeering):")
414
+ print(f" EMOS={final['emos']:.4f}", end="")
415
+ if HAS_VAD:
416
+ print(f" | VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} (exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
417
+ else:
418
+ print()
419
+ print(f" → so exp08: audeering {'mạnh' if HAS_VAD and final['val'] > EXP08['val'] else 'yếu/ngang'} ở VAL. "
420
+ f"Ensemble sẽ lấy trung bình 2 model.")
421
+ save_full_ckpt(best_state if best_state else {"aud": aud.state_dict(), "heads": heads.state_dict()}, final["emos"])
422
+ print(f"✅ Đã lưu {CKPT_PATH}. NHỚ Save Version!")
423
+
424
+ # %% [markdown]
425
+ # ## 6. Dự đoán DEV → predictions + answer_audeering.txt
426
+
427
+ # %%
428
+ def list_dev():
429
+ with open(DEV_SCP) as f:
430
+ return [ln.strip() for ln in f if ln.strip()]
431
+
432
+ dev_names = list_dev()
433
+ if LIMIT_DEV:
434
+ dev_names = dev_names[:LIMIT_DEV]
435
+ print("DEV:", len(dev_names), "mẫu")
436
+
437
+ def load_exp07_qmos():
438
+ if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
439
+ import csv
440
+ d = {}
441
+ with open(EXP07_ANSWER) as f:
442
+ for row in csv.DictReader(f):
443
+ d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
444
+ print(f"✅ Mượn QMOS từ exp07: {len(d)//2} wav")
445
+ return d
446
+ return None
447
+
448
+ qmos_map = load_exp07_qmos()
449
+ if qmos_map is None:
450
+ print("ℹ️ Không có exp07 → QMOS bằng UTMOSv2.")
451
+ pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
452
+ import utmosv2
453
+ v2 = utmosv2.create_model(pretrained=True)
454
+ qmos_map = {}
455
+ for n in tqdm(dev_names, desc="UTMOSv2"):
456
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
457
+ if os.path.exists(wav):
458
+ o = v2.predict(input_path=wav)
459
+ qmos_map[n] = float(o["predicted_mos"]) if isinstance(o, dict) else float(o)
460
+ del v2; torch.cuda.empty_cache() if device == "cuda" else None
461
+
462
+ @torch.no_grad()
463
+ def predict_emotion(sid):
464
+ iv = load_iv(sid)
465
+ if iv is None:
466
+ return None
467
+ aud.eval(); heads.eval()
468
+ ivt = torch.from_numpy(iv).unsqueeze(0).to(device)
469
+ am = torch.ones((1, len(iv)), dtype=torch.long, device=device)
470
+ tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
471
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
472
+ feat = aud_embed(ivt, am)
473
+ emos_p, cat_l, vad_p = heads(feat, tgt)
474
+ emos = float(emos_p.item()) * emos_sd + emos_mu
475
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
476
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
477
+ return emos, cat5, vad3
478
+
479
+ def fmt_cat(p5):
480
+ return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
481
+
482
+ dev_pred = {} # name -> (emos, cat5, vad3)
483
+ with open(os.path.join(OUT_DIR, "answer_audeering.txt"), "w") as f:
484
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
485
+ for name in tqdm(dev_names, desc="answer_aud"):
486
+ sid = stem(name)
487
+ pr = predict_emotion(sid)
488
+ if pr is None:
489
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])
490
+ else:
491
+ emos, cat5, vad3 = pr
492
+ dev_pred[name] = (emos, cat5, vad3)
493
+ qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
494
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
495
+ print("Đã ghi answer_audeering.txt")
496
+
497
+ # %% [markdown]
498
+ # ## 7. ENSEMBLE với exp08 → answer.txt cuối (trung bình cột VAD)
499
+ # Lấy answer.txt exp08 làm nền; cột trong `ENSEMBLE_COLS` = trung bình (exp08 + exp10). Còn lại giữ exp08.
500
+
501
+ # %%
502
+ import csv
503
+ COL_IDX = {"QMOS": 1, "EMOS": 2, "VAL": 4, "ARO": 5, "DOM": 6} # vị trí cột trong answer.txt
504
+ AUD_VAL = {"EMOS": lambda p: p[0], "VAL": lambda p: p[2][0], "ARO": lambda p: p[2][1], "DOM": lambda p: p[2][2]}
505
+
506
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
507
+ if EXP08_ANSWER and os.path.exists(EXP08_ANSWER):
508
+ with open(EXP08_ANSWER) as f:
509
+ rows = list(csv.reader(f))
510
+ header, body = rows[0], rows[1:]
511
+ n_ens = 0
512
+ with open(answer_path, "w") as f:
513
+ f.write(",".join(header) + "\n")
514
+ for r in body:
515
+ name = r[0]; sid = stem(name)
516
+ pr = dev_pred.get(name) or dev_pred.get(sid)
517
+ if pr is not None:
518
+ for col in ENSEMBLE_COLS:
519
+ if col in COL_IDX and col in AUD_VAL:
520
+ v08 = float(r[COL_IDX[col]]); vaud = float(AUD_VAL[col](pr))
521
+ r[COL_IDX[col]] = f"{0.5*(v08+vaud):.6g}"
522
+ n_ens += 1
523
+ f.write(",".join(r) + "\n")
524
+ print(f"✅ Ensemble {ENSEMBLE_COLS}: {n_ens} dòng → {answer_path} (nền exp08 + trung bình audeering)")
525
+ else:
526
+ print("ℹ️ Không có EXP08_ANSWER → answer.txt = answer_audeering.txt (chỉ audeering, chưa ensemble).")
527
+ import shutil
528
+ shutil.copy(os.path.join(OUT_DIR, "answer_audeering.txt"), answer_path)
529
+
530
+ # %% [markdown]
531
+ # ## 8. Validate + zip
532
+
533
+ # %%
534
+ def validate(path):
535
+ with open(path) as f:
536
+ rows = list(csv.reader(f))
537
+ assert rows[0][0] == "wav" and "QMOS" in rows[0], "Header sai"
538
+ for i, r in enumerate(rows[1:], 2):
539
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
540
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
541
+
542
+ validate(answer_path)
543
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp10_ensemble.zip answer.txt && unzip -l submission_track2_exp10_ensemble.zip")
544
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp10_ensemble.zip"))
545
+
546
+ # %% [markdown]
547
+ # ## Ghi chú
548
+ # - **Hướng A (T4-an toàn):** fine-tune audeering RIÊNG (1 backbone) → ensemble VAD với exp08 → KHÔNG OOM.
549
+ # - **Đọc mục 5:** audeering VAL/ARO/DOM có ≥ exp08 không? Nếu ngang/hơn → ensemble đáng giá.
550
+ # - **Ensemble (mục 7):** mặc định trung bình VAL/ARO/DOM. Thêm "EMOS" vào `ENSEMBLE_COLS` nếu audeering EMOS tốt.
551
+ # - **Checkpoint:** lưu `ft_audeering_full.pt` mỗi best (kernel chết vẫn còn). Save Version sau khi xong.
552
+ # - QMOS vẫn mượn exp07 (0.548). So sánh: nộp answer.txt ensemble vs exp08 thuần để xem ensemble có nhích VAD.
553
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (exp10).
track2/exp11_finetune_joint.ipynb ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "a2dce1b4",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp11 (FINE-TUNE ĐỒNG THỜI WavLM + audeering, FUSION 1 model) — Kaggle T4\n",
9
+ "\n",
10
+ "**Khác exp08:** exp08 chỉ fine-tune WavLM, audeering **đóng băng** (frozen, cache). exp11 **MỞ BĂNG CẢ HAI**\n",
11
+ "backbone và fuse đặc trưng **trong cùng 1 model** → cả hai cùng học cho bài MOS cảm xúc 2026.\n",
12
+ "\n",
13
+ "```\n",
14
+ " wav ─┬─► WavLM-large (warm-start exp08, TRAINABLE: mở băng N lớp trên) ─► pool ─► emb_wavlm ┐\n",
15
+ " └─► audeering MSP (TRAINABLE: mở băng N lớp trên) ─► pool ─► [emb_aud(1024) | vad3] ──────┼─► TRUNK ─┬─► EMOS (+target)\n",
16
+ " ┘ ├─► CAT (5)\n",
17
+ " └─► VAD (3)\n",
18
+ " QMOS: KHÔNG train ở đây → mượn cột QMOS exp07 (0.548) hoặc UTMOSv2.\n",
19
+ "```\n",
20
+ "\n",
21
+ "## Vì sao \"feature fusion + fine-tune cả 2\" (khác ensemble exp10)\n",
22
+ "- **exp10 = ensemble:** 2 model RIÊNG → trung bình cột VAD ở mức answer. An toàn nhưng 2 model không \"nói chuyện\".\n",
23
+ "- **exp11 = fusion:** 1 model, 2 backbone fuse Ở TRONG → trunk học phối hợp cả hai góc nhìn (WavLM categorical +\n",
24
+ " audeering dimensional) → kỳ vọng mạnh hơn nếu không OOM/overfit.\n",
25
+ "\n",
26
+ "## ⚠️ ĐÁNH ĐỔI PHẢI BIẾT — đây là cấu hình NẶNG nhất (2 backbone large cùng có gradient)\n",
27
+ "- **Rủi ro OOM cao trên T4 (16GB).** Đã bật sẵn mọi cách giảm bộ nhớ: `BATCH=1` + grad-accum,\n",
28
+ " gradient-checkpointing CẢ 2 backbone, AMP fp16, `MAX_SECONDS=6`, mở băng ÍT lớp (mặc định 4 mỗi backbone).\n",
29
+ "- Nếu vẫn OOM: giảm `UNFREEZE_WAVLM`/`UNFREEZE_AUD` → 2, giảm `MAX_SECONDS` → 5, tăng `ACCUM`.\n",
30
+ "- **Chậm + đốt giờ GPU** (2 backbone forward+backward, không cache được). **LẦN ĐẦU BẮT BUỘC `LIMIT_TRAIN=300`,\n",
31
+ " `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.\n",
32
+ "- **Lưới an toàn:** đừng đốt lượt nộp — chỉ nộp khi exp11 thắng exp08 (0.811) TRÊN VAL NỘI BỘ.\n",
33
+ "\n",
34
+ "**Cách chạy:** GPU **T4** + Internet **On** → Add Input (data + checkpoint exp08 + [tùy chọn] answer exp07) →\n",
35
+ "sửa slug cell 0 → Run All. Ghi config→kết quả→nhận xét vào `docs/04_experiments_log.md` (exp11)."
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "id": "f6d884a7",
41
+ "metadata": {},
42
+ "source": [
43
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "id": "6eca25f7",
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "import os\n",
54
+ "\n",
55
+ "DATA_ROOT = \"/kaggle/input/datasets/minhtoan2/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
56
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
57
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
58
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
59
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
60
+ "\n",
61
+ "# ── Warm-start / RESUME: trỏ tới 1 trong 2 loại checkpoint ───────────────────\n",
62
+ "# • ft_emotion_full_20epoch.pt (exp08): có 'wavlm'+'heads' → WARM-START (audeering từ pretrained gốc).\n",
63
+ "# • ft_joint_full.pt (exp11): có thêm 'aud'+'aud_head' → RESUME ĐỦ (khôi phục cả 2 backbone đã fine-tune).\n",
64
+ "# Notebook TỰ nhận biết theo key trong checkpoint. Để \"\" nếu train WavLM từ SAILER trắng.\n",
65
+ "WARMSTART_CKPT = \"/kaggle/input/ft-joint-full/ft_joint_full.pt\" # << exp08 ckpt (warm-start) HOẶC exp11 ckpt (resume)\n",
66
+ "\n",
67
+ "# Mượn cột QMOS exp07 (0.548). Không có → UTMOSv2.\n",
68
+ "EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) answer.txt exp07; không có → UTMOSv2\n",
69
+ "\n",
70
+ "OUT_DIR = \"/kaggle/working\"\n",
71
+ "\n",
72
+ "# ── Fine-tune / siêu tham số (CẤU HÌNH NẶNG — đã tối ưu cho T4) ───────────────\n",
73
+ "DEVICE = \"cuda\"\n",
74
+ "SR = 16000\n",
75
+ "MAX_SECONDS = 6 # ↓ so exp08 (8) để tiết kiệm VRAM (2 backbone)\n",
76
+ "UNFREEZE_WAVLM = 4 # số lớp encoder WavLM mở băng (OOM → 2)\n",
77
+ "UNFREEZE_AUD = 4 # số lớp encoder audeering mở băng (OOM → 2)\n",
78
+ "TRUNK_HIDDEN = 512 # PHẢI khớp checkpoint exp08 nếu warm-start heads\n",
79
+ "HEAD_HIDDEN = 128 # PHẢI khớp checkpoint exp08\n",
80
+ "DROPOUT = 0.3\n",
81
+ "LR_BACKBONE = 1e-5 # LR chung cho 2 backbone\n",
82
+ "LR_HEAD = 1e-3\n",
83
+ "RESUME_LR_SCALE = 1.0 # <1.0 để GIẢM LR khi resume (vd 0.5 nếu val đã chững) — nhân vào cả 2 nhóm LR\n",
84
+ "WEIGHT_DECAY = 1e-5\n",
85
+ "EPOCHS = 12\n",
86
+ "PATIENCE = 4 # dừng khi val không lên; LUÔN giữ best\n",
87
+ "BATCH = 1 # ⚠️ 2 backbone large → batch nhỏ\n",
88
+ "ACCUM = 16 # effective batch = 16\n",
89
+ "VAL_FRAC = 0.10\n",
90
+ "SEED = 42\n",
91
+ "USE_AMP = True\n",
92
+ "USE_GRAD_CKPT = True\n",
93
+ "USE_UNCERTAINTY = True\n",
94
+ "\n",
95
+ "LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
96
+ "LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
97
+ "\n",
98
+ "EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751} # mốc để so\n",
99
+ "\n",
100
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
101
+ "_EMO_ALIAS = {\n",
102
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
103
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
104
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
105
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
106
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
107
+ "}\n",
108
+ "\n",
109
+ "def norm_emotion(label):\n",
110
+ " key = str(label).strip().lower()\n",
111
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
112
+ "\n",
113
+ "def stem(p):\n",
114
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
115
+ "\n",
116
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
117
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
118
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
119
+ "print((\" ✅ \" if (WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT)) else \" ⚠️ KHÔNG có \") + str(WARMSTART_CKPT)\n",
120
+ " + (\" → warm-start\" if (WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT)) else \" → train từ SAILER trắng\"))"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "markdown",
125
+ "id": "b15a7e01",
126
+ "metadata": {},
127
+ "source": [
128
+ "## 1. Cài đặt + tải code SAILER (dựng đúng kiến trúc WavLM)"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": null,
134
+ "id": "2aacc36b",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "import sys, subprocess\n",
139
+ "\n",
140
+ "def pip_install(*pkgs):\n",
141
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
142
+ "\n",
143
+ "pip_install(\"transformers\", \"huggingface_hub\", \"safetensors\", \"loralib\", \"speechbrain\",\n",
144
+ " \"speechmos\", \"librosa\", \"soundfile\", \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
145
+ "\n",
146
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
147
+ "if not os.path.exists(REPO_DIR):\n",
148
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
149
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
150
+ "if REPO_DIR not in sys.path:\n",
151
+ " sys.path.insert(0, REPO_DIR)"
152
+ ]
153
+ },
154
+ {
155
+ "cell_type": "markdown",
156
+ "id": "3021edeb",
157
+ "metadata": {},
158
+ "source": [
159
+ "## 2A. WavLM TRAINABLE (warm-start SAILER / checkpoint exp08)"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "d2459502",
166
+ "metadata": {
167
+ "lines_to_next_cell": 1
168
+ },
169
+ "outputs": [],
170
+ "source": [
171
+ "import torch\n",
172
+ "import torch.nn as nn\n",
173
+ "import torch.nn.functional as F\n",
174
+ "import numpy as np\n",
175
+ "\n",
176
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
177
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
178
+ "\n",
179
+ "# Nạp checkpoint exp08 (nếu có) — lấy cả 'wavlm', 'heads', thống kê chuẩn hóa\n",
180
+ "ckpt = None\n",
181
+ "if WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT):\n",
182
+ " ckpt = torch.load(WARMSTART_CKPT, map_location=\"cpu\", weights_only=False)\n",
183
+ " print(\"✅ Nạp checkpoint warm-start:\", WARMSTART_CKPT, \"| keys:\", list(ckpt.keys()))\n",
184
+ " if \"wavlm\" not in ckpt:\n",
185
+ " print(\" ⚠️ Checkpoint KHÔNG có 'wavlm' (chỉ heads?) → vẫn dựng WavLM từ SAILER, chỉ warm-start heads nếu khớp.\")\n",
186
+ "\n",
187
+ "def find_hf_backbone(module):\n",
188
+ " cands = []\n",
189
+ " for name, m in module.named_modules():\n",
190
+ " enc = getattr(m, \"encoder\", None)\n",
191
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
192
+ " and getattr(enc, \"layers\", None) is not None:\n",
193
+ " cands.append((name, m))\n",
194
+ " if not cands:\n",
195
+ " return None, None\n",
196
+ " cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
197
+ " return cands[0]\n",
198
+ "\n",
199
+ "wavlm = None\n",
200
+ "try:\n",
201
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
202
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
203
+ " name, wavlm = find_hf_backbone(_wrapper)\n",
204
+ " if wavlm is not None:\n",
205
+ " print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'\")\n",
206
+ "except Exception as e:\n",
207
+ " print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
208
+ "\n",
209
+ "if wavlm is None:\n",
210
+ " from transformers import WavLMModel\n",
211
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
212
+ " print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
213
+ "\n",
214
+ "wavlm = wavlm.to(device)\n",
215
+ "WAVLM_DIM = int(wavlm.config.hidden_size)\n",
216
+ "wavlm.config.layerdrop = 0.0 # ⚠️ tắt layerdrop khi dùng gradient-checkpointing (tránh CheckpointError)\n",
217
+ "\n",
218
+ "# Đè trọng số đã fine-tune từ checkpoint exp08 (nếu có)\n",
219
+ "if ckpt is not None and \"wavlm\" in ckpt:\n",
220
+ " miss, unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
221
+ " print(f\"🔁 load wavlm từ checkpoint exp08: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).\")\n",
222
+ "\n",
223
+ "# Đóng băng partial: chỉ mở UNFREEZE_WAVLM lớp trên\n",
224
+ "for p in wavlm.parameters():\n",
225
+ " p.requires_grad = False\n",
226
+ "_wl = wavlm.encoder.layers\n",
227
+ "for layer in _wl[max(0, len(_wl) - UNFREEZE_WAVLM):]:\n",
228
+ " for p in layer.parameters():\n",
229
+ " p.requires_grad = True\n",
230
+ "print(f\"WavLM: {len(_wl)} lớp · mở băng {min(UNFREEZE_WAVLM, len(_wl))} → \"\n",
231
+ " f\"{sum(p.numel() for p in wavlm.parameters() if p.requires_grad)/1e6:.1f}M param train (dim {WAVLM_DIM})\")\n",
232
+ "\n",
233
+ "if USE_GRAD_CKPT:\n",
234
+ " wavlm.gradient_checkpointing_enable()\n",
235
+ " if hasattr(wavlm, \"enable_input_require_grads\"):\n",
236
+ " wavlm.enable_input_require_grads()\n",
237
+ "\n",
238
+ "def masked_mean(hidden, attn_mask, model):\n",
239
+ " if attn_mask is None:\n",
240
+ " return hidden.mean(dim=1)\n",
241
+ " try:\n",
242
+ " fm = model._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
243
+ " except Exception:\n",
244
+ " return hidden.mean(dim=1)\n",
245
+ " fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
246
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
247
+ "\n",
248
+ "def wavlm_embed(input_values, attn_mask):\n",
249
+ " out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
250
+ " return masked_mean(out, attn_mask, wavlm)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "markdown",
255
+ "id": "20a0c88d",
256
+ "metadata": {},
257
+ "source": [
258
+ "## 2B. audeering TRAINABLE (mở băng — khác exp08 là frozen)\n",
259
+ "Nạp backbone tay + head dimensional gốc; mở băng `UNFREEZE_AUD` lớp trên. Đặc trưng fuse = [hidden(1024) | vad3]."
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "code",
264
+ "execution_count": null,
265
+ "id": "9360d566",
266
+ "metadata": {
267
+ "lines_to_next_cell": 1
268
+ },
269
+ "outputs": [],
270
+ "source": [
271
+ "from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
272
+ "from huggingface_hub import hf_hub_download\n",
273
+ "\n",
274
+ "AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
275
+ "aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
276
+ "aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
277
+ "aud = Wav2Vec2Model(aud_cfg)\n",
278
+ "try:\n",
279
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
280
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
281
+ "except Exception:\n",
282
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
283
+ "bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
284
+ "aud.load_state_dict(bb_sd, strict=False)\n",
285
+ "_hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
286
+ "aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd[\"classifier.out_proj.weight\"].shape[0]))\n",
287
+ "aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
288
+ "aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
289
+ "aud = aud.to(device); aud_head = aud_head.to(device)\n",
290
+ "aud.config.layerdrop = 0.0 # ⚠️ tắt layerdrop khi dùng gradient-checkpointing (tránh CheckpointError)\n",
291
+ "AUD_DIM = _hid + 3 # = 1027 (khớp exp08 để warm-start heads)\n",
292
+ "\n",
293
+ "# RESUME: nếu checkpoint là ft_joint_full.pt (có 'aud') → khôi phục audeering ĐÃ fine-tune (đè pretrained)\n",
294
+ "if ckpt is not None and \"aud\" in ckpt:\n",
295
+ " amiss, aunexp = aud.load_state_dict(ckpt[\"aud\"], strict=False)\n",
296
+ " print(f\"🔁 RESUME audeering từ checkpoint: thiếu {len(amiss)} / dư {len(aunexp)} key (kỳ vọng ~0).\")\n",
297
+ " if \"aud_head\" in ckpt:\n",
298
+ " aud_head.load_state_dict(ckpt[\"aud_head\"]); print(\"🔁 RESUME aud_head từ checkpoint.\")\n",
299
+ "else:\n",
300
+ " print(\"ℹ️ Checkpoint không có 'aud' → audeering khởi từ pretrained gốc (chế độ warm-start exp08).\")\n",
301
+ "\n",
302
+ "# Đóng băng partial audeering: mở UNFREEZE_AUD lớp trên + head dimensional luôn trainable\n",
303
+ "for p in aud.parameters():\n",
304
+ " p.requires_grad = False\n",
305
+ "_al = aud.encoder.layers\n",
306
+ "for layer in _al[max(0, len(_al) - UNFREEZE_AUD):]:\n",
307
+ " for p in layer.parameters():\n",
308
+ " p.requires_grad = True\n",
309
+ "for p in aud_head.parameters():\n",
310
+ " p.requires_grad = True\n",
311
+ "print(f\"audeering: {len(_al)} lớp · mở băng {min(UNFREEZE_AUD, len(_al))} → \"\n",
312
+ " f\"{sum(p.numel() for p in aud.parameters() if p.requires_grad)/1e6:.1f}M param train (hidden {_hid}, fuse dim {AUD_DIM})\")\n",
313
+ "\n",
314
+ "if USE_GRAD_CKPT:\n",
315
+ " aud.gradient_checkpointing_enable()\n",
316
+ " if hasattr(aud, \"enable_input_require_grads\"):\n",
317
+ " aud.enable_input_require_grads()\n",
318
+ "\n",
319
+ "def aud_embed(input_values, attn_mask):\n",
320
+ " \"\"\"Trả về [hidden(1024) | vad3] — vad3 từ head dimensional gốc, theo thứ tự VAL,ARO,DOM.\"\"\"\n",
321
+ " h = masked_mean(aud(input_values, attention_mask=attn_mask).last_hidden_state, attn_mask, aud)\n",
322
+ " out = aud_head(h) # [B,3] thứ tự gốc audeering: (arousal, dominance, valence)\n",
323
+ " vad = torch.stack([1 + 4 * out[:, 2], 1 + 4 * out[:, 0], 1 + 4 * out[:, 1]], dim=1) # → VAL,ARO,DOM\n",
324
+ " return torch.cat([h, vad], dim=1)"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "markdown",
329
+ "id": "7550bb8e",
330
+ "metadata": {},
331
+ "source": [
332
+ "## 3. Đọc & gộp nhãn theo wavID (như exp08)"
333
+ ]
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "execution_count": null,
338
+ "id": "86f83d8a",
339
+ "metadata": {},
340
+ "outputs": [],
341
+ "source": [
342
+ "import librosa\n",
343
+ "import pandas as pd\n",
344
+ "from tqdm.auto import tqdm\n",
345
+ "\n",
346
+ "def load_target_emotions():\n",
347
+ " tgt = {}\n",
348
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
349
+ " for ln in f:\n",
350
+ " parts = ln.strip().split(\"|\")\n",
351
+ " if len(parts) >= 2:\n",
352
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
353
+ " return tgt\n",
354
+ "\n",
355
+ "def _col(cols_map, *names, df=None, default_idx=None):\n",
356
+ " for n in names:\n",
357
+ " if n in cols_map:\n",
358
+ " return cols_map[n]\n",
359
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
360
+ "\n",
361
+ "def parse_emocat_votes(cell):\n",
362
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
363
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
364
+ " e = norm_emotion(tok)\n",
365
+ " if e in EMOTIONS5:\n",
366
+ " v[EMOTIONS5.index(e)] += 1.0\n",
367
+ " return v\n",
368
+ "\n",
369
+ "def load_train_labels():\n",
370
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
371
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
372
+ " wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
373
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
374
+ " val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
375
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
376
+ " assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
377
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
378
+ " rows = []\n",
379
+ " for sid, g in df.groupby(\"_stem\"):\n",
380
+ " rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
381
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
382
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
383
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
384
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
385
+ " if cat_col:\n",
386
+ " for cell in g[cat_col]:\n",
387
+ " votes += parse_emocat_votes(cell)\n",
388
+ " s = votes.sum()\n",
389
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
390
+ " for i in range(len(EMOTIONS5)):\n",
391
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
392
+ " rows.append(rec)\n",
393
+ " return pd.DataFrame(rows)\n",
394
+ "\n",
395
+ "target_map = load_target_emotions()\n",
396
+ "train_df = load_train_labels()\n",
397
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
398
+ "print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
399
+ ]
400
+ },
401
+ {
402
+ "cell_type": "markdown",
403
+ "id": "02e003af",
404
+ "metadata": {},
405
+ "source": [
406
+ "## 4. Dataset/loader — trả về CẢ raw wave (cho WavLM) + input_values audeering\n",
407
+ "Hai backbone cần đầu vào khác nhau: WavLM nhận wave thô; audeering nhận wave đã chuẩn hóa bởi processor.\n",
408
+ "Cùng độ dài → dùng chung attention mask theo mức sample."
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": null,
414
+ "id": "f91d8e80",
415
+ "metadata": {},
416
+ "outputs": [],
417
+ "source": [
418
+ "from torch.utils.data import Dataset, DataLoader\n",
419
+ "from sklearn.model_selection import train_test_split\n",
420
+ "\n",
421
+ "train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
422
+ "if LIMIT_TRAIN:\n",
423
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
424
+ "lab = train_df.set_index(\"wavID\")\n",
425
+ "\n",
426
+ "# Chuẩn hóa: lấy TỪ checkpoint nếu warm-start (để khớp head đã train); không thì fit từ data.\n",
427
+ "if ckpt is not None and \"emos_mu\" in ckpt:\n",
428
+ " emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
429
+ " vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
430
+ " print(f\"Chuẩn hóa TỪ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
431
+ "else:\n",
432
+ " def _zfit(a):\n",
433
+ " a = np.asarray(a, dtype=np.float32); return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
434
+ " emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
435
+ " if HAS_VAD:\n",
436
+ " vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], np.float32)\n",
437
+ " vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], np.float32)\n",
438
+ " else:\n",
439
+ " vad_mu = np.zeros(3, np.float32); vad_sd = np.ones(3, np.float32)\n",
440
+ " print(f\"Chuẩn hóa fit từ data: emos μ={emos_mu:.3f} σ={emos_sd:.3f}\")\n",
441
+ "\n",
442
+ "def onehot_target(tgt):\n",
443
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
444
+ " if tgt in EMOTIONS5:\n",
445
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
446
+ " return v\n",
447
+ "\n",
448
+ "def load_pair(sid):\n",
449
+ " \"\"\"Trả về (wave_thô, iv_audeering) cùng độ dài; None nếu thiếu file.\"\"\"\n",
450
+ " p = os.path.join(WAV_DIR, sid if str(sid).endswith(\".wav\") else str(sid) + \".wav\")\n",
451
+ " if not os.path.exists(p):\n",
452
+ " return None\n",
453
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
454
+ " wave = wave[: MAX_SECONDS * SR].astype(np.float32)\n",
455
+ " iv = np.asarray(aud_proc(wave, sampling_rate=SR).input_values[0], dtype=np.float32)\n",
456
+ " return wave, iv\n",
457
+ "\n",
458
+ "class JointDataset(Dataset):\n",
459
+ " def __init__(self, stems):\n",
460
+ " self.stems = [s for s in stems if load_pair(s) is not None]\n",
461
+ " def __len__(self):\n",
462
+ " return len(self.stems)\n",
463
+ " def __getitem__(self, i):\n",
464
+ " s = self.stems[i]\n",
465
+ " wave, iv = load_pair(s)\n",
466
+ " emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
467
+ " if HAS_VAD:\n",
468
+ " vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
469
+ " else:\n",
470
+ " vad = np.zeros(3, dtype=np.float32)\n",
471
+ " cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
472
+ " return {\"wave\": wave, \"iv\": iv, \"tgt\": onehot_target(target_map.get(s)),\n",
473
+ " \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
474
+ " \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
475
+ " \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
476
+ "\n",
477
+ "def collate(batch):\n",
478
+ " L = max(len(b[\"wave\"]) for b in batch)\n",
479
+ " waves = np.zeros((len(batch), L), dtype=np.float32)\n",
480
+ " ivs = np.zeros((len(batch), L), dtype=np.float32)\n",
481
+ " mask = np.zeros((len(batch), L), dtype=np.float32)\n",
482
+ " for i, b in enumerate(batch):\n",
483
+ " n = len(b[\"wave\"])\n",
484
+ " waves[i, :n] = b[\"wave\"]; ivs[i, :len(b[\"iv\"])] = b[\"iv\"]; mask[i, :n] = 1.0\n",
485
+ " return {\n",
486
+ " \"wave\": torch.from_numpy(waves), \"iv\": torch.from_numpy(ivs), \"attn_mask\": torch.from_numpy(mask).long(),\n",
487
+ " \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
488
+ " \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
489
+ " \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
490
+ " \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
491
+ " \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
492
+ " \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
493
+ " }\n",
494
+ "\n",
495
+ "ds = JointDataset(train_stems)\n",
496
+ "print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
497
+ "tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
498
+ "tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
499
+ "va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "markdown",
504
+ "id": "0f85c871",
505
+ "metadata": {},
506
+ "source": [
507
+ "## 5. Heads (warm-start exp08 nếu khớp) + optimizer 2 backbone + train loop"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "execution_count": null,
513
+ "id": "b0a71176",
514
+ "metadata": {
515
+ "lines_to_next_cell": 1
516
+ },
517
+ "outputs": [],
518
+ "source": [
519
+ "from scipy.stats import spearmanr\n",
520
+ "\n",
521
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
522
+ "N_EMO = len(EMOTIONS5)\n",
523
+ "TRUNK_IN = WAVLM_DIM + AUD_DIM\n",
524
+ "\n",
525
+ "class EmoHeads(nn.Module):\n",
526
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
527
+ " super().__init__()\n",
528
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
529
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
530
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
531
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
532
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
533
+ " def forward(self, feat, tgt):\n",
534
+ " h = self.trunk(feat)\n",
535
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
536
+ "\n",
537
+ "heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
538
+ "if ckpt is not None and \"heads\" in ckpt:\n",
539
+ " hmiss, hunexp = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
540
+ " if len(hmiss) == 0 and len(hunexp) == 0:\n",
541
+ " print(\"🔁 warm-start heads từ exp08: KHỚP hoàn toàn.\")\n",
542
+ " else:\n",
543
+ " print(f\"⚠️ heads exp08 lệch (thiếu {len(hmiss)}/dư {len(hunexp)}) → có thể TRUNK_IN khác. Heads init mới phần lệch.\")\n",
544
+ "print(f\"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM})\")\n",
545
+ "\n",
546
+ "TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
547
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
548
+ "bb_params = [p for p in wavlm.parameters() if p.requires_grad] + \\\n",
549
+ " [p for p in aud.parameters() if p.requires_grad] + list(aud_head.parameters())\n",
550
+ "head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
551
+ "opt = torch.optim.AdamW([{\"params\": bb_params, \"lr\": LR_BACKBONE * RESUME_LR_SCALE},\n",
552
+ " {\"params\": head_params, \"lr\": LR_HEAD * RESUME_LR_SCALE}], weight_decay=WEIGHT_DECAY)\n",
553
+ "scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
554
+ "mse = nn.MSELoss()\n",
555
+ "\n",
556
+ "def soft_ce(logits, target_dist):\n",
557
+ " return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
558
+ "\n",
559
+ "def forward_batch(b):\n",
560
+ " am = b[\"attn_mask\"].to(device)\n",
561
+ " fw = wavlm_embed(b[\"wave\"].to(device), am)\n",
562
+ " fa = aud_embed(b[\"iv\"].to(device), am)\n",
563
+ " return heads(torch.cat([fw, fa], dim=1), b[\"tgt\"].to(device))\n",
564
+ "\n",
565
+ "def compute_loss(emos_p, cat_l, vad_p, b):\n",
566
+ " L = {}\n",
567
+ " L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
568
+ " L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
569
+ " if HAS_VAD:\n",
570
+ " vt = b[\"vad\"].to(device)\n",
571
+ " L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
572
+ " else:\n",
573
+ " z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
574
+ " if USE_UNCERTAINTY:\n",
575
+ " return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
576
+ " return sum(L.values())\n",
577
+ "\n",
578
+ "def set_train(flag):\n",
579
+ " wavlm.train(flag); aud.train(flag); aud_head.train(flag); heads.train(flag)\n",
580
+ "\n",
581
+ "@torch.no_grad()\n",
582
+ "def evaluate():\n",
583
+ " set_train(False)\n",
584
+ " P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
585
+ " catP, catY = [], []\n",
586
+ " for b in va_loader:\n",
587
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
588
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
589
+ " P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
590
+ " vad_p = vad_p.float().cpu().numpy()\n",
591
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
592
+ " P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
593
+ " catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
594
+ " out = {}\n",
595
+ " for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
596
+ " out[t] = spearmanr(P[t], Y[t]).correlation\n",
597
+ " q = np.concatenate(catP); p = np.concatenate(catY)\n",
598
+ " out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
599
+ " return out\n",
600
+ "\n",
601
+ "def mean_srcc(m):\n",
602
+ " keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
603
+ " return float(np.mean([m[k] for k in keys]))\n",
604
+ "\n",
605
+ "def snapshot():\n",
606
+ " return {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
607
+ " \"aud\": {k: v.cpu().clone() for k, v in aud.state_dict().items()},\n",
608
+ " \"aud_head\": {k: v.cpu().clone() for k, v in aud_head.state_dict().items()},\n",
609
+ " \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
610
+ "\n",
611
+ "CKPT_PATH = os.path.join(OUT_DIR, \"ft_joint_full.pt\")\n",
612
+ "def save_full(state, val_emos=float(\"nan\")):\n",
613
+ " torch.save({**state, \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
614
+ " \"WAVLM_DIM\": WAVLM_DIM, \"AUD_DIM\": AUD_DIM,\n",
615
+ " \"UNFREEZE_WAVLM\": UNFREEZE_WAVLM, \"UNFREEZE_AUD\": UNFREEZE_AUD,\n",
616
+ " \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
617
+ "\n",
618
+ "# Init best từ trạng thái warm-start hiện tại → chỉ lưu nếu train tốt hơn\n",
619
+ "m0 = evaluate(); best = mean_srcc(m0); best_state = snapshot(); save_full(best_state, m0.get(\"emos\", float(\"nan\")))\n",
620
+ "print(f\"📍 Khởi điểm (warm-start): mean SRCC = {best:.4f} | \"\n",
621
+ " + \" \".join(f\"{k}={m0[k]:.3f}\" for k in ['emos','val','aro','dom'] if k in m0))\n",
622
+ "\n",
623
+ "bad = 0\n",
624
+ "for ep in range(1, EPOCHS + 1):\n",
625
+ " set_train(True)\n",
626
+ " opt.zero_grad(); run = 0.0; nb = 0\n",
627
+ " for step, b in enumerate(tqdm(tr_loader, desc=f\"epoch {ep}\")):\n",
628
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
629
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
630
+ " loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
631
+ " scaler.scale(loss).backward()\n",
632
+ " if (step + 1) % ACCUM == 0:\n",
633
+ " scaler.step(opt); scaler.update(); opt.zero_grad()\n",
634
+ " run += loss.item() * ACCUM; nb += 1\n",
635
+ " m = evaluate(); sc = mean_srcc(m)\n",
636
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
637
+ " print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
638
+ " if sc > best:\n",
639
+ " best = sc; best_state = snapshot(); save_full(best_state, m[\"emos\"])\n",
640
+ " print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\"); bad = 0\n",
641
+ " else:\n",
642
+ " bad += 1\n",
643
+ " if bad >= PATIENCE:\n",
644
+ " print(f\"Early stop ở epoch {ep}.\"); break\n",
645
+ "\n",
646
+ "# Nạp lại best\n",
647
+ "wavlm.load_state_dict(best_state[\"wavlm\"]); aud.load_state_dict(best_state[\"aud\"])\n",
648
+ "aud_head.load_state_dict(best_state[\"aud_head\"]); heads.load_state_dict(best_state[\"heads\"])\n",
649
+ "final = evaluate()\n",
650
+ "print(\"\\n✅ VAL (nội bộ) — exp11 (fine-tune CẢ 2 + fusion):\")\n",
651
+ "print(f\" EMOS={final['emos']:.4f} (exp08 {EXP08['emos']})\")\n",
652
+ "if HAS_VAD:\n",
653
+ " print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} \"\n",
654
+ " f\"(exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
655
+ "print(f\" mean SRCC: warm-start {mean_srcc(m0):.4f} → exp11 {mean_srcc(final):.4f} \"\n",
656
+ " + (\"🚀 cải thiện\" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else \"➖ không cải thiện\"))\n",
657
+ "save_full(best_state, final.get(\"emos\", float(\"nan\")))\n",
658
+ "print(\"Đã lưu FULL:\", CKPT_PATH, \"→ NHỚ Save Version!\")"
659
+ ]
660
+ },
661
+ {
662
+ "cell_type": "markdown",
663
+ "id": "dcb57395",
664
+ "metadata": {},
665
+ "source": [
666
+ "## 6. Dự đoán DEV → answer.txt (5 cột cảm xúc từ exp11; QMOS mượn exp07 / UTMOSv2)"
667
+ ]
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "execution_count": null,
672
+ "id": "fcae1d4c",
673
+ "metadata": {
674
+ "lines_to_next_cell": 1
675
+ },
676
+ "outputs": [],
677
+ "source": [
678
+ "def list_dev():\n",
679
+ " with open(DEV_SCP) as f:\n",
680
+ " return [ln.strip() for ln in f if ln.strip()]\n",
681
+ "\n",
682
+ "dev_names = list_dev()\n",
683
+ "if LIMIT_DEV:\n",
684
+ " dev_names = dev_names[:LIMIT_DEV]\n",
685
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
686
+ "\n",
687
+ "def load_exp07_qmos():\n",
688
+ " if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
689
+ " import csv\n",
690
+ " d = {}\n",
691
+ " with open(EXP07_ANSWER) as f:\n",
692
+ " for row in csv.DictReader(f):\n",
693
+ " d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
694
+ " print(f\"✅ Mượn QMOS exp07: {len(d)//2} wav\")\n",
695
+ " return d\n",
696
+ " return None\n",
697
+ "\n",
698
+ "qmos_map = load_exp07_qmos()\n",
699
+ "if qmos_map is None:\n",
700
+ " print(\"ℹ️ Không có exp07 → QMOS bằng UTMOSv2.\")\n",
701
+ " pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
702
+ " import utmosv2\n",
703
+ " v2 = utmosv2.create_model(pretrained=True)\n",
704
+ " qmos_map = {}\n",
705
+ " for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
706
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
707
+ " if os.path.exists(wav):\n",
708
+ " o = v2.predict(input_path=wav)\n",
709
+ " qmos_map[n] = float(o[\"predicted_mos\"]) if isinstance(o, dict) else float(o)\n",
710
+ " del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
711
+ "\n",
712
+ "@torch.no_grad()\n",
713
+ "def predict_emotion(sid):\n",
714
+ " pair = load_pair(sid)\n",
715
+ " if pair is None:\n",
716
+ " return None\n",
717
+ " wave, iv = pair\n",
718
+ " set_train(False)\n",
719
+ " w = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
720
+ " ivt = torch.from_numpy(iv).unsqueeze(0).to(device)\n",
721
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
722
+ " tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
723
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
724
+ " feat = torch.cat([wavlm_embed(w, am), aud_embed(ivt, am)], dim=1)\n",
725
+ " emos_p, cat_l, vad_p = heads(feat, tgt)\n",
726
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
727
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
728
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
729
+ " return emos, cat5, vad3\n",
730
+ "\n",
731
+ "def fmt_cat(p5):\n",
732
+ " return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
733
+ "\n",
734
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
735
+ "n_real = n_def = 0\n",
736
+ "with open(answer_path, \"w\") as f:\n",
737
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
738
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
739
+ " sid = stem(name)\n",
740
+ " pr = predict_emotion(sid)\n",
741
+ " if pr is None:\n",
742
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
743
+ " else:\n",
744
+ " emos, cat5, vad3 = pr; n_real += 1\n",
745
+ " qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
746
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
747
+ "print(f\"Ghi {len(dev_names)} dòng → {answer_path} | cảm xúc thật {n_real}, mặc định {n_def}\")"
748
+ ]
749
+ },
750
+ {
751
+ "cell_type": "markdown",
752
+ "id": "7b04afd9",
753
+ "metadata": {},
754
+ "source": [
755
+ "## 7. Validate + zip"
756
+ ]
757
+ },
758
+ {
759
+ "cell_type": "code",
760
+ "execution_count": null,
761
+ "id": "0dfcf6ef",
762
+ "metadata": {},
763
+ "outputs": [],
764
+ "source": [
765
+ "def validate(path):\n",
766
+ " import csv\n",
767
+ " with open(path) as f:\n",
768
+ " rows = list(csv.reader(f))\n",
769
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0], \"Header sai\"\n",
770
+ " for i, r in enumerate(rows[1:], 2):\n",
771
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
772
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
773
+ "\n",
774
+ "validate(answer_path)\n",
775
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp11_joint.zip answer.txt && unzip -l submission_track2_exp11_joint.zip\")\n",
776
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp11_joint.zip\"))"
777
+ ]
778
+ },
779
+ {
780
+ "cell_type": "markdown",
781
+ "id": "8b7adc9b",
782
+ "metadata": {},
783
+ "source": [
784
+ "## Ghi chú\n",
785
+ "- **exp11 = fine-tune CẢ WavLM + audeering, FUSION 1 model** (khác exp08 audeering frozen, khác exp10 ensemble).\n",
786
+ "- **Warm-start:** WavLM + heads từ `ft_emotion_full_20epoch.pt` (exp08) → bắt đầu từ điểm tốt; audeering từ\n",
787
+ " pretrained gốc, mở băng để học thêm. Khởi điểm = đúng exp08 → train chỉ có thể tốt lên (giữ best).\n",
788
+ "- **OOM:** đây là cấu hình nặng nhất. Nếu CUDA OOM → giảm `UNFREEZE_WAVLM`/`UNFREEZE_AUD` (4→2),\n",
789
+ " `MAX_SECONDS` (6→5), giữ `BATCH=1` + tăng `ACCUM`.\n",
790
+ "- **Checkpoint:** lưu `ft_joint_full.pt` mỗi best (đủ cả 2 backbone + heads) → kernel chết vẫn còn. Save Version!\n",
791
+ "- **QMOS** vẫn mượn exp07 (0.548). So sánh nộp: exp11 vs exp08(0.811) vs exp10(ensemble) → chọn bản tốt nhất.\n",
792
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (exp11)."
793
+ ]
794
+ }
795
+ ],
796
+ "metadata": {
797
+ "jupytext": {
798
+ "cell_metadata_filter": "-all",
799
+ "main_language": "python",
800
+ "notebook_metadata_filter": "-all"
801
+ }
802
+ },
803
+ "nbformat": 4,
804
+ "nbformat_minor": 5
805
+ }
track2/exp11_finetune_joint_pipeline.py ADDED
@@ -0,0 +1,665 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp11 (FINE-TUNE ĐỒNG THỜI WavLM + audeering, FUSION 1 model) — Kaggle T4
3
+ #
4
+ # **Khác exp08:** exp08 chỉ fine-tune WavLM, audeering **đóng băng** (frozen, cache). exp11 **MỞ BĂNG CẢ HAI**
5
+ # backbone và fuse đặc trưng **trong cùng 1 model** → cả hai cùng học cho bài MOS cảm xúc 2026.
6
+ #
7
+ # ```
8
+ # wav ─┬─► WavLM-large (warm-start exp08, TRAINABLE: mở băng N lớp trên) ─► pool ─► emb_wavlm ┐
9
+ # └─► audeering MSP (TRAINABLE: mở băng N lớp trên) ─► pool ─► [emb_aud(1024) | vad3] ──────┼─► TRUNK ─┬─► EMOS (+target)
10
+ # ┘ ├─► CAT (5)
11
+ # └─► VAD (3)
12
+ # QMOS: KHÔNG train ở đây → mượn cột QMOS exp07 (0.548) hoặc UTMOSv2.
13
+ # ```
14
+ #
15
+ # ## Vì sao "feature fusion + fine-tune cả 2" (khác ensemble exp10)
16
+ # - **exp10 = ensemble:** 2 model RIÊNG → trung bình cột VAD ở mức answer. An toàn nhưng 2 model không "nói chuyện".
17
+ # - **exp11 = fusion:** 1 model, 2 backbone fuse Ở TRONG → trunk học phối hợp cả hai góc nhìn (WavLM categorical +
18
+ # audeering dimensional) → kỳ vọng mạnh hơn nếu không OOM/overfit.
19
+ #
20
+ # ## ⚠️ ĐÁNH ĐỔI PHẢI BIẾT — đây là cấu hình NẶNG nhất (2 backbone large cùng có gradient)
21
+ # - **Rủi ro OOM cao trên T4 (16GB).** Đã bật sẵn mọi cách giảm bộ nhớ: `BATCH=1` + grad-accum,
22
+ # gradient-checkpointing CẢ 2 backbone, AMP fp16, `MAX_SECONDS=6`, mở băng ÍT lớp (mặc định 4 mỗi backbone).
23
+ # - Nếu vẫn OOM: giảm `UNFREEZE_WAVLM`/`UNFREEZE_AUD` → 2, giảm `MAX_SECONDS` → 5, tăng `ACCUM`.
24
+ # - **Chậm + đốt giờ GPU** (2 backbone forward+backward, không cache được). **LẦN ĐẦU BẮT BUỘC `LIMIT_TRAIN=300`,
25
+ # `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.
26
+ # - **Lưới an toàn:** đừng đốt lượt nộp — chỉ nộp khi exp11 thắng exp08 (0.811) TRÊN VAL NỘI BỘ.
27
+ #
28
+ # **Cách chạy:** GPU **T4** + Internet **On** → Add Input (data + checkpoint exp08 + [tùy chọn] answer exp07) →
29
+ # sửa slug cell 0 → Run All. Ghi config→kết quả→nhận xét vào `docs/04_experiments_log.md` (exp11).
30
+
31
+ # %% [markdown]
32
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
33
+
34
+ # %%
35
+ import os
36
+
37
+ DATA_ROOT = "/kaggle/input/datasets/minhtoan2/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
38
+ WAV_DIR = f"{DATA_ROOT}/wav"
39
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
40
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
41
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
42
+
43
+ # ── Warm-start / RESUME: trỏ tới 1 trong 2 loại checkpoint ───────────────────
44
+ # • ft_emotion_full_20epoch.pt (exp08): có 'wavlm'+'heads' → WARM-START (audeering từ pretrained gốc).
45
+ # • ft_joint_full.pt (exp11): có thêm 'aud'+'aud_head' → RESUME ĐỦ (khôi phục cả 2 backbone đã fine-tune).
46
+ # Notebook TỰ nhận biết theo key trong checkpoint. Để "" nếu train WavLM từ SAILER trắng.
47
+ WARMSTART_CKPT = "/kaggle/input/ft-joint-full/ft_joint_full.pt" # << exp08 ckpt (warm-start) HOẶC exp11 ckpt (resume)
48
+
49
+ # Mượn cột QMOS exp07 (0.548). Không có → UTMOSv2.
50
+ EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) answer.txt exp07; không có → UTMOSv2
51
+
52
+ OUT_DIR = "/kaggle/working"
53
+
54
+ # ── Fine-tune / siêu tham số (CẤU HÌNH NẶNG — đã tối ưu cho T4) ───────────────
55
+ DEVICE = "cuda"
56
+ SR = 16000
57
+ MAX_SECONDS = 6 # ↓ so exp08 (8) để tiết kiệm VRAM (2 backbone)
58
+ UNFREEZE_WAVLM = 4 # số lớp encoder WavLM mở băng (OOM → 2)
59
+ UNFREEZE_AUD = 4 # số lớp encoder audeering mở băng (OOM → 2)
60
+ TRUNK_HIDDEN = 512 # PHẢI khớp checkpoint exp08 nếu warm-start heads
61
+ HEAD_HIDDEN = 128 # PHẢI khớp checkpoint exp08
62
+ DROPOUT = 0.3
63
+ LR_BACKBONE = 1e-5 # LR chung cho 2 backbone
64
+ LR_HEAD = 1e-3
65
+ RESUME_LR_SCALE = 1.0 # <1.0 để GIẢM LR khi resume (vd 0.5 nếu val đã chững) — nhân vào cả 2 nhóm LR
66
+ WEIGHT_DECAY = 1e-5
67
+ EPOCHS = 12
68
+ PATIENCE = 4 # dừng khi val không lên; LUÔN giữ best
69
+ BATCH = 1 # ⚠️ 2 backbone large → batch nhỏ
70
+ ACCUM = 16 # effective batch = 16
71
+ VAL_FRAC = 0.10
72
+ SEED = 42
73
+ USE_AMP = True
74
+ USE_GRAD_CKPT = True
75
+ USE_UNCERTAINTY = True
76
+
77
+ LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
78
+ LIMIT_DEV = 20 # << LẦN ĐẦU 20; ch���y thật None
79
+
80
+ EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751} # mốc để so
81
+
82
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
83
+ _EMO_ALIAS = {
84
+ "angry": "angry", "anger": "angry",
85
+ "happy": "happy", "happiness": "happy", "joy": "happy",
86
+ "neutral": "neutral", "calm": "neutral",
87
+ "sad": "sad", "sadness": "sad",
88
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
89
+ }
90
+
91
+ def norm_emotion(label):
92
+ key = str(label).strip().lower()
93
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
94
+
95
+ def stem(p):
96
+ return os.path.splitext(os.path.basename(str(p)))[0]
97
+
98
+ print("DATA_ROOT:", DATA_ROOT)
99
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
100
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
101
+ print((" ✅ " if (WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT)) else " ⚠️ KHÔNG có ") + str(WARMSTART_CKPT)
102
+ + (" → warm-start" if (WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT)) else " → train từ SAILER trắng"))
103
+
104
+ # %% [markdown]
105
+ # ## 1. Cài đặt + tải code SAILER (dựng đúng kiến trúc WavLM)
106
+
107
+ # %%
108
+ import sys, subprocess
109
+
110
+ def pip_install(*pkgs):
111
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
112
+
113
+ pip_install("transformers", "huggingface_hub", "safetensors", "loralib", "speechbrain",
114
+ "speechmos", "librosa", "soundfile", "scipy", "scikit-learn", "pandas", "tqdm")
115
+
116
+ REPO_DIR = "/kaggle/working/vox-profile-release"
117
+ if not os.path.exists(REPO_DIR):
118
+ subprocess.run(["git", "clone", "--depth", "1",
119
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
120
+ if REPO_DIR not in sys.path:
121
+ sys.path.insert(0, REPO_DIR)
122
+
123
+ # %% [markdown]
124
+ # ## 2A. WavLM TRAINABLE (warm-start SAILER / checkpoint exp08)
125
+
126
+ # %%
127
+ import torch
128
+ import torch.nn as nn
129
+ import torch.nn.functional as F
130
+ import numpy as np
131
+
132
+ device = DEVICE if torch.cuda.is_available() else "cpu"
133
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
134
+
135
+ # Nạp checkpoint exp08 (nếu có) — lấy cả 'wavlm', 'heads', thống kê chuẩn hóa
136
+ ckpt = None
137
+ if WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT):
138
+ ckpt = torch.load(WARMSTART_CKPT, map_location="cpu", weights_only=False)
139
+ print("✅ Nạp checkpoint warm-start:", WARMSTART_CKPT, "| keys:", list(ckpt.keys()))
140
+ if "wavlm" not in ckpt:
141
+ print(" ⚠️ Checkpoint KHÔNG có 'wavlm' (chỉ heads?) → vẫn dựng WavLM từ SAILER, chỉ warm-start heads nếu khớp.")
142
+
143
+ def find_hf_backbone(module):
144
+ cands = []
145
+ for name, m in module.named_modules():
146
+ enc = getattr(m, "encoder", None)
147
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
148
+ and getattr(enc, "layers", None) is not None:
149
+ cands.append((name, m))
150
+ if not cands:
151
+ return None, None
152
+ cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
153
+ return cands[0]
154
+
155
+ wavlm = None
156
+ try:
157
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
158
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
159
+ name, wavlm = find_hf_backbone(_wrapper)
160
+ if wavlm is not None:
161
+ print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'")
162
+ except Exception as e:
163
+ print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
164
+
165
+ if wavlm is None:
166
+ from transformers import WavLMModel
167
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
168
+ print("ℹ️ Fallback: microsoft/wavlm-large.")
169
+
170
+ wavlm = wavlm.to(device)
171
+ WAVLM_DIM = int(wavlm.config.hidden_size)
172
+ wavlm.config.layerdrop = 0.0 # ⚠️ tắt layerdrop khi dùng gradient-checkpointing (tránh CheckpointError)
173
+
174
+ # Đè trọng số đã fine-tune từ checkpoint exp08 (nếu có)
175
+ if ckpt is not None and "wavlm" in ckpt:
176
+ miss, unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
177
+ print(f"🔁 load wavlm từ checkpoint exp08: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).")
178
+
179
+ # Đóng băng partial: chỉ mở UNFREEZE_WAVLM lớp trên
180
+ for p in wavlm.parameters():
181
+ p.requires_grad = False
182
+ _wl = wavlm.encoder.layers
183
+ for layer in _wl[max(0, len(_wl) - UNFREEZE_WAVLM):]:
184
+ for p in layer.parameters():
185
+ p.requires_grad = True
186
+ print(f"WavLM: {len(_wl)} lớp · mở băng {min(UNFREEZE_WAVLM, len(_wl))} → "
187
+ f"{sum(p.numel() for p in wavlm.parameters() if p.requires_grad)/1e6:.1f}M param train (dim {WAVLM_DIM})")
188
+
189
+ if USE_GRAD_CKPT:
190
+ wavlm.gradient_checkpointing_enable()
191
+ if hasattr(wavlm, "enable_input_require_grads"):
192
+ wavlm.enable_input_require_grads()
193
+
194
+ def masked_mean(hidden, attn_mask, model):
195
+ if attn_mask is None:
196
+ return hidden.mean(dim=1)
197
+ try:
198
+ fm = model._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
199
+ except Exception:
200
+ return hidden.mean(dim=1)
201
+ fm = fm.unsqueeze(-1).to(hidden.dtype)
202
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
203
+
204
+ def wavlm_embed(input_values, attn_mask):
205
+ out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
206
+ return masked_mean(out, attn_mask, wavlm)
207
+
208
+ # %% [markdown]
209
+ # ## 2B. audeering TRAINABLE (mở băng — khác exp08 là frozen)
210
+ # Nạp backbone tay + head dimensional gốc; mở băng `UNFREEZE_AUD` lớp trên. Đặc trưng fuse = [hidden(1024) | vad3].
211
+
212
+ # %%
213
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
214
+ from huggingface_hub import hf_hub_download
215
+
216
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
217
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
218
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
219
+ aud = Wav2Vec2Model(aud_cfg)
220
+ try:
221
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
222
+ hf_hub_download(AUD_NAME, "model.safetensors"))
223
+ except Exception:
224
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
225
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
226
+ aud.load_state_dict(bb_sd, strict=False)
227
+ _hid = _sd["classifier.dense.weight"].shape[0]
228
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd["classifier.out_proj.weight"].shape[0]))
229
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
230
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
231
+ aud = aud.to(device); aud_head = aud_head.to(device)
232
+ aud.config.layerdrop = 0.0 # ⚠️ tắt layerdrop khi dùng gradient-checkpointing (tránh CheckpointError)
233
+ AUD_DIM = _hid + 3 # = 1027 (khớp exp08 để warm-start heads)
234
+
235
+ # RESUME: nếu checkpoint là ft_joint_full.pt (có 'aud') → khôi phục audeering ĐÃ fine-tune (đè pretrained)
236
+ if ckpt is not None and "aud" in ckpt:
237
+ amiss, aunexp = aud.load_state_dict(ckpt["aud"], strict=False)
238
+ print(f"🔁 RESUME audeering từ checkpoint: thiếu {len(amiss)} / dư {len(aunexp)} key (kỳ vọng ~0).")
239
+ if "aud_head" in ckpt:
240
+ aud_head.load_state_dict(ckpt["aud_head"]); print("🔁 RESUME aud_head từ checkpoint.")
241
+ else:
242
+ print("ℹ️ Checkpoint không có 'aud' → audeering khởi từ pretrained gốc (chế độ warm-start exp08).")
243
+
244
+ # Đóng băng partial audeering: mở UNFREEZE_AUD lớp trên + head dimensional luôn trainable
245
+ for p in aud.parameters():
246
+ p.requires_grad = False
247
+ _al = aud.encoder.layers
248
+ for layer in _al[max(0, len(_al) - UNFREEZE_AUD):]:
249
+ for p in layer.parameters():
250
+ p.requires_grad = True
251
+ for p in aud_head.parameters():
252
+ p.requires_grad = True
253
+ print(f"audeering: {len(_al)} lớp · mở băng {min(UNFREEZE_AUD, len(_al))} → "
254
+ f"{sum(p.numel() for p in aud.parameters() if p.requires_grad)/1e6:.1f}M param train (hidden {_hid}, fuse dim {AUD_DIM})")
255
+
256
+ if USE_GRAD_CKPT:
257
+ aud.gradient_checkpointing_enable()
258
+ if hasattr(aud, "enable_input_require_grads"):
259
+ aud.enable_input_require_grads()
260
+
261
+ def aud_embed(input_values, attn_mask):
262
+ """Trả về [hidden(1024) | vad3] — vad3 từ head dimensional gốc, theo thứ tự VAL,ARO,DOM."""
263
+ h = masked_mean(aud(input_values, attention_mask=attn_mask).last_hidden_state, attn_mask, aud)
264
+ out = aud_head(h) # [B,3] thứ tự gốc audeering: (arousal, dominance, valence)
265
+ vad = torch.stack([1 + 4 * out[:, 2], 1 + 4 * out[:, 0], 1 + 4 * out[:, 1]], dim=1) # → VAL,ARO,DOM
266
+ return torch.cat([h, vad], dim=1)
267
+
268
+ # %% [markdown]
269
+ # ## 3. Đọc & gộp nhãn theo wavID (như exp08)
270
+
271
+ # %%
272
+ import librosa
273
+ import pandas as pd
274
+ from tqdm.auto import tqdm
275
+
276
+ def load_target_emotions():
277
+ tgt = {}
278
+ with open(METADATA_CSV, encoding="utf-8") as f:
279
+ for ln in f:
280
+ parts = ln.strip().split("|")
281
+ if len(parts) >= 2:
282
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
283
+ return tgt
284
+
285
+ def _col(cols_map, *names, df=None, default_idx=None):
286
+ for n in names:
287
+ if n in cols_map:
288
+ return cols_map[n]
289
+ return list(df.columns)[default_idx] if default_idx is not None else None
290
+
291
+ def parse_emocat_votes(cell):
292
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
293
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
294
+ e = norm_emotion(tok)
295
+ if e in EMOTIONS5:
296
+ v[EMOTIONS5.index(e)] += 1.0
297
+ return v
298
+
299
+ def load_train_labels():
300
+ df = pd.read_csv(TRAIN_CSV, sep="|")
301
+ cols = {c.lower().strip(): c for c in df.columns}
302
+ wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
303
+ emos_col = _col(cols, "emos", "emo", "emomos")
304
+ val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
305
+ cat_col = _col(cols, "emocat", "cat", "emotion")
306
+ assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
307
+ df["_stem"] = df[wav_col].map(stem)
308
+ rows = []
309
+ for sid, g in df.groupby("_stem"):
310
+ rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
311
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
312
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
313
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
314
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
315
+ if cat_col:
316
+ for cell in g[cat_col]:
317
+ votes += parse_emocat_votes(cell)
318
+ s = votes.sum()
319
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
320
+ for i in range(len(EMOTIONS5)):
321
+ rec[f"cat{i}"] = float(cat[i])
322
+ rows.append(rec)
323
+ return pd.DataFrame(rows)
324
+
325
+ target_map = load_target_emotions()
326
+ train_df = load_train_labels()
327
+ HAS_VAD = bool(train_df["val"].notna().any())
328
+ print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
329
+
330
+ # %% [markdown]
331
+ # ## 4. Dataset/loader — trả về CẢ raw wave (cho WavLM) + input_values audeering
332
+ # Hai backbone cần đầu vào khác nhau: WavLM nhận wave thô; audeering nhận wave đã chuẩn hóa bởi processor.
333
+ # Cùng độ dài → dùng chung attention mask theo mức sample.
334
+
335
+ # %%
336
+ from torch.utils.data import Dataset, DataLoader
337
+ from sklearn.model_selection import train_test_split
338
+
339
+ train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
340
+ if LIMIT_TRAIN:
341
+ train_stems = train_stems[:LIMIT_TRAIN]
342
+ lab = train_df.set_index("wavID")
343
+
344
+ # Chuẩn hóa: lấy TỪ checkpoint nếu warm-start (để khớp head đã train); không thì fit từ data.
345
+ if ckpt is not None and "emos_mu" in ckpt:
346
+ emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
347
+ vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
348
+ print(f"Chuẩn hóa TỪ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
349
+ else:
350
+ def _zfit(a):
351
+ a = np.asarray(a, dtype=np.float32); return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
352
+ emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
353
+ if HAS_VAD:
354
+ vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], np.float32)
355
+ vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], np.float32)
356
+ else:
357
+ vad_mu = np.zeros(3, np.float32); vad_sd = np.ones(3, np.float32)
358
+ print(f"Chuẩn hóa fit từ data: emos μ={emos_mu:.3f} σ={emos_sd:.3f}")
359
+
360
+ def onehot_target(tgt):
361
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
362
+ if tgt in EMOTIONS5:
363
+ v[EMOTIONS5.index(tgt)] = 1.0
364
+ return v
365
+
366
+ def load_pair(sid):
367
+ """Trả về (wave_thô, iv_audeering) cùng độ dài; None nếu thiếu file."""
368
+ p = os.path.join(WAV_DIR, sid if str(sid).endswith(".wav") else str(sid) + ".wav")
369
+ if not os.path.exists(p):
370
+ return None
371
+ wave, _ = librosa.load(p, sr=SR, mono=True)
372
+ wave = wave[: MAX_SECONDS * SR].astype(np.float32)
373
+ iv = np.asarray(aud_proc(wave, sampling_rate=SR).input_values[0], dtype=np.float32)
374
+ return wave, iv
375
+
376
+ class JointDataset(Dataset):
377
+ def __init__(self, stems):
378
+ self.stems = [s for s in stems if load_pair(s) is not None]
379
+ def __len__(self):
380
+ return len(self.stems)
381
+ def __getitem__(self, i):
382
+ s = self.stems[i]
383
+ wave, iv = load_pair(s)
384
+ emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
385
+ if HAS_VAD:
386
+ vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
387
+ else:
388
+ vad = np.zeros(3, dtype=np.float32)
389
+ cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
390
+ return {"wave": wave, "iv": iv, "tgt": onehot_target(target_map.get(s)),
391
+ "emos": np.float32(emos), "vad": vad, "cat": cat,
392
+ "emos_raw": np.float32(lab.loc[s, "emos"]),
393
+ "vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
394
+
395
+ def collate(batch):
396
+ L = max(len(b["wave"]) for b in batch)
397
+ waves = np.zeros((len(batch), L), dtype=np.float32)
398
+ ivs = np.zeros((len(batch), L), dtype=np.float32)
399
+ mask = np.zeros((len(batch), L), dtype=np.float32)
400
+ for i, b in enumerate(batch):
401
+ n = len(b["wave"])
402
+ waves[i, :n] = b["wave"]; ivs[i, :len(b["iv"])] = b["iv"]; mask[i, :n] = 1.0
403
+ return {
404
+ "wave": torch.from_numpy(waves), "iv": torch.from_numpy(ivs), "attn_mask": torch.from_numpy(mask).long(),
405
+ "tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
406
+ "emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
407
+ "vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
408
+ "cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
409
+ "emos_raw": np.stack([b["emos_raw"] for b in batch]),
410
+ "vad_raw": np.stack([b["vad_raw"] for b in batch]),
411
+ }
412
+
413
+ ds = JointDataset(train_stems)
414
+ print("Dataset hợp lệ:", len(ds), "wav")
415
+ tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
416
+ tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
417
+ va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
418
+
419
+ # %% [markdown]
420
+ # ## 5. Heads (warm-start exp08 nếu khớp) + optimizer 2 backbone + train loop
421
+
422
+ # %%
423
+ from scipy.stats import spearmanr
424
+
425
+ torch.manual_seed(SEED); np.random.seed(SEED)
426
+ N_EMO = len(EMOTIONS5)
427
+ TRUNK_IN = WAVLM_DIM + AUD_DIM
428
+
429
+ class EmoHeads(nn.Module):
430
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
431
+ super().__init__()
432
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
433
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
434
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
435
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
436
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
437
+ def forward(self, feat, tgt):
438
+ h = self.trunk(feat)
439
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
440
+
441
+ heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
442
+ if ckpt is not None and "heads" in ckpt:
443
+ hmiss, hunexp = heads.load_state_dict(ckpt["heads"], strict=False)
444
+ if len(hmiss) == 0 and len(hunexp) == 0:
445
+ print("🔁 warm-start heads từ exp08: KHỚP hoàn toàn.")
446
+ else:
447
+ print(f"⚠️ heads exp08 lệch (thiếu {len(hmiss)}/dư {len(hunexp)}) → có thể TRUNK_IN khác. Heads init mới phần lệch.")
448
+ print(f"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM})")
449
+
450
+ TASKS = ["emos", "cat", "val", "aro", "dom"]
451
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
452
+ bb_params = [p for p in wavlm.parameters() if p.requires_grad] + \
453
+ [p for p in aud.parameters() if p.requires_grad] + list(aud_head.parameters())
454
+ head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
455
+ opt = torch.optim.AdamW([{"params": bb_params, "lr": LR_BACKBONE * RESUME_LR_SCALE},
456
+ {"params": head_params, "lr": LR_HEAD * RESUME_LR_SCALE}], weight_decay=WEIGHT_DECAY)
457
+ scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
458
+ mse = nn.MSELoss()
459
+
460
+ def soft_ce(logits, target_dist):
461
+ return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
462
+
463
+ def forward_batch(b):
464
+ am = b["attn_mask"].to(device)
465
+ fw = wavlm_embed(b["wave"].to(device), am)
466
+ fa = aud_embed(b["iv"].to(device), am)
467
+ return heads(torch.cat([fw, fa], dim=1), b["tgt"].to(device))
468
+
469
+ def compute_loss(emos_p, cat_l, vad_p, b):
470
+ L = {}
471
+ L["emos"] = mse(emos_p, b["emos"].to(device))
472
+ L["cat"] = soft_ce(cat_l, b["cat"].to(device))
473
+ if HAS_VAD:
474
+ vt = b["vad"].to(device)
475
+ L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
476
+ else:
477
+ z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
478
+ if USE_UNCERTAINTY:
479
+ return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
480
+ return sum(L.values())
481
+
482
+ def set_train(flag):
483
+ wavlm.train(flag); aud.train(flag); aud_head.train(flag); heads.train(flag)
484
+
485
+ @torch.no_grad()
486
+ def evaluate():
487
+ set_train(False)
488
+ P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
489
+ catP, catY = [], []
490
+ for b in va_loader:
491
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
492
+ emos_p, cat_l, vad_p = forward_batch(b)
493
+ P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
494
+ vad_p = vad_p.float().cpu().numpy()
495
+ for j, t in enumerate(["val", "aro", "dom"]):
496
+ P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
497
+ catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
498
+ out = {}
499
+ for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
500
+ out[t] = spearmanr(P[t], Y[t]).correlation
501
+ q = np.concatenate(catP); p = np.concatenate(catY)
502
+ out["cat_err"] = float(np.abs(q - p).sum(1).mean())
503
+ return out
504
+
505
+ def mean_srcc(m):
506
+ keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
507
+ return float(np.mean([m[k] for k in keys]))
508
+
509
+ def snapshot():
510
+ return {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
511
+ "aud": {k: v.cpu().clone() for k, v in aud.state_dict().items()},
512
+ "aud_head": {k: v.cpu().clone() for k, v in aud_head.state_dict().items()},
513
+ "heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
514
+
515
+ CKPT_PATH = os.path.join(OUT_DIR, "ft_joint_full.pt")
516
+ def save_full(state, val_emos=float("nan")):
517
+ torch.save({**state, "emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
518
+ "WAVLM_DIM": WAVLM_DIM, "AUD_DIM": AUD_DIM,
519
+ "UNFREEZE_WAVLM": UNFREEZE_WAVLM, "UNFREEZE_AUD": UNFREEZE_AUD,
520
+ "val_emos": float(val_emos)}, CKPT_PATH)
521
+
522
+ # Init best từ trạng thái warm-start hiện tại → chỉ lưu nếu train tốt hơn
523
+ m0 = evaluate(); best = mean_srcc(m0); best_state = snapshot(); save_full(best_state, m0.get("emos", float("nan")))
524
+ print(f"📍 Khởi điểm (warm-start): mean SRCC = {best:.4f} | "
525
+ + " ".join(f"{k}={m0[k]:.3f}" for k in ['emos','val','aro','dom'] if k in m0))
526
+
527
+ bad = 0
528
+ for ep in range(1, EPOCHS + 1):
529
+ set_train(True)
530
+ opt.zero_grad(); run = 0.0; nb = 0
531
+ for step, b in enumerate(tqdm(tr_loader, desc=f"epoch {ep}")):
532
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
533
+ emos_p, cat_l, vad_p = forward_batch(b)
534
+ loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
535
+ scaler.scale(loss).backward()
536
+ if (step + 1) % ACCUM == 0:
537
+ scaler.step(opt); scaler.update(); opt.zero_grad()
538
+ run += loss.item() * ACCUM; nb += 1
539
+ m = evaluate(); sc = mean_srcc(m)
540
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
541
+ print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
542
+ if sc > best:
543
+ best = sc; best_state = snapshot(); save_full(best_state, m["emos"])
544
+ print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})"); bad = 0
545
+ else:
546
+ bad += 1
547
+ if bad >= PATIENCE:
548
+ print(f"Early stop ở epoch {ep}."); break
549
+
550
+ # Nạp lại best
551
+ wavlm.load_state_dict(best_state["wavlm"]); aud.load_state_dict(best_state["aud"])
552
+ aud_head.load_state_dict(best_state["aud_head"]); heads.load_state_dict(best_state["heads"])
553
+ final = evaluate()
554
+ print("\n✅ VAL (nội bộ) — exp11 (fine-tune CẢ 2 + fusion):")
555
+ print(f" EMOS={final['emos']:.4f} (exp08 {EXP08['emos']})")
556
+ if HAS_VAD:
557
+ print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} "
558
+ f"(exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
559
+ print(f" mean SRCC: warm-start {mean_srcc(m0):.4f} → exp11 {mean_srcc(final):.4f} "
560
+ + ("🚀 cải thiện" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else "➖ không cải thiện"))
561
+ save_full(best_state, final.get("emos", float("nan")))
562
+ print("Đã lưu FULL:", CKPT_PATH, "→ NHỚ Save Version!")
563
+
564
+ # %% [markdown]
565
+ # ## 6. Dự đoán DEV → answer.txt (5 cột cảm xúc từ exp11; QMOS mượn exp07 / UTMOSv2)
566
+
567
+ # %%
568
+ def list_dev():
569
+ with open(DEV_SCP) as f:
570
+ return [ln.strip() for ln in f if ln.strip()]
571
+
572
+ dev_names = list_dev()
573
+ if LIMIT_DEV:
574
+ dev_names = dev_names[:LIMIT_DEV]
575
+ print("DEV:", len(dev_names), "mẫu")
576
+
577
+ def load_exp07_qmos():
578
+ if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
579
+ import csv
580
+ d = {}
581
+ with open(EXP07_ANSWER) as f:
582
+ for row in csv.DictReader(f):
583
+ d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
584
+ print(f"✅ Mượn QMOS exp07: {len(d)//2} wav")
585
+ return d
586
+ return None
587
+
588
+ qmos_map = load_exp07_qmos()
589
+ if qmos_map is None:
590
+ print("ℹ️ Không có exp07 → QMOS bằng UTMOSv2.")
591
+ pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
592
+ import utmosv2
593
+ v2 = utmosv2.create_model(pretrained=True)
594
+ qmos_map = {}
595
+ for n in tqdm(dev_names, desc="UTMOSv2"):
596
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
597
+ if os.path.exists(wav):
598
+ o = v2.predict(input_path=wav)
599
+ qmos_map[n] = float(o["predicted_mos"]) if isinstance(o, dict) else float(o)
600
+ del v2; torch.cuda.empty_cache() if device == "cuda" else None
601
+
602
+ @torch.no_grad()
603
+ def predict_emotion(sid):
604
+ pair = load_pair(sid)
605
+ if pair is None:
606
+ return None
607
+ wave, iv = pair
608
+ set_train(False)
609
+ w = torch.from_numpy(wave).unsqueeze(0).to(device)
610
+ ivt = torch.from_numpy(iv).unsqueeze(0).to(device)
611
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
612
+ tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
613
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
614
+ feat = torch.cat([wavlm_embed(w, am), aud_embed(ivt, am)], dim=1)
615
+ emos_p, cat_l, vad_p = heads(feat, tgt)
616
+ emos = float(emos_p.item()) * emos_sd + emos_mu
617
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
618
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
619
+ return emos, cat5, vad3
620
+
621
+ def fmt_cat(p5):
622
+ return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
623
+
624
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
625
+ n_real = n_def = 0
626
+ with open(answer_path, "w") as f:
627
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
628
+ for name in tqdm(dev_names, desc="answer"):
629
+ sid = stem(name)
630
+ pr = predict_emotion(sid)
631
+ if pr is None:
632
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
633
+ else:
634
+ emos, cat5, vad3 = pr; n_real += 1
635
+ qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
636
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
637
+ print(f"Ghi {len(dev_names)} dòng → {answer_path} | cảm xúc thật {n_real}, mặc định {n_def}")
638
+
639
+ # %% [markdown]
640
+ # ## 7. Validate + zip
641
+
642
+ # %%
643
+ def validate(path):
644
+ import csv
645
+ with open(path) as f:
646
+ rows = list(csv.reader(f))
647
+ assert rows[0][0] == "wav" and "QMOS" in rows[0], "Header sai"
648
+ for i, r in enumerate(rows[1:], 2):
649
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
650
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
651
+
652
+ validate(answer_path)
653
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp11_joint.zip answer.txt && unzip -l submission_track2_exp11_joint.zip")
654
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp11_joint.zip"))
655
+
656
+ # %% [markdown]
657
+ # ## Ghi chú
658
+ # - **exp11 = fine-tune CẢ WavLM + audeering, FUSION 1 model** (khác exp08 audeering frozen, khác exp10 ensemble).
659
+ # - **Warm-start:** WavLM + heads từ `ft_emotion_full_20epoch.pt` (exp08) → bắt đầu từ điểm tốt; audeering từ
660
+ # pretrained gốc, mở băng để học thêm. Khởi điểm = đúng exp08 → train chỉ có thể tốt lên (giữ best).
661
+ # - **OOM:** đây là cấu hình nặng nhất. Nếu CUDA OOM → giảm `UNFREEZE_WAVLM`/`UNFREEZE_AUD` (4→2),
662
+ # `MAX_SECONDS` (6→5), giữ `BATCH=1` + tăng `ACCUM`.
663
+ # - **Checkpoint:** lưu `ft_joint_full.pt` mỗi best (đủ cả 2 backbone + heads) → kernel chết vẫn còn. Save Version!
664
+ # - **QMOS** vẫn mượn exp07 (0.548). So sánh nộp: exp11 vs exp08(0.811) vs exp10(ensemble) → chọn bản tốt nhất.
665
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (exp11).
track2/exp12_wavlm_scratch.ipynb ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "73aea642",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp12 (WavLM: SCRATCH vs BASE vs SAILER — ablation khởi tạo) — Kaggle T4\n",
9
+ "\n",
10
+ "**Mục đích:** kiểm chứng giả thuyết của mentor — *\"với 12k data, train from scratch có tốt hơn fine-tune không?\"*\n",
11
+ "Một notebook, đổi cờ `INIT_MODE` để chạy 3 cách khởi tạo backbone WavLM, so trên CÙNG kiến trúc/data:\n",
12
+ "\n",
13
+ "| INIT_MODE | Khởi tạo WavLM | Train gì | Ý nghĩa |\n",
14
+ "|---|---|---|---|\n",
15
+ "| `scratch` | **ngẫu nhiên** (không pretrain) | **toàn bộ** backbone | \"from scratch\" đúng nghĩa mentor nói |\n",
16
+ "| `base` | microsoft/wavlm-large (pretrain SSL, KHÔNG cảm xúc) | mở băng N lớp trên | đo lợi ích của SAILER warm-start |\n",
17
+ "| `sailer` | warm-start cảm xúc (như exp08) | mở băng N lớp trên | bản mạnh hiện tại |\n",
18
+ "\n",
19
+ "**Chỉ WavLM** (bỏ audeering) để cô lập đúng biến \"khởi tạo\". QMOS mượn exp07 / UTMOSv2.\n",
20
+ "\n",
21
+ "## ⚠️ Kỳ vọng trung thực (để đọc kết quả đúng)\n",
22
+ "- `scratch` gần như CHẮC CHẮN yếu hơn `base`/`sailer`: 12k mẫu là quá ít để dạy WavLM \"nghe\" từ đầu\n",
23
+ " (SSL pretrain dùng ~94.000 GIỜ audio). Đây là ablation để **chứng minh bằng số**, không phải để vượt.\n",
24
+ "- `scratch` phải mở băng TOÀN BỘ (mới có gì để học) → **nặng + chậm + dễ OOM** trên T4. Dùng LIMIT nhỏ trước.\n",
25
+ "- So sánh bằng **VAL nội bộ** giữa 3 mode đã đủ kết luận; muốn chắc thì nộp mode tốt nhất lên DEV.\n",
26
+ "\n",
27
+ "**Cách chạy:** GPU T4 + Internet On → sửa cell 0 (`INIT_MODE` + slug) → Run All. Chạy 3 lần đổi INIT_MODE."
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "id": "c0242f5c",
33
+ "metadata": {},
34
+ "source": [
35
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "a167d9d3",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "import os\n",
46
+ "\n",
47
+ "INIT_MODE = \"sailer\" # << \"scratch\" | \"base\" | \"sailer\" (đổi rồi chạy lại để so) — \"sailer\" = WavLM warm-start cảm xúc\n",
48
+ "\n",
49
+ "DATA_ROOT = \"/kaggle/input/datasets/minhtoan2/vmc2026-track2-full\" # << SỬA slug cho khớp Add Input\n",
50
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
51
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\"\n",
52
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\"\n",
53
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
54
+ "\n",
55
+ "EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) mượn QMOS 0.548; không có → UTMOSv2\n",
56
+ "OUT_DIR = \"/kaggle/working\"\n",
57
+ "\n",
58
+ "# ── Siêu tham số ─────────────────────────────────────────────────────────────\n",
59
+ "DEVICE = \"cuda\"\n",
60
+ "SR = 16000\n",
61
+ "MAX_SECONDS = 6\n",
62
+ "TRUNK_HIDDEN = 512\n",
63
+ "HEAD_HIDDEN = 128\n",
64
+ "DROPOUT = 0.3\n",
65
+ "WEIGHT_DECAY = 1e-5\n",
66
+ "EPOCHS = 15\n",
67
+ "PATIENCE = 5\n",
68
+ "BATCH = 4\n",
69
+ "ACCUM = 8\n",
70
+ "VAL_FRAC = 0.10\n",
71
+ "SEED = 42\n",
72
+ "USE_AMP = True\n",
73
+ "USE_GRAD_CKPT = True\n",
74
+ "USE_UNCERTAINTY = True\n",
75
+ "\n",
76
+ "# Khởi tạo & LR & mở băng — TỰ đặt theo INIT_MODE (scratch cần LR lớn + mở băng toàn bộ)\n",
77
+ "if INIT_MODE == \"scratch\":\n",
78
+ " UNFREEZE_TOP_LAYERS = \"all\" # random init → phải train tất cả mới học được\n",
79
+ " LR_BACKBONE = 1e-4 # random init cần bước lớn hơn fine-tune\n",
80
+ "elif INIT_MODE in (\"base\", \"sailer\"):\n",
81
+ " UNFREEZE_TOP_LAYERS = 6 # fine-tune: chỉ mở băng N lớp trên (tiết kiệm VRAM, chống overfit)\n",
82
+ " LR_BACKBONE = 1e-5\n",
83
+ "else:\n",
84
+ " raise ValueError(f\"INIT_MODE lạ: {INIT_MODE}\")\n",
85
+ "LR_HEAD = 1e-3\n",
86
+ "\n",
87
+ "LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
88
+ "LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
89
+ "\n",
90
+ "EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751} # mốc DEV để tham khảo\n",
91
+ "\n",
92
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
93
+ "_EMO_ALIAS = {\n",
94
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
95
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
96
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
97
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
98
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
99
+ "}\n",
100
+ "\n",
101
+ "def norm_emotion(label):\n",
102
+ " key = str(label).strip().lower()\n",
103
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
104
+ "\n",
105
+ "def stem(p):\n",
106
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
107
+ "\n",
108
+ "print(f\"INIT_MODE = {INIT_MODE} | UNFREEZE = {UNFREEZE_TOP_LAYERS} | LR_BACKBONE = {LR_BACKBONE}\")\n",
109
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
110
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
111
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "markdown",
116
+ "id": "46cd2554",
117
+ "metadata": {},
118
+ "source": [
119
+ "## 1. Cài đặt (clone SAILER chỉ khi INIT_MODE='sailer')"
120
+ ]
121
+ },
122
+ {
123
+ "cell_type": "code",
124
+ "execution_count": null,
125
+ "id": "51808707",
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "import sys, subprocess\n",
130
+ "import numpy as _np\n",
131
+ "\n",
132
+ "# ⚠️ KHÓA numpy = bản Kaggle đang có → pip KHÔNG được nâng/hạ numpy → tránh \"SystemError: bad call flags\"\n",
133
+ "# (lỗi import torch do numpy lệch phiên bản với torch đã biên dịch sẵn).\n",
134
+ "_NPIN = f\"numpy=={_np.__version__}\"\n",
135
+ "print(\"Khóa numpy ở:\", _NPIN)\n",
136
+ "\n",
137
+ "def pip_install(*pkgs):\n",
138
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs, _NPIN], check=True)\n",
139
+ "\n",
140
+ "# Kaggle đã có sẵn torch/transformers/librosa/scipy/sklearn/pandas/tqdm/huggingface_hub/safetensors.\n",
141
+ "# Chỉ cài thêm vài gói speech còn thiếu (kèm khóa numpy ở trên).\n",
142
+ "pip_install(\"loralib\", \"speechmos\", \"soundfile\")\n",
143
+ "if INIT_MODE == \"sailer\":\n",
144
+ " pip_install(\"speechbrain\")\n",
145
+ "\n",
146
+ "if INIT_MODE == \"sailer\":\n",
147
+ " REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
148
+ " if not os.path.exists(REPO_DIR):\n",
149
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
150
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
151
+ " if REPO_DIR not in sys.path:\n",
152
+ " sys.path.insert(0, REPO_DIR)"
153
+ ]
154
+ },
155
+ {
156
+ "cell_type": "markdown",
157
+ "id": "5288727c",
158
+ "metadata": {},
159
+ "source": [
160
+ "## 2. Dựng WavLM theo INIT_MODE"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": null,
166
+ "id": "c828dcd3",
167
+ "metadata": {
168
+ "lines_to_next_cell": 1
169
+ },
170
+ "outputs": [],
171
+ "source": [
172
+ "import torch\n",
173
+ "import torch.nn as nn\n",
174
+ "import torch.nn.functional as F\n",
175
+ "import numpy as np\n",
176
+ "\n",
177
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
178
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
179
+ "\n",
180
+ "from transformers import WavLMModel, WavLMConfig\n",
181
+ "\n",
182
+ "def find_hf_backbone(module):\n",
183
+ " cands = []\n",
184
+ " for name, m in module.named_modules():\n",
185
+ " enc = getattr(m, \"encoder\", None)\n",
186
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
187
+ " and getattr(enc, \"layers\", None) is not None:\n",
188
+ " cands.append((name, m))\n",
189
+ " if not cands:\n",
190
+ " return None, None\n",
191
+ " cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
192
+ " return cands[0]\n",
193
+ "\n",
194
+ "wavlm = None\n",
195
+ "if INIT_MODE == \"scratch\":\n",
196
+ " # Random init NHƯNG giữ ĐÚNG kiến trúc large (để công bằng với base/sailer)\n",
197
+ " cfg = WavLMConfig.from_pretrained(\"microsoft/wavlm-large\")\n",
198
+ " wavlm = WavLMModel(cfg) # KHÔNG load trọng số → ngẫu nhiên\n",
199
+ " print(\"🎲 WavLM-large khởi tạo NGẪU NHIÊN (from scratch, không pretrain).\")\n",
200
+ "elif INIT_MODE == \"base\":\n",
201
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
202
+ " print(\"📦 WavLM-large pretrain SSL (chưa học cảm xúc).\")\n",
203
+ "elif INIT_MODE == \"sailer\":\n",
204
+ " try:\n",
205
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
206
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
207
+ " name, wavlm = find_hf_backbone(_wrapper)\n",
208
+ " print(f\"🔥 WavLM warm-start SAILER (cảm xúc) tại '.{name}'\")\n",
209
+ " except Exception as e:\n",
210
+ " print(\"⚠️ Lỗi nạp SAILER:\", repr(e), \"→ fallback base pretrained.\")\n",
211
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
212
+ "\n",
213
+ "wavlm = wavlm.to(device)\n",
214
+ "WAVLM_DIM = int(wavlm.config.hidden_size)\n",
215
+ "wavlm.config.layerdrop = 0.0 # ⚠️ BẮT BUỘC khi dùng gradient-checkpointing (tránh CheckpointError do bỏ lớp ngẫu nhiên)\n",
216
+ "\n",
217
+ "# Mở băng theo cấu hình\n",
218
+ "if UNFREEZE_TOP_LAYERS == \"all\":\n",
219
+ " for p in wavlm.parameters():\n",
220
+ " p.requires_grad = True\n",
221
+ " n_open = \"ALL\"\n",
222
+ "else:\n",
223
+ " for p in wavlm.parameters():\n",
224
+ " p.requires_grad = False\n",
225
+ " _wl = wavlm.encoder.layers\n",
226
+ " for layer in _wl[max(0, len(_wl) - UNFREEZE_TOP_LAYERS):]:\n",
227
+ " for p in layer.parameters():\n",
228
+ " p.requires_grad = True\n",
229
+ " n_open = f\"top {min(UNFREEZE_TOP_LAYERS, len(_wl))}/{len(_wl)}\"\n",
230
+ "print(f\"WavLM mở băng: {n_open} → {sum(p.numel() for p in wavlm.parameters() if p.requires_grad)/1e6:.1f}M param train (dim {WAVLM_DIM})\")\n",
231
+ "\n",
232
+ "if USE_GRAD_CKPT:\n",
233
+ " wavlm.gradient_checkpointing_enable()\n",
234
+ " if hasattr(wavlm, \"enable_input_require_grads\"):\n",
235
+ " wavlm.enable_input_require_grads()\n",
236
+ "\n",
237
+ "def masked_mean(hidden, attn_mask):\n",
238
+ " if attn_mask is None:\n",
239
+ " return hidden.mean(dim=1)\n",
240
+ " try:\n",
241
+ " fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
242
+ " except Exception:\n",
243
+ " return hidden.mean(dim=1)\n",
244
+ " fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
245
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
246
+ "\n",
247
+ "def wavlm_embed(input_values, attn_mask):\n",
248
+ " out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
249
+ " return masked_mean(out, attn_mask)"
250
+ ]
251
+ },
252
+ {
253
+ "cell_type": "markdown",
254
+ "id": "6b963eb2",
255
+ "metadata": {},
256
+ "source": [
257
+ "## 3. Đọc & gộp nhãn theo wavID"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "4ba4667b",
264
+ "metadata": {},
265
+ "outputs": [],
266
+ "source": [
267
+ "import librosa\n",
268
+ "import pandas as pd\n",
269
+ "from tqdm.auto import tqdm\n",
270
+ "\n",
271
+ "def load_target_emotions():\n",
272
+ " tgt = {}\n",
273
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
274
+ " for ln in f:\n",
275
+ " parts = ln.strip().split(\"|\")\n",
276
+ " if len(parts) >= 2:\n",
277
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
278
+ " return tgt\n",
279
+ "\n",
280
+ "def _col(cols_map, *names, df=None, default_idx=None):\n",
281
+ " for n in names:\n",
282
+ " if n in cols_map:\n",
283
+ " return cols_map[n]\n",
284
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
285
+ "\n",
286
+ "def parse_emocat_votes(cell):\n",
287
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
288
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
289
+ " e = norm_emotion(tok)\n",
290
+ " if e in EMOTIONS5:\n",
291
+ " v[EMOTIONS5.index(e)] += 1.0\n",
292
+ " return v\n",
293
+ "\n",
294
+ "def load_train_labels():\n",
295
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
296
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
297
+ " wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
298
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
299
+ " val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
300
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
301
+ " assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
302
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
303
+ " rows = []\n",
304
+ " for sid, g in df.groupby(\"_stem\"):\n",
305
+ " rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
306
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
307
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
308
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
309
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
310
+ " if cat_col:\n",
311
+ " for cell in g[cat_col]:\n",
312
+ " votes += parse_emocat_votes(cell)\n",
313
+ " s = votes.sum()\n",
314
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
315
+ " for i in range(len(EMOTIONS5)):\n",
316
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
317
+ " rows.append(rec)\n",
318
+ " return pd.DataFrame(rows)\n",
319
+ "\n",
320
+ "target_map = load_target_emotions()\n",
321
+ "train_df = load_train_labels()\n",
322
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
323
+ "print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
324
+ ]
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "id": "7efc0957",
329
+ "metadata": {},
330
+ "source": [
331
+ "## 4. Dataset/loader (chỉ raw wave cho WavLM)"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "id": "a0ae0f55",
338
+ "metadata": {},
339
+ "outputs": [],
340
+ "source": [
341
+ "from torch.utils.data import Dataset, DataLoader\n",
342
+ "from sklearn.model_selection import train_test_split\n",
343
+ "\n",
344
+ "train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
345
+ "if LIMIT_TRAIN:\n",
346
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
347
+ "lab = train_df.set_index(\"wavID\")\n",
348
+ "\n",
349
+ "def _zfit(a):\n",
350
+ " a = np.asarray(a, dtype=np.float32); return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
351
+ "emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
352
+ "if HAS_VAD:\n",
353
+ " vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], np.float32)\n",
354
+ " vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], np.float32)\n",
355
+ "else:\n",
356
+ " vad_mu = np.zeros(3, np.float32); vad_sd = np.ones(3, np.float32)\n",
357
+ "print(f\"Chuẩn hóa: emos μ={emos_mu:.3f} σ={emos_sd:.3f}\")\n",
358
+ "\n",
359
+ "def onehot_target(tgt):\n",
360
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
361
+ " if tgt in EMOTIONS5:\n",
362
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
363
+ " return v\n",
364
+ "\n",
365
+ "def load_wav(sid):\n",
366
+ " p = os.path.join(WAV_DIR, sid if str(sid).endswith(\".wav\") else str(sid) + \".wav\")\n",
367
+ " if not os.path.exists(p):\n",
368
+ " return None\n",
369
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
370
+ " return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
371
+ "\n",
372
+ "class EmoDataset(Dataset):\n",
373
+ " def __init__(self, stems):\n",
374
+ " self.stems = [s for s in stems if load_wav(s) is not None]\n",
375
+ " def __len__(self):\n",
376
+ " return len(self.stems)\n",
377
+ " def __getitem__(self, i):\n",
378
+ " s = self.stems[i]\n",
379
+ " wave = load_wav(s)\n",
380
+ " emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
381
+ " if HAS_VAD:\n",
382
+ " vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
383
+ " else:\n",
384
+ " vad = np.zeros(3, dtype=np.float32)\n",
385
+ " cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
386
+ " return {\"wave\": wave, \"tgt\": onehot_target(target_map.get(s)),\n",
387
+ " \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
388
+ " \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
389
+ " \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
390
+ "\n",
391
+ "def collate(batch):\n",
392
+ " L = max(len(b[\"wave\"]) for b in batch)\n",
393
+ " waves = np.zeros((len(batch), L), dtype=np.float32)\n",
394
+ " mask = np.zeros((len(batch), L), dtype=np.float32)\n",
395
+ " for i, b in enumerate(batch):\n",
396
+ " waves[i, : len(b[\"wave\"])] = b[\"wave\"]; mask[i, : len(b[\"wave\"])] = 1.0\n",
397
+ " return {\n",
398
+ " \"input_values\": torch.from_numpy(waves), \"attn_mask\": torch.from_numpy(mask).long(),\n",
399
+ " \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
400
+ " \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
401
+ " \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
402
+ " \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
403
+ " \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
404
+ " \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
405
+ " }\n",
406
+ "\n",
407
+ "ds = EmoDataset(train_stems)\n",
408
+ "print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
409
+ "tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
410
+ "tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
411
+ "va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "markdown",
416
+ "id": "d249653b",
417
+ "metadata": {},
418
+ "source": [
419
+ "## 5. Heads + train loop"
420
+ ]
421
+ },
422
+ {
423
+ "cell_type": "code",
424
+ "execution_count": null,
425
+ "id": "fac929f9",
426
+ "metadata": {
427
+ "lines_to_next_cell": 1
428
+ },
429
+ "outputs": [],
430
+ "source": [
431
+ "from scipy.stats import spearmanr\n",
432
+ "\n",
433
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
434
+ "N_EMO = len(EMOTIONS5)\n",
435
+ "\n",
436
+ "class EmoHeads(nn.Module):\n",
437
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
438
+ " super().__init__()\n",
439
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
440
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
441
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
442
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
443
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
444
+ " def forward(self, feat, tgt):\n",
445
+ " h = self.trunk(feat)\n",
446
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
447
+ "\n",
448
+ "heads = EmoHeads(WAVLM_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
449
+ "\n",
450
+ "TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
451
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
452
+ "bb_params = [p for p in wavlm.parameters() if p.requires_grad]\n",
453
+ "head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
454
+ "opt = torch.optim.AdamW([{\"params\": bb_params, \"lr\": LR_BACKBONE},\n",
455
+ " {\"params\": head_params, \"lr\": LR_HEAD}], weight_decay=WEIGHT_DECAY)\n",
456
+ "scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
457
+ "mse = nn.MSELoss()\n",
458
+ "\n",
459
+ "def soft_ce(logits, target_dist):\n",
460
+ " return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
461
+ "\n",
462
+ "def forward_batch(b):\n",
463
+ " feat = wavlm_embed(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
464
+ " return heads(feat, b[\"tgt\"].to(device))\n",
465
+ "\n",
466
+ "def compute_loss(emos_p, cat_l, vad_p, b):\n",
467
+ " L = {}\n",
468
+ " L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
469
+ " L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
470
+ " if HAS_VAD:\n",
471
+ " vt = b[\"vad\"].to(device)\n",
472
+ " L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
473
+ " else:\n",
474
+ " z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
475
+ " if USE_UNCERTAINTY:\n",
476
+ " return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
477
+ " return sum(L.values())\n",
478
+ "\n",
479
+ "@torch.no_grad()\n",
480
+ "def evaluate():\n",
481
+ " wavlm.eval(); heads.eval()\n",
482
+ " P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
483
+ " catP, catY = [], []\n",
484
+ " for b in va_loader:\n",
485
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
486
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
487
+ " P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
488
+ " vad_p = vad_p.float().cpu().numpy()\n",
489
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
490
+ " P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
491
+ " catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
492
+ " out = {}\n",
493
+ " for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
494
+ " out[t] = spearmanr(P[t], Y[t]).correlation\n",
495
+ " q = np.concatenate(catP); p = np.concatenate(catY)\n",
496
+ " out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
497
+ " return out\n",
498
+ "\n",
499
+ "def mean_srcc(m):\n",
500
+ " keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
501
+ " return float(np.mean([m[k] for k in keys]))\n",
502
+ "\n",
503
+ "CKPT_PATH = os.path.join(OUT_DIR, f\"ft_wavlm_{INIT_MODE}.pt\")\n",
504
+ "def save_full(state, val_emos=float(\"nan\")):\n",
505
+ " torch.save({\"wavlm\": state[\"wavlm\"], \"heads\": state[\"heads\"], \"INIT_MODE\": INIT_MODE,\n",
506
+ " \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
507
+ " \"WAVLM_DIM\": WAVLM_DIM, \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
508
+ "\n",
509
+ "best, best_state, bad = -1e9, None, 0\n",
510
+ "for ep in range(1, EPOCHS + 1):\n",
511
+ " wavlm.train(); heads.train()\n",
512
+ " opt.zero_grad(); run = 0.0; nb = 0\n",
513
+ " for step, b in enumerate(tqdm(tr_loader, desc=f\"[{INIT_MODE}] epoch {ep}\")):\n",
514
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
515
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
516
+ " loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
517
+ " scaler.scale(loss).backward()\n",
518
+ " if (step + 1) % ACCUM == 0:\n",
519
+ " scaler.step(opt); scaler.update(); opt.zero_grad()\n",
520
+ " run += loss.item() * ACCUM; nb += 1\n",
521
+ " m = evaluate(); sc = mean_srcc(m)\n",
522
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
523
+ " print(f\"[{INIT_MODE}] epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
524
+ " if sc > best:\n",
525
+ " best = sc\n",
526
+ " best_state = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
527
+ " \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
528
+ " save_full(best_state, m[\"emos\"]); bad = 0\n",
529
+ " print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\")\n",
530
+ " else:\n",
531
+ " bad += 1\n",
532
+ " if bad >= PATIENCE:\n",
533
+ " print(f\"Early stop ở epoch {ep}.\"); break\n",
534
+ "\n",
535
+ "if best_state:\n",
536
+ " wavlm.load_state_dict(best_state[\"wavlm\"]); heads.load_state_dict(best_state[\"heads\"])\n",
537
+ "final = evaluate()\n",
538
+ "print(f\"\\n✅ VAL (nội bộ) — exp12 INIT_MODE={INIT_MODE}:\")\n",
539
+ "print(f\" EMOS={final['emos']:.4f}\", end=\"\")\n",
540
+ "if HAS_VAD:\n",
541
+ " print(f\" | VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}\")\n",
542
+ "else:\n",
543
+ " print()\n",
544
+ "print(f\" cat_err={final['cat_err']:.4f} | mean SRCC={mean_srcc(final):.4f}\")\n",
545
+ "print(f\" (Mốc DEV exp08 để tham khảo: EMOS {EXP08['emos']}, VAD {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
546
+ "print(\" ➜ GHI con số này vào bảng ablation 04_ rồi đổi INIT_MODE chạy lại để so 3 mode.\")"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "markdown",
551
+ "id": "0874af79",
552
+ "metadata": {},
553
+ "source": [
554
+ "## 6. Dự đoán DEV → answer.txt (QMOS mượn exp07 / UTMOSv2)"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "code",
559
+ "execution_count": null,
560
+ "id": "b21df9af",
561
+ "metadata": {
562
+ "lines_to_next_cell": 1
563
+ },
564
+ "outputs": [],
565
+ "source": [
566
+ "def list_dev():\n",
567
+ " with open(DEV_SCP) as f:\n",
568
+ " return [ln.strip() for ln in f if ln.strip()]\n",
569
+ "\n",
570
+ "dev_names = list_dev()\n",
571
+ "if LIMIT_DEV:\n",
572
+ " dev_names = dev_names[:LIMIT_DEV]\n",
573
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
574
+ "\n",
575
+ "def load_exp07_qmos():\n",
576
+ " if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
577
+ " import csv\n",
578
+ " d = {}\n",
579
+ " with open(EXP07_ANSWER) as f:\n",
580
+ " for row in csv.DictReader(f):\n",
581
+ " d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
582
+ " print(f\"✅ Mượn QMOS exp07: {len(d)//2} wav\")\n",
583
+ " return d\n",
584
+ " return None\n",
585
+ "\n",
586
+ "qmos_map = load_exp07_qmos()\n",
587
+ "if qmos_map is None:\n",
588
+ " print(\"ℹ️ Không có exp07 → QMOS bằng UTMOSv2.\")\n",
589
+ " pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
590
+ " import utmosv2\n",
591
+ " v2 = utmosv2.create_model(pretrained=True)\n",
592
+ " qmos_map = {}\n",
593
+ " for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
594
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
595
+ " if os.path.exists(wav):\n",
596
+ " o = v2.predict(input_path=wav)\n",
597
+ " qmos_map[n] = float(o[\"predicted_mos\"]) if isinstance(o, dict) else float(o)\n",
598
+ " del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
599
+ "\n",
600
+ "@torch.no_grad()\n",
601
+ "def predict_emotion(sid):\n",
602
+ " wave = load_wav(sid)\n",
603
+ " if wave is None:\n",
604
+ " return None\n",
605
+ " wavlm.eval(); heads.eval()\n",
606
+ " iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
607
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
608
+ " tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
609
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
610
+ " feat = wavlm_embed(iv, am)\n",
611
+ " emos_p, cat_l, vad_p = heads(feat, tgt)\n",
612
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
613
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
614
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
615
+ " return emos, cat5, vad3\n",
616
+ "\n",
617
+ "def fmt_cat(p5):\n",
618
+ " return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
619
+ "\n",
620
+ "answer_path = os.path.join(OUT_DIR, f\"answer_{INIT_MODE}.txt\")\n",
621
+ "n_real = n_def = 0\n",
622
+ "with open(answer_path, \"w\") as f:\n",
623
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
624
+ " for name in tqdm(dev_names, desc=f\"answer[{INIT_MODE}]\"):\n",
625
+ " sid = stem(name)\n",
626
+ " pr = predict_emotion(sid)\n",
627
+ " if pr is None:\n",
628
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
629
+ " else:\n",
630
+ " emos, cat5, vad3 = pr; n_real += 1\n",
631
+ " qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
632
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
633
+ "print(f\"Ghi {len(dev_names)} dòng → {answer_path} | thật {n_real}, mặc định {n_def}\")"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "markdown",
638
+ "id": "eaebc2e5",
639
+ "metadata": {},
640
+ "source": [
641
+ "## 7. Validate + zip"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": null,
647
+ "id": "43e0440b",
648
+ "metadata": {},
649
+ "outputs": [],
650
+ "source": [
651
+ "def validate(path):\n",
652
+ " import csv\n",
653
+ " with open(path) as f:\n",
654
+ " rows = list(csv.reader(f))\n",
655
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0], \"Header sai\"\n",
656
+ " for i, r in enumerate(rows[1:], 2):\n",
657
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
658
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
659
+ "\n",
660
+ "validate(answer_path)\n",
661
+ "os.system(f\"cd {OUT_DIR} && cp answer_{INIT_MODE}.txt answer.txt && zip -j submission_track2_exp12_{INIT_MODE}.zip answer.txt && unzip -l submission_track2_exp12_{INIT_MODE}.zip\")\n",
662
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, f\"submission_track2_exp12_{INIT_MODE}.zip\"))"
663
+ ]
664
+ },
665
+ {
666
+ "cell_type": "markdown",
667
+ "id": "ade69063",
668
+ "metadata": {},
669
+ "source": [
670
+ "## Ghi chú\n",
671
+ "- **Chạy 3 lần** đổi `INIT_MODE` (\"scratch\"→\"base\"→\"sailer\"), ghi `mean SRCC` mỗi lần vào BẢNG ABLATION\n",
672
+ " trong `docs/04_experiments_log.md` → trả lời mentor bằng số: from-scratch tốt hơn fine-tune không?\n",
673
+ "- **scratch nặng:** mở băng toàn bộ WavLM-large. Nếu OOM → giảm `BATCH` (4→2), `MAX_SECONDS` (6→5),\n",
674
+ " hoặc đổi sang `microsoft/wavlm-base-plus` (sửa cell 2) cho khả thi (lưu ý: khác kiến trúc → so kém công bằng hơn).\n",
675
+ "- **scratch chậm + cần nhiều epoch hơn** (random init): để `EPOCHS=15`, `PATIENCE=5`. Vẫn nhiều khả năng < base/sailer.\n",
676
+ "- **Đừng nhầm VAL nội bộ với DEV.** So 3 mode bằng VAL nội bộ đã đủ kết luận; muốn chắc thì nộp mode tốt nhất.\n",
677
+ "- Checkpoint lưu `ft_wavlm_<mode>.pt`. Save Version sau mỗi lần chạy."
678
+ ]
679
+ }
680
+ ],
681
+ "metadata": {
682
+ "jupytext": {
683
+ "cell_metadata_filter": "-all",
684
+ "main_language": "python",
685
+ "notebook_metadata_filter": "-all"
686
+ }
687
+ },
688
+ "nbformat": 4,
689
+ "nbformat_minor": 5
690
+ }
track2/exp12_wavlm_scratch_pipeline.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp12 (WavLM: SCRATCH vs BASE vs SAILER — ablation khởi tạo) — Kaggle T4
3
+ #
4
+ # **Mục đích:** kiểm chứng giả thuyết của mentor — *"với 12k data, train from scratch có tốt hơn fine-tune không?"*
5
+ # Một notebook, đổi cờ `INIT_MODE` để chạy 3 cách khởi tạo backbone WavLM, so trên CÙNG kiến trúc/data:
6
+ #
7
+ # | INIT_MODE | Khởi tạo WavLM | Train gì | Ý nghĩa |
8
+ # |---|---|---|---|
9
+ # | `scratch` | **ngẫu nhiên** (không pretrain) | **toàn bộ** backbone | "from scratch" đúng nghĩa mentor nói |
10
+ # | `base` | microsoft/wavlm-large (pretrain SSL, KHÔNG cảm xúc) | mở băng N lớp trên | đo lợi ích của SAILER warm-start |
11
+ # | `sailer` | warm-start cảm xúc (như exp08) | mở băng N lớp trên | bản mạnh hiện tại |
12
+ #
13
+ # **Chỉ WavLM** (bỏ audeering) để cô lập đúng biến "khởi tạo". QMOS mượn exp07 / UTMOSv2.
14
+ #
15
+ # ## ⚠️ Kỳ vọng trung thực (để đọc kết quả đúng)
16
+ # - `scratch` gần như CHẮC CHẮN yếu hơn `base`/`sailer`: 12k mẫu là quá ít để dạy WavLM "nghe" từ đầu
17
+ # (SSL pretrain dùng ~94.000 GIỜ audio). Đây là ablation để **chứng minh bằng số**, không phải để vượt.
18
+ # - `scratch` phải mở băng TOÀN BỘ (mới có gì để học) → **nặng + chậm + dễ OOM** trên T4. Dùng LIMIT nhỏ trước.
19
+ # - So sánh bằng **VAL nội bộ** giữa 3 mode đã đủ kết luận; muốn chắc thì nộp mode tốt nhất lên DEV.
20
+ #
21
+ # **Cách chạy:** GPU T4 + Internet On → sửa cell 0 (`INIT_MODE` + slug) → Run All. Chạy 3 lần đổi INIT_MODE.
22
+
23
+ # %% [markdown]
24
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
25
+
26
+ # %%
27
+ import os
28
+
29
+ INIT_MODE = "sailer" # << "scratch" | "base" | "sailer" (đổi rồi chạy lại để so) — "sailer" = WavLM warm-start cảm xúc
30
+
31
+ DATA_ROOT = "/kaggle/input/datasets/minhtoan2/vmc2026-track2-full" # << SỬA slug cho khớp Add Input
32
+ WAV_DIR = f"{DATA_ROOT}/wav"
33
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv"
34
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv"
35
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
36
+
37
+ EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) mượn QMOS 0.548; không có → UTMOSv2
38
+ OUT_DIR = "/kaggle/working"
39
+
40
+ # ── Siêu tham số ─────────────────────────────────────────────────────────────
41
+ DEVICE = "cuda"
42
+ SR = 16000
43
+ MAX_SECONDS = 6
44
+ TRUNK_HIDDEN = 512
45
+ HEAD_HIDDEN = 128
46
+ DROPOUT = 0.3
47
+ WEIGHT_DECAY = 1e-5
48
+ EPOCHS = 15
49
+ PATIENCE = 5
50
+ BATCH = 4
51
+ ACCUM = 8
52
+ VAL_FRAC = 0.10
53
+ SEED = 42
54
+ USE_AMP = True
55
+ USE_GRAD_CKPT = True
56
+ USE_UNCERTAINTY = True
57
+
58
+ # Khởi tạo & LR & mở băng — TỰ đặt theo INIT_MODE (scratch cần LR lớn + mở băng toàn bộ)
59
+ if INIT_MODE == "scratch":
60
+ UNFREEZE_TOP_LAYERS = "all" # random init → phải train tất cả mới học được
61
+ LR_BACKBONE = 1e-4 # random init cần bước lớn hơn fine-tune
62
+ elif INIT_MODE in ("base", "sailer"):
63
+ UNFREEZE_TOP_LAYERS = 6 # fine-tune: chỉ mở băng N lớp trên (tiết kiệm VRAM, chống overfit)
64
+ LR_BACKBONE = 1e-5
65
+ else:
66
+ raise ValueError(f"INIT_MODE lạ: {INIT_MODE}")
67
+ LR_HEAD = 1e-3
68
+
69
+ LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
70
+ LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
71
+
72
+ EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751} # mốc DEV để tham khảo
73
+
74
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
75
+ _EMO_ALIAS = {
76
+ "angry": "angry", "anger": "angry",
77
+ "happy": "happy", "happiness": "happy", "joy": "happy",
78
+ "neutral": "neutral", "calm": "neutral",
79
+ "sad": "sad", "sadness": "sad",
80
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
81
+ }
82
+
83
+ def norm_emotion(label):
84
+ key = str(label).strip().lower()
85
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
86
+
87
+ def stem(p):
88
+ return os.path.splitext(os.path.basename(str(p)))[0]
89
+
90
+ print(f"INIT_MODE = {INIT_MODE} | UNFREEZE = {UNFREEZE_TOP_LAYERS} | LR_BACKBONE = {LR_BACKBONE}")
91
+ print("DATA_ROOT:", DATA_ROOT)
92
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
93
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
94
+
95
+ # %% [markdown]
96
+ # ## 1. Cài đặt (clone SAILER chỉ khi INIT_MODE='sailer')
97
+
98
+ # %%
99
+ import sys, subprocess
100
+ import numpy as _np
101
+
102
+ # ⚠️ KHÓA numpy = bản Kaggle đang có → pip KHÔNG được nâng/hạ numpy → tránh "SystemError: bad call flags"
103
+ # (lỗi import torch do numpy lệch phiên bản với torch đã biên dịch sẵn).
104
+ _NPIN = f"numpy=={_np.__version__}"
105
+ print("Khóa numpy ở:", _NPIN)
106
+
107
+ def pip_install(*pkgs):
108
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs, _NPIN], check=True)
109
+
110
+ # Kaggle đã có sẵn torch/transformers/librosa/scipy/sklearn/pandas/tqdm/huggingface_hub/safetensors.
111
+ # Chỉ cài thêm vài gói speech còn thiếu (kèm khóa numpy ở trên).
112
+ pip_install("loralib", "speechmos", "soundfile")
113
+ if INIT_MODE == "sailer":
114
+ pip_install("speechbrain")
115
+
116
+ if INIT_MODE == "sailer":
117
+ REPO_DIR = "/kaggle/working/vox-profile-release"
118
+ if not os.path.exists(REPO_DIR):
119
+ subprocess.run(["git", "clone", "--depth", "1",
120
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
121
+ if REPO_DIR not in sys.path:
122
+ sys.path.insert(0, REPO_DIR)
123
+
124
+ # %% [markdown]
125
+ # ## 2. Dựng WavLM theo INIT_MODE
126
+
127
+ # %%
128
+ import torch
129
+ import torch.nn as nn
130
+ import torch.nn.functional as F
131
+ import numpy as np
132
+
133
+ device = DEVICE if torch.cuda.is_available() else "cpu"
134
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
135
+
136
+ from transformers import WavLMModel, WavLMConfig
137
+
138
+ def find_hf_backbone(module):
139
+ cands = []
140
+ for name, m in module.named_modules():
141
+ enc = getattr(m, "encoder", None)
142
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
143
+ and getattr(enc, "layers", None) is not None:
144
+ cands.append((name, m))
145
+ if not cands:
146
+ return None, None
147
+ cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
148
+ return cands[0]
149
+
150
+ wavlm = None
151
+ if INIT_MODE == "scratch":
152
+ # Random init NHƯNG giữ ĐÚNG kiến trúc large (để công bằng với base/sailer)
153
+ cfg = WavLMConfig.from_pretrained("microsoft/wavlm-large")
154
+ wavlm = WavLMModel(cfg) # KHÔNG load trọng số → ngẫu nhiên
155
+ print("🎲 WavLM-large khởi tạo NGẪU NHIÊN (from scratch, không pretrain).")
156
+ elif INIT_MODE == "base":
157
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
158
+ print("📦 WavLM-large pretrain SSL (chưa học cảm xúc).")
159
+ elif INIT_MODE == "sailer":
160
+ try:
161
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
162
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
163
+ name, wavlm = find_hf_backbone(_wrapper)
164
+ print(f"🔥 WavLM warm-start SAILER (cảm xúc) tại '.{name}'")
165
+ except Exception as e:
166
+ print("⚠️ Lỗi nạp SAILER:", repr(e), "→ fallback base pretrained.")
167
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
168
+
169
+ wavlm = wavlm.to(device)
170
+ WAVLM_DIM = int(wavlm.config.hidden_size)
171
+ wavlm.config.layerdrop = 0.0 # ⚠️ BẮT BUỘC khi dùng gradient-checkpointing (tránh CheckpointError do bỏ lớp ngẫu nhiên)
172
+
173
+ # Mở băng theo cấu hình
174
+ if UNFREEZE_TOP_LAYERS == "all":
175
+ for p in wavlm.parameters():
176
+ p.requires_grad = True
177
+ n_open = "ALL"
178
+ else:
179
+ for p in wavlm.parameters():
180
+ p.requires_grad = False
181
+ _wl = wavlm.encoder.layers
182
+ for layer in _wl[max(0, len(_wl) - UNFREEZE_TOP_LAYERS):]:
183
+ for p in layer.parameters():
184
+ p.requires_grad = True
185
+ n_open = f"top {min(UNFREEZE_TOP_LAYERS, len(_wl))}/{len(_wl)}"
186
+ print(f"WavLM mở băng: {n_open} → {sum(p.numel() for p in wavlm.parameters() if p.requires_grad)/1e6:.1f}M param train (dim {WAVLM_DIM})")
187
+
188
+ if USE_GRAD_CKPT:
189
+ wavlm.gradient_checkpointing_enable()
190
+ if hasattr(wavlm, "enable_input_require_grads"):
191
+ wavlm.enable_input_require_grads()
192
+
193
+ def masked_mean(hidden, attn_mask):
194
+ if attn_mask is None:
195
+ return hidden.mean(dim=1)
196
+ try:
197
+ fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
198
+ except Exception:
199
+ return hidden.mean(dim=1)
200
+ fm = fm.unsqueeze(-1).to(hidden.dtype)
201
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
202
+
203
+ def wavlm_embed(input_values, attn_mask):
204
+ out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
205
+ return masked_mean(out, attn_mask)
206
+
207
+ # %% [markdown]
208
+ # ## 3. Đọc & gộp nhãn theo wavID
209
+
210
+ # %%
211
+ import librosa
212
+ import pandas as pd
213
+ from tqdm.auto import tqdm
214
+
215
+ def load_target_emotions():
216
+ tgt = {}
217
+ with open(METADATA_CSV, encoding="utf-8") as f:
218
+ for ln in f:
219
+ parts = ln.strip().split("|")
220
+ if len(parts) >= 2:
221
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
222
+ return tgt
223
+
224
+ def _col(cols_map, *names, df=None, default_idx=None):
225
+ for n in names:
226
+ if n in cols_map:
227
+ return cols_map[n]
228
+ return list(df.columns)[default_idx] if default_idx is not None else None
229
+
230
+ def parse_emocat_votes(cell):
231
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
232
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
233
+ e = norm_emotion(tok)
234
+ if e in EMOTIONS5:
235
+ v[EMOTIONS5.index(e)] += 1.0
236
+ return v
237
+
238
+ def load_train_labels():
239
+ df = pd.read_csv(TRAIN_CSV, sep="|")
240
+ cols = {c.lower().strip(): c for c in df.columns}
241
+ wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
242
+ emos_col = _col(cols, "emos", "emo", "emomos")
243
+ val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
244
+ cat_col = _col(cols, "emocat", "cat", "emotion")
245
+ assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
246
+ df["_stem"] = df[wav_col].map(stem)
247
+ rows = []
248
+ for sid, g in df.groupby("_stem"):
249
+ rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
250
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
251
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
252
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
253
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
254
+ if cat_col:
255
+ for cell in g[cat_col]:
256
+ votes += parse_emocat_votes(cell)
257
+ s = votes.sum()
258
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
259
+ for i in range(len(EMOTIONS5)):
260
+ rec[f"cat{i}"] = float(cat[i])
261
+ rows.append(rec)
262
+ return pd.DataFrame(rows)
263
+
264
+ target_map = load_target_emotions()
265
+ train_df = load_train_labels()
266
+ HAS_VAD = bool(train_df["val"].notna().any())
267
+ print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
268
+
269
+ # %% [markdown]
270
+ # ## 4. Dataset/loader (chỉ raw wave cho WavLM)
271
+
272
+ # %%
273
+ from torch.utils.data import Dataset, DataLoader
274
+ from sklearn.model_selection import train_test_split
275
+
276
+ train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
277
+ if LIMIT_TRAIN:
278
+ train_stems = train_stems[:LIMIT_TRAIN]
279
+ lab = train_df.set_index("wavID")
280
+
281
+ def _zfit(a):
282
+ a = np.asarray(a, dtype=np.float32); return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
283
+ emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
284
+ if HAS_VAD:
285
+ vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], np.float32)
286
+ vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], np.float32)
287
+ else:
288
+ vad_mu = np.zeros(3, np.float32); vad_sd = np.ones(3, np.float32)
289
+ print(f"Chuẩn hóa: emos μ={emos_mu:.3f} σ={emos_sd:.3f}")
290
+
291
+ def onehot_target(tgt):
292
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
293
+ if tgt in EMOTIONS5:
294
+ v[EMOTIONS5.index(tgt)] = 1.0
295
+ return v
296
+
297
+ def load_wav(sid):
298
+ p = os.path.join(WAV_DIR, sid if str(sid).endswith(".wav") else str(sid) + ".wav")
299
+ if not os.path.exists(p):
300
+ return None
301
+ wave, _ = librosa.load(p, sr=SR, mono=True)
302
+ return wave[: MAX_SECONDS * SR].astype(np.float32)
303
+
304
+ class EmoDataset(Dataset):
305
+ def __init__(self, stems):
306
+ self.stems = [s for s in stems if load_wav(s) is not None]
307
+ def __len__(self):
308
+ return len(self.stems)
309
+ def __getitem__(self, i):
310
+ s = self.stems[i]
311
+ wave = load_wav(s)
312
+ emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
313
+ if HAS_VAD:
314
+ vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
315
+ else:
316
+ vad = np.zeros(3, dtype=np.float32)
317
+ cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
318
+ return {"wave": wave, "tgt": onehot_target(target_map.get(s)),
319
+ "emos": np.float32(emos), "vad": vad, "cat": cat,
320
+ "emos_raw": np.float32(lab.loc[s, "emos"]),
321
+ "vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
322
+
323
+ def collate(batch):
324
+ L = max(len(b["wave"]) for b in batch)
325
+ waves = np.zeros((len(batch), L), dtype=np.float32)
326
+ mask = np.zeros((len(batch), L), dtype=np.float32)
327
+ for i, b in enumerate(batch):
328
+ waves[i, : len(b["wave"])] = b["wave"]; mask[i, : len(b["wave"])] = 1.0
329
+ return {
330
+ "input_values": torch.from_numpy(waves), "attn_mask": torch.from_numpy(mask).long(),
331
+ "tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
332
+ "emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
333
+ "vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
334
+ "cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
335
+ "emos_raw": np.stack([b["emos_raw"] for b in batch]),
336
+ "vad_raw": np.stack([b["vad_raw"] for b in batch]),
337
+ }
338
+
339
+ ds = EmoDataset(train_stems)
340
+ print("Dataset hợp lệ:", len(ds), "wav")
341
+ tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
342
+ tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
343
+ va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
344
+
345
+ # %% [markdown]
346
+ # ## 5. Heads + train loop
347
+
348
+ # %%
349
+ from scipy.stats import spearmanr
350
+
351
+ torch.manual_seed(SEED); np.random.seed(SEED)
352
+ N_EMO = len(EMOTIONS5)
353
+
354
+ class EmoHeads(nn.Module):
355
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
356
+ super().__init__()
357
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
358
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
359
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
360
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
361
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
362
+ def forward(self, feat, tgt):
363
+ h = self.trunk(feat)
364
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
365
+
366
+ heads = EmoHeads(WAVLM_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
367
+
368
+ TASKS = ["emos", "cat", "val", "aro", "dom"]
369
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
370
+ bb_params = [p for p in wavlm.parameters() if p.requires_grad]
371
+ head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
372
+ opt = torch.optim.AdamW([{"params": bb_params, "lr": LR_BACKBONE},
373
+ {"params": head_params, "lr": LR_HEAD}], weight_decay=WEIGHT_DECAY)
374
+ scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
375
+ mse = nn.MSELoss()
376
+
377
+ def soft_ce(logits, target_dist):
378
+ return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
379
+
380
+ def forward_batch(b):
381
+ feat = wavlm_embed(b["input_values"].to(device), b["attn_mask"].to(device))
382
+ return heads(feat, b["tgt"].to(device))
383
+
384
+ def compute_loss(emos_p, cat_l, vad_p, b):
385
+ L = {}
386
+ L["emos"] = mse(emos_p, b["emos"].to(device))
387
+ L["cat"] = soft_ce(cat_l, b["cat"].to(device))
388
+ if HAS_VAD:
389
+ vt = b["vad"].to(device)
390
+ L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
391
+ else:
392
+ z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
393
+ if USE_UNCERTAINTY:
394
+ return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
395
+ return sum(L.values())
396
+
397
+ @torch.no_grad()
398
+ def evaluate():
399
+ wavlm.eval(); heads.eval()
400
+ P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
401
+ catP, catY = [], []
402
+ for b in va_loader:
403
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
404
+ emos_p, cat_l, vad_p = forward_batch(b)
405
+ P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
406
+ vad_p = vad_p.float().cpu().numpy()
407
+ for j, t in enumerate(["val", "aro", "dom"]):
408
+ P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
409
+ catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
410
+ out = {}
411
+ for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
412
+ out[t] = spearmanr(P[t], Y[t]).correlation
413
+ q = np.concatenate(catP); p = np.concatenate(catY)
414
+ out["cat_err"] = float(np.abs(q - p).sum(1).mean())
415
+ return out
416
+
417
+ def mean_srcc(m):
418
+ keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
419
+ return float(np.mean([m[k] for k in keys]))
420
+
421
+ CKPT_PATH = os.path.join(OUT_DIR, f"ft_wavlm_{INIT_MODE}.pt")
422
+ def save_full(state, val_emos=float("nan")):
423
+ torch.save({"wavlm": state["wavlm"], "heads": state["heads"], "INIT_MODE": INIT_MODE,
424
+ "emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
425
+ "WAVLM_DIM": WAVLM_DIM, "val_emos": float(val_emos)}, CKPT_PATH)
426
+
427
+ best, best_state, bad = -1e9, None, 0
428
+ for ep in range(1, EPOCHS + 1):
429
+ wavlm.train(); heads.train()
430
+ opt.zero_grad(); run = 0.0; nb = 0
431
+ for step, b in enumerate(tqdm(tr_loader, desc=f"[{INIT_MODE}] epoch {ep}")):
432
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
433
+ emos_p, cat_l, vad_p = forward_batch(b)
434
+ loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
435
+ scaler.scale(loss).backward()
436
+ if (step + 1) % ACCUM == 0:
437
+ scaler.step(opt); scaler.update(); opt.zero_grad()
438
+ run += loss.item() * ACCUM; nb += 1
439
+ m = evaluate(); sc = mean_srcc(m)
440
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
441
+ print(f"[{INIT_MODE}] epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
442
+ if sc > best:
443
+ best = sc
444
+ best_state = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
445
+ "heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
446
+ save_full(best_state, m["emos"]); bad = 0
447
+ print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})")
448
+ else:
449
+ bad += 1
450
+ if bad >= PATIENCE:
451
+ print(f"Early stop ở epoch {ep}."); break
452
+
453
+ if best_state:
454
+ wavlm.load_state_dict(best_state["wavlm"]); heads.load_state_dict(best_state["heads"])
455
+ final = evaluate()
456
+ print(f"\n✅ VAL (nội bộ) — exp12 INIT_MODE={INIT_MODE}:")
457
+ print(f" EMOS={final['emos']:.4f}", end="")
458
+ if HAS_VAD:
459
+ print(f" | VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}")
460
+ else:
461
+ print()
462
+ print(f" cat_err={final['cat_err']:.4f} | mean SRCC={mean_srcc(final):.4f}")
463
+ print(f" (Mốc DEV exp08 để tham khảo: EMOS {EXP08['emos']}, VAD {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
464
+ print(" ➜ GHI con số này vào bảng ablation 04_ rồi đổi INIT_MODE chạy lại để so 3 mode.")
465
+
466
+ # %% [markdown]
467
+ # ## 6. Dự đoán DEV → answer.txt (QMOS mượn exp07 / UTMOSv2)
468
+
469
+ # %%
470
+ def list_dev():
471
+ with open(DEV_SCP) as f:
472
+ return [ln.strip() for ln in f if ln.strip()]
473
+
474
+ dev_names = list_dev()
475
+ if LIMIT_DEV:
476
+ dev_names = dev_names[:LIMIT_DEV]
477
+ print("DEV:", len(dev_names), "mẫu")
478
+
479
+ def load_exp07_qmos():
480
+ if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
481
+ import csv
482
+ d = {}
483
+ with open(EXP07_ANSWER) as f:
484
+ for row in csv.DictReader(f):
485
+ d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
486
+ print(f"✅ Mượn QMOS exp07: {len(d)//2} wav")
487
+ return d
488
+ return None
489
+
490
+ qmos_map = load_exp07_qmos()
491
+ if qmos_map is None:
492
+ print("ℹ️ Không có exp07 → QMOS bằng UTMOSv2.")
493
+ pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
494
+ import utmosv2
495
+ v2 = utmosv2.create_model(pretrained=True)
496
+ qmos_map = {}
497
+ for n in tqdm(dev_names, desc="UTMOSv2"):
498
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
499
+ if os.path.exists(wav):
500
+ o = v2.predict(input_path=wav)
501
+ qmos_map[n] = float(o["predicted_mos"]) if isinstance(o, dict) else float(o)
502
+ del v2; torch.cuda.empty_cache() if device == "cuda" else None
503
+
504
+ @torch.no_grad()
505
+ def predict_emotion(sid):
506
+ wave = load_wav(sid)
507
+ if wave is None:
508
+ return None
509
+ wavlm.eval(); heads.eval()
510
+ iv = torch.from_numpy(wave).unsqueeze(0).to(device)
511
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
512
+ tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
513
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
514
+ feat = wavlm_embed(iv, am)
515
+ emos_p, cat_l, vad_p = heads(feat, tgt)
516
+ emos = float(emos_p.item()) * emos_sd + emos_mu
517
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
518
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
519
+ return emos, cat5, vad3
520
+
521
+ def fmt_cat(p5):
522
+ return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
523
+
524
+ answer_path = os.path.join(OUT_DIR, f"answer_{INIT_MODE}.txt")
525
+ n_real = n_def = 0
526
+ with open(answer_path, "w") as f:
527
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
528
+ for name in tqdm(dev_names, desc=f"answer[{INIT_MODE}]"):
529
+ sid = stem(name)
530
+ pr = predict_emotion(sid)
531
+ if pr is None:
532
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
533
+ else:
534
+ emos, cat5, vad3 = pr; n_real += 1
535
+ qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
536
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
537
+ print(f"Ghi {len(dev_names)} dòng → {answer_path} | thật {n_real}, mặc định {n_def}")
538
+
539
+ # %% [markdown]
540
+ # ## 7. Validate + zip
541
+
542
+ # %%
543
+ def validate(path):
544
+ import csv
545
+ with open(path) as f:
546
+ rows = list(csv.reader(f))
547
+ assert rows[0][0] == "wav" and "QMOS" in rows[0], "Header sai"
548
+ for i, r in enumerate(rows[1:], 2):
549
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
550
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
551
+
552
+ validate(answer_path)
553
+ os.system(f"cd {OUT_DIR} && cp answer_{INIT_MODE}.txt answer.txt && zip -j submission_track2_exp12_{INIT_MODE}.zip answer.txt && unzip -l submission_track2_exp12_{INIT_MODE}.zip")
554
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, f"submission_track2_exp12_{INIT_MODE}.zip"))
555
+
556
+ # %% [markdown]
557
+ # ## Ghi chú
558
+ # - **Chạy 3 lần** đổi `INIT_MODE` ("scratch"→"base"→"sailer"), ghi `mean SRCC` mỗi lần vào BẢNG ABLATION
559
+ # trong `docs/04_experiments_log.md` → trả lời mentor bằng số: from-scratch tốt hơn fine-tune không?
560
+ # - **scratch nặng:** mở băng toàn bộ WavLM-large. Nếu OOM → giảm `BATCH` (4→2), `MAX_SECONDS` (6→5),
561
+ # hoặc đổi sang `microsoft/wavlm-base-plus` (sửa cell 2) cho khả thi (lưu ý: khác kiến trúc → so kém công bằng hơn).
562
+ # - **scratch chậm + cần nhiều epoch hơn** (random init): để `EPOCHS=15`, `PATIENCE=5`. Vẫn nhiều khả năng < base/sailer.
563
+ # - **Đừng nhầm VAL nội bộ với DEV.** So 3 mode bằng VAL nội bộ đã đủ kết luận; muốn chắc thì nộp mode tốt nhất.
564
+ # - Checkpoint lưu `ft_wavlm_<mode>.pt`. Save Version sau mỗi lần chạy.
track2/exp13_finetune_qmos.ipynb ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "d3c827cb",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp13 (FINE-TUNE UTMOS cho QMOS) + answer 6 cột — Kaggle\n",
9
+ "\n",
10
+ "**Mục tiêu:** QMOS hiện tốt nhất = 0.548 (exp07, head ĐÓNG BĂNG + neo UTMOS). exp13 thử **MỞ BĂNG\n",
11
+ "(fine-tune) thẳng UTMOS** trên nhãn `qMOS` thật của Track 2 → kéo model chất lượng về đúng domain giọng\n",
12
+ "cảm xúc. Sau đó **mượn 5 cột cảm xúc từ checkpoint exp08** (`ft_emotion_full_20epoch.pt` — bản TỐT NHẤT)\n",
13
+ "→ ghép `answer.txt` 6 cột.\n",
14
+ "\n",
15
+ "## Vì sao fine-tune UTMOS (không phải UTMOSv2)\n",
16
+ "- UTMOS (`utmos22_strong`, tarepan/SpeechMOS) = **1 model đơn**, tải qua `torch.hub`, **bản thân đã dự đoán\n",
17
+ " QMOS** → warm-start hoàn hảo cho cột chất lượng (khác UTMOSv2 = ensemble nhiều fold + 2 luồng → khó train).\n",
18
+ "- forward: `model(wave[B,T], sr) -> MOS[B]`, là `nn.Module` chuẩn → backprop được toàn model.\n",
19
+ "- **Không dùng neo UTMOS riêng** (đã chốt): khi fine-tune chính UTMOS thì \"neo\" nằm sẵn trong trọng số\n",
20
+ " warm-start → head/neo ngoài là thừa.\n",
21
+ "\n",
22
+ "## Thiết kế\n",
23
+ "```\n",
24
+ " [PHẦN A] wav ─► UTMOS (utmos22_strong, TRAINABLE, warm-start pretrained) ─► QMOS (train trên qMOS gold)\n",
25
+ " [PHẦN B] wav ─► WavLM(exp08 ft) + audeering(frozen) ─► EMOS/CAT/VAD (NẠP ckpt, chỉ inference)\n",
26
+ " [PHẦN C] ghép QMOS(A) + 5 cột cảm xúc(B) ─► answer.txt 6 cột ─► validate ─► zip\n",
27
+ "```\n",
28
+ "\n",
29
+ "## ⚠️ Phải biết trước\n",
30
+ "- Fine-tune = **không cache** (mỗi epoch chạy lại UTMOS forward+backward) → tốn giờ GPU. **Lần đầu BẮT BUỘC\n",
31
+ " `LIMIT_TRAIN=300`, `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.\n",
32
+ "- Lưới an toàn: chỉ nộp QMOS fine-tune nếu **SRCC val nội bộ > zero-shot UTMOS** (mục A in cả 2 số).\n",
33
+ "- **Lưu checkpoint `ft_qmos_utmos.pt` mỗi best + Save Version NGAY** (bài học exp08: kernel chết là mất).\n",
34
+ "\n",
35
+ "**Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input (1) dataset Track 2, (2) dataset chứa\n",
36
+ "`ft_emotion_full.pt` (exp08), (3) tùy chọn cache `aud_dev.npz` → sửa slug cell 0 → Run All."
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "markdown",
41
+ "id": "6a78806d",
42
+ "metadata": {},
43
+ "source": [
44
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": null,
50
+ "id": "1374fa7d",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "import os, shutil, glob\n",
55
+ "\n",
56
+ "# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets/train.csv + wav/ + metadata.csv) ──\n",
57
+ "def find_data_root(search_root=\"/kaggle/input\"):\n",
58
+ " cands = []\n",
59
+ " for train_csv in glob.glob(os.path.join(search_root, \"**\", \"sets\", \"train.csv\"), recursive=True):\n",
60
+ " root = os.path.dirname(os.path.dirname(train_csv)) # .../<root>/sets/train.csv → <root>\n",
61
+ " score = os.path.isdir(os.path.join(root, \"wav\")) + os.path.exists(os.path.join(root, \"metadata.csv\"))\n",
62
+ " cands.append((score, root))\n",
63
+ " cands.sort(reverse=True) # ưu tiên thư mục đủ wav + metadata\n",
64
+ " return cands\n",
65
+ "\n",
66
+ "_cands = find_data_root(\"/kaggle/input\")\n",
67
+ "if _cands:\n",
68
+ " print(\"🔎 Ứng viên DATA_ROOT (điểm cao = đủ wav+metadata):\")\n",
69
+ " for sc, r in _cands:\n",
70
+ " print(f\" [{sc}/2] {r}\")\n",
71
+ " DATA_ROOT = _cands[0][1]\n",
72
+ " print(f\"👉 Tự chọn DATA_ROOT = {DATA_ROOT}\")\n",
73
+ "else:\n",
74
+ " DATA_ROOT = \"/kaggle/input/datasets/minhtoan2\" # dự phòng — sửa tay nếu auto-dò không thấy\n",
75
+ " print(f\"❌ Không thấy sets/train.csv trong /kaggle/input → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)\")\n",
76
+ "\n",
77
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
78
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (cho cột cảm xúc)\n",
79
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
80
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
81
+ "\n",
82
+ "# ── Checkpoint cảm xúc exp08 (để sinh 5 cột EMOS/CAT/VAD) ─────────────────────\n",
83
+ "# ⭐ TỐT NHẤT = ft_emotion_full_20epoch.pt (bản 20 epoch) — dùng bản này, KHÔNG dùng ft_emotion_full.pt.\n",
84
+ "EMO_CKPT = \"/kaggle/input/ft-emotion-full/ft_emotion_full_20epoch.pt\" # << ckpt exp08 20ep (CÓ backbone WavLM)\n",
85
+ "CACHE_INPUT = \"/kaggle/input/ft-emotion-cache\" # << (tùy chọn) thư mục chứa aud_dev.npz; \"\" nếu không có\n",
86
+ "\n",
87
+ "OUT_DIR = \"/kaggle/working\"\n",
88
+ "CACHE_DIR = \"/kaggle/working/ft_cache\" # /kaggle/input read-only → copy cache audeering sang đây\n",
89
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
90
+ "\n",
91
+ "# ── PHẦN A: fine-tune UTMOS (QMOS) ───────────────────────────────────────────\n",
92
+ "DEVICE = \"cuda\"\n",
93
+ "SR = 16000\n",
94
+ "QMOS_MAX_SEC = 12 # cắt audio chặn bộ nhớ backprop (UTMOS); OOM thì giảm 10/8\n",
95
+ "LR = 1e-5 # LR nhỏ cho fine-tune (warm-start sẵn tốt)\n",
96
+ "WEIGHT_DECAY = 1e-5\n",
97
+ "EPOCHS = 10 # TRẦN; early-stop quyết số epoch thật\n",
98
+ "PATIENCE = 3\n",
99
+ "BATCH = 1 # UTMOS forward KHÔNG có attention-mask → BATCH=1 an toàn (pad zero sẽ lệch pooling)\n",
100
+ "ACCUM = 16 # effective batch = BATCH*ACCUM = 16\n",
101
+ "VAL_FRAC = 0.10\n",
102
+ "SEED = 42\n",
103
+ "USE_AMP = True\n",
104
+ "RANK_LAMBDA = 0.0 # 0 = chỉ MSE. >0 (vd 0.3) = cộng pairwise ranking loss (tối ưu thẳng thứ hạng=SRCC)\n",
105
+ "FREEZE_FEAT_EXT = True # đóng băng feature-extractor (CNN conv) của UTMOS → đỡ VRAM + chống overfit\n",
106
+ "\n",
107
+ "# ── PHẦN B: inference cảm xúc (PHẢI khớp kiến trúc exp08) ─────────────────────\n",
108
+ "EMO_MAX_SEC = 8\n",
109
+ "UNFREEZE_TOP_LAYERS = 6 # khớp ckpt exp08\n",
110
+ "TRUNK_HIDDEN = 512\n",
111
+ "HEAD_HIDDEN = 128\n",
112
+ "DROPOUT = 0.3\n",
113
+ "USE_AUDEERING = True # khớp ckpt exp08\n",
114
+ "\n",
115
+ "LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
116
+ "LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
117
+ "\n",
118
+ "# Mốc QMOS để so (leaderboard DEV)\n",
119
+ "QMOS_BASELINE = {\"utmos_zeroshot\": 0.414, \"exp07_head\": 0.548}\n",
120
+ "\n",
121
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
122
+ "_EMO_ALIAS = {\n",
123
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
124
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
125
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
126
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
127
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
128
+ "}\n",
129
+ "\n",
130
+ "def norm_emotion(label):\n",
131
+ " key = str(label).strip().lower()\n",
132
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
133
+ "\n",
134
+ "def stem(p):\n",
135
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
136
+ "\n",
137
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
138
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP, EMO_CKPT]:\n",
139
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
140
+ "print(f\"Fine-tune UTMOS: LR {LR} · BATCH {BATCH}×ACCUM {ACCUM} · MAX {QMOS_MAX_SEC}s · rank λ {RANK_LAMBDA}\")\n",
141
+ "\n",
142
+ "# Copy cache audeering (aud_dev.npz) từ input read-only sang working (để cột cảm xúc khỏi trích lại)\n",
143
+ "if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
144
+ " n = 0\n",
145
+ " for fn in os.listdir(CACHE_INPUT):\n",
146
+ " if fn.startswith(\"aud_\") and fn.endswith(\".npz\"):\n",
147
+ " shutil.copy(os.path.join(CACHE_INPUT, fn), os.path.join(CACHE_DIR, fn)); n += 1\n",
148
+ " print(f\"📦 Copy {n} file cache audeering từ {CACHE_INPUT} → {CACHE_DIR}\")\n",
149
+ "else:\n",
150
+ " print(\"ℹ️ Không có CACHE_INPUT → sẽ tự trích audeering cho DEV (chậm hơn lần đầu).\")"
151
+ ]
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "id": "5e568431",
156
+ "metadata": {},
157
+ "source": [
158
+ "## 1. Cài đặt"
159
+ ]
160
+ },
161
+ {
162
+ "cell_type": "code",
163
+ "execution_count": null,
164
+ "id": "731d1056",
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "import sys, subprocess\n",
169
+ "\n",
170
+ "def pip_install(*pkgs):\n",
171
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
172
+ "\n",
173
+ "pip_install(\"speechmos\", \"loralib\", \"speechbrain\", \"librosa\", \"soundfile\",\n",
174
+ " \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
175
+ "\n",
176
+ "# Code SAILER (để dựng đúng kiến trúc WavLM rồi nạp ckpt exp08 đè lên) — chỉ cần cho PHẦN B\n",
177
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
178
+ "if not os.path.exists(REPO_DIR):\n",
179
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
180
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
181
+ "if REPO_DIR not in sys.path:\n",
182
+ " sys.path.insert(0, REPO_DIR)"
183
+ ]
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "id": "732b81f8",
188
+ "metadata": {},
189
+ "source": [
190
+ "## 2. Nhãn vàng qMOS (gộp trung bình theo wav) — như exp06/exp09a"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "id": "8cfb94af",
197
+ "metadata": {},
198
+ "outputs": [],
199
+ "source": [
200
+ "import numpy as np\n",
201
+ "import pandas as pd\n",
202
+ "\n",
203
+ "def load_qmos_labels():\n",
204
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
205
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
206
+ " wav_col = cols.get(\"wavid\") or cols.get(\"wav\") or list(df.columns)[1]\n",
207
+ " qmos_col = cols.get(\"qmos\") or cols.get(\"mos\")\n",
208
+ " assert qmos_col, f\"Không thấy cột qMOS (cột: {list(df.columns)})\"\n",
209
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
210
+ " g = df.groupby(\"_stem\")[qmos_col].mean()\n",
211
+ " return {s: float(v) for s, v in g.items()}\n",
212
+ "\n",
213
+ "qmos_gold = load_qmos_labels()\n",
214
+ "print(f\"Số wav train có nhãn qMOS: {len(qmos_gold)}\")\n",
215
+ "_vals = np.array(list(qmos_gold.values()))\n",
216
+ "print(f\"qMOS gold: mean {_vals.mean():.3f} · std {_vals.std():.3f} · min {_vals.min():.2f} · max {_vals.max():.2f}\")"
217
+ ]
218
+ },
219
+ {
220
+ "cell_type": "markdown",
221
+ "id": "374534a0",
222
+ "metadata": {},
223
+ "source": [
224
+ "## 3. PHẦN A — Fine-tune UTMOS trên qMOS\n",
225
+ "UTMOS xuất MOS thang ~1–5 (đã warm-start) → train MSE trên thang GỐC (không z-score, để giữ ý nghĩa warm-start).\n",
226
+ "`BATCH=1` + grad-accum: tránh phải pad (UTMOS forward không nhận attention-mask)."
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": null,
232
+ "id": "4636d35c",
233
+ "metadata": {
234
+ "lines_to_next_cell": 1
235
+ },
236
+ "outputs": [],
237
+ "source": [
238
+ "import torch\n",
239
+ "import torch.nn as nn\n",
240
+ "import librosa\n",
241
+ "from tqdm.auto import tqdm\n",
242
+ "from scipy.stats import spearmanr\n",
243
+ "from sklearn.model_selection import train_test_split\n",
244
+ "\n",
245
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
246
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
247
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
248
+ "\n",
249
+ "# Nạp UTMOS (torch.hub) — model nn.Module, forward(wave[B,T], sr) -> MOS[B]\n",
250
+ "utmos = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True).to(device)\n",
251
+ "n_all = sum(p.numel() for p in utmos.parameters())\n",
252
+ "\n",
253
+ "# (tùy chọn) đóng băng feature-extractor (các lớp conv trích đặc trưng) → đỡ VRAM + chống overfit\n",
254
+ "if FREEZE_FEAT_EXT:\n",
255
+ " n_frozen = 0\n",
256
+ " for name, p in utmos.named_parameters():\n",
257
+ " if \"feature_extractor\" in name or \"feature_projection\" in name or \"conv\" in name.lower():\n",
258
+ " p.requires_grad = False; n_frozen += p.numel()\n",
259
+ " print(f\"❄️ Đóng băng feature-extractor: {n_frozen/1e6:.1f}M / {n_all/1e6:.1f}M param\")\n",
260
+ "n_train = sum(p.numel() for p in utmos.parameters() if p.requires_grad)\n",
261
+ "print(f\"UTMOS: {n_all/1e6:.1f}M param tổng · {n_train/1e6:.1f}M param sẽ train\")\n",
262
+ "\n",
263
+ "def load_wav_qmos(sid):\n",
264
+ " p = os.path.join(WAV_DIR, sid + \".wav\")\n",
265
+ " if not os.path.exists(p):\n",
266
+ " return None\n",
267
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
268
+ " return wave[: QMOS_MAX_SEC * SR].astype(np.float32)\n",
269
+ "\n",
270
+ "# Tập train QMOS: chỉ wav tồn tại trên đĩa\n",
271
+ "train_stems_q = [s for s in qmos_gold if os.path.exists(os.path.join(WAV_DIR, s + \".wav\"))]\n",
272
+ "np.random.shuffle(train_stems_q)\n",
273
+ "if LIMIT_TRAIN:\n",
274
+ " train_stems_q = train_stems_q[:LIMIT_TRAIN]\n",
275
+ "tr_q, va_q = train_test_split(train_stems_q, test_size=VAL_FRAC, random_state=SEED)\n",
276
+ "print(f\"QMOS train: {len(tr_q)} · val nội bộ: {len(va_q)}\")\n",
277
+ "\n",
278
+ "opt = torch.optim.AdamW([p for p in utmos.parameters() if p.requires_grad],\n",
279
+ " lr=LR, weight_decay=WEIGHT_DECAY)\n",
280
+ "scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
281
+ "mse = nn.MSELoss()\n",
282
+ "\n",
283
+ "def utmos_forward(wave_np):\n",
284
+ " \"\"\"1 wav numpy -> MOS scalar tensor (giữ grad).\"\"\"\n",
285
+ " x = torch.from_numpy(wave_np).unsqueeze(0).to(device) # [1, T]\n",
286
+ " out = utmos(x, SR) # [1] (hoặc [1,?])\n",
287
+ " return out.reshape(-1).mean() # scalar an toàn mọi shape\n",
288
+ "\n",
289
+ "def pairwise_rank_loss(preds, targets):\n",
290
+ " \"\"\"Hinge ranking trên các cặp trong 1 nhóm (khuyến khích đúng thứ hạng = tối ưu SRCC).\"\"\"\n",
291
+ " p = torch.stack(preds); t = torch.tensor(targets, device=device, dtype=torch.float32)\n",
292
+ " if len(p) < 2:\n",
293
+ " return torch.zeros((), device=device)\n",
294
+ " sign = torch.sign(t.unsqueeze(0) - t.unsqueeze(1))\n",
295
+ " diff = p.unsqueeze(0) - p.unsqueeze(1)\n",
296
+ " return torch.relu(-sign * diff).mean()\n",
297
+ "\n",
298
+ "@torch.no_grad()\n",
299
+ "def eval_qmos_val():\n",
300
+ " utmos.eval()\n",
301
+ " preds, gts = [], []\n",
302
+ " for s in va_q:\n",
303
+ " wave = load_wav_qmos(s)\n",
304
+ " if wave is None:\n",
305
+ " continue\n",
306
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
307
+ " preds.append(float(utmos_forward(wave).item()))\n",
308
+ " gts.append(qmos_gold[s])\n",
309
+ " return float(spearmanr(preds, gts).correlation)\n",
310
+ "\n",
311
+ "# Baseline ZERO-SHOT (trước khi train) trên CÙNG val → mốc phải vượt\n",
312
+ "srcc_zeroshot = eval_qmos_val()\n",
313
+ "print(f\"\\n📍 UTMOS zero-shot (val nội bộ): SRCC = {srcc_zeroshot:.4f} \"\n",
314
+ " f\"(leaderboard DEV ~{QMOS_BASELINE['utmos_zeroshot']}; exp07 head {QMOS_BASELINE['exp07_head']})\")\n",
315
+ "\n",
316
+ "CKPT_QMOS = os.path.join(OUT_DIR, \"ft_qmos_utmos.pt\")\n",
317
+ "def save_qmos_ckpt(srcc):\n",
318
+ " torch.save({\"utmos_state\": {k: v.cpu() for k, v in utmos.state_dict().items()},\n",
319
+ " \"val_srcc\": float(srcc), \"raw_scale\": True,\n",
320
+ " \"QMOS_MAX_SEC\": QMOS_MAX_SEC, \"FREEZE_FEAT_EXT\": FREEZE_FEAT_EXT}, CKPT_QMOS)\n",
321
+ "\n",
322
+ "best, best_state, bad = srcc_zeroshot, {k: v.cpu().clone() for k, v in utmos.state_dict().items()}, 0\n",
323
+ "save_qmos_ckpt(best) # lưu sẵn bản zero-shot (worst case vẫn = baseline)\n",
324
+ "\n",
325
+ "# Gom theo CỬA SỔ = ACCUM mẫu HỢP LỆ (micro). Hai chế độ backward:\n",
326
+ "# • RANK off (mặc định) → backward NGAY từng mẫu → đồ thị giải phóng liền → VRAM thấp.\n",
327
+ "# • RANK on → ranking cần SO các pred TRONG cửa sổ → PHẢI giữ đồ thị cả cửa sổ →\n",
328
+ "# gom MSE (win_loss) + pred (buf_p) rồi backward MỘT lần (MSE_mean + λ·rank).\n",
329
+ "# ⚠️ Lỗi cũ: backward MSE từng bước đã giải phóng đồ thị → rank_loss.backward() sau đó\n",
330
+ "# sẽ lỗi \"backward through the graph a second time\". Bản này gom rồi backward 1 lần → hết lỗi.\n",
331
+ "for ep in range(1, EPOCHS + 1):\n",
332
+ " utmos.train()\n",
333
+ " opt.zero_grad()\n",
334
+ " np.random.shuffle(tr_q)\n",
335
+ " run = 0.0; nb = 0\n",
336
+ " micro = 0; win_loss = None; buf_p, buf_t = [], []\n",
337
+ " for s in tqdm(tr_q, desc=f\"epoch {ep}\"):\n",
338
+ " wave = load_wav_qmos(s)\n",
339
+ " if wave is None:\n",
340
+ " continue\n",
341
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
342
+ " pred = utmos_forward(wave)\n",
343
+ " loss = mse(pred, torch.tensor(qmos_gold[s], device=device, dtype=pred.dtype))\n",
344
+ " run += float(loss.item()); nb += 1\n",
345
+ " if RANK_LAMBDA > 0:\n",
346
+ " win_loss = loss if win_loss is None else win_loss + loss # GIỮ đồ thị (không backward ngay)\n",
347
+ " buf_p.append(pred); buf_t.append(qmos_gold[s]); micro += 1\n",
348
+ " else:\n",
349
+ " scaler.scale(loss / ACCUM).backward(); micro += 1 # backward ngay → VRAM thấp\n",
350
+ " if micro == ACCUM:\n",
351
+ " if RANK_LAMBDA > 0:\n",
352
+ " total = win_loss / micro\n",
353
+ " if len(buf_p) >= 2:\n",
354
+ " total = total + RANK_LAMBDA * pairwise_rank_loss(buf_p, buf_t)\n",
355
+ " scaler.scale(total).backward()\n",
356
+ " scaler.step(opt); scaler.update(); opt.zero_grad()\n",
357
+ " micro = 0; win_loss = None; buf_p, buf_t = [], []\n",
358
+ " # flush cửa sổ dư cuối epoch (số mẫu không chia hết cho ACCUM)\n",
359
+ " if micro > 0:\n",
360
+ " if RANK_LAMBDA > 0:\n",
361
+ " total = win_loss / micro\n",
362
+ " if len(buf_p) >= 2:\n",
363
+ " total = total + RANK_LAMBDA * pairwise_rank_loss(buf_p, buf_t)\n",
364
+ " scaler.scale(total).backward()\n",
365
+ " scaler.step(opt); scaler.update(); opt.zero_grad()\n",
366
+ " sc = eval_qmos_val()\n",
367
+ " print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | val SRCC {sc:.4f} \"\n",
368
+ " f\"(zero-shot {srcc_zeroshot:.4f} · best {max(best,sc):.4f})\")\n",
369
+ " if sc > best:\n",
370
+ " best = sc\n",
371
+ " best_state = {k: v.cpu().clone() for k, v in utmos.state_dict().items()}\n",
372
+ " save_qmos_ckpt(best)\n",
373
+ " print(f\" 💾 lưu best → {CKPT_QMOS} (epoch {ep}, SRCC {sc:.4f})\")\n",
374
+ " bad = 0\n",
375
+ " else:\n",
376
+ " bad += 1\n",
377
+ " if bad >= PATIENCE:\n",
378
+ " print(f\"Early stop ở epoch {ep}.\"); break\n",
379
+ "\n",
380
+ "utmos.load_state_dict(best_state)\n",
381
+ "print(f\"\\n✅ PHẦN A xong — QMOS val nội bộ: zero-shot {srcc_zeroshot:.4f} → fine-tune {best:.4f} \"\n",
382
+ " + (\"🚀 cải thiện\" if best > srcc_zeroshot + 1e-4 else \"➖ KHÔNG vượt zero-shot\"))\n",
383
+ "if best <= srcc_zeroshot + 1e-4:\n",
384
+ " print(\" ⚠️ Fine-tune chưa vượt zero-shot → cân nhắc tăng EPOCHS / bật RANK_LAMBDA=0.3 / \"\n",
385
+ " \"mở băng feature-extractor (FREEZE_FEAT_EXT=False); hoặc giữ QMOS exp07 (0.548).\")"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "markdown",
390
+ "id": "d2448fa6",
391
+ "metadata": {},
392
+ "source": [
393
+ "## 4. PHẦN A (tiếp) — Dự đoán QMOS cho DEV bằng UTMOS đã fine-tune"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": null,
399
+ "id": "f8463ca5",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "def list_dev():\n",
404
+ " with open(DEV_SCP) as f:\n",
405
+ " return [ln.strip() for ln in f if ln.strip()]\n",
406
+ "\n",
407
+ "dev_names = list_dev()\n",
408
+ "if LIMIT_DEV:\n",
409
+ " dev_names = dev_names[:LIMIT_DEV]\n",
410
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
411
+ "\n",
412
+ "@torch.no_grad()\n",
413
+ "def predict_qmos(name):\n",
414
+ " p = os.path.join(WAV_DIR, name if str(name).endswith(\".wav\") else str(name) + \".wav\")\n",
415
+ " if not os.path.exists(p):\n",
416
+ " return None\n",
417
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
418
+ " wave = wave[: QMOS_MAX_SEC * SR].astype(np.float32)\n",
419
+ " utmos.eval()\n",
420
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
421
+ " v = float(utmos_forward(wave).item())\n",
422
+ " return float(np.clip(v, 1.0, 5.0))\n",
423
+ "\n",
424
+ "qmos_pred = {}\n",
425
+ "n_real = n_def = 0\n",
426
+ "for name in tqdm(dev_names, desc=\"QMOS dev\"):\n",
427
+ " v = predict_qmos(name)\n",
428
+ " if v is None:\n",
429
+ " v = 3.0; n_def += 1\n",
430
+ " else:\n",
431
+ " n_real += 1\n",
432
+ " qmos_pred[name] = v\n",
433
+ "print(f\"QMOS dự đoán: thật {n_real}, mặc định {n_def}\")\n",
434
+ "\n",
435
+ "# Giải phóng UTMOS trước khi nạp backbone cảm xúc (đỡ VRAM T4)\n",
436
+ "del utmos, opt, scaler\n",
437
+ "torch.cuda.empty_cache() if device == \"cuda\" else None"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "markdown",
442
+ "id": "b67b5d6b",
443
+ "metadata": {},
444
+ "source": [
445
+ "## 5. PHẦN B — Nạp ckpt exp08 (WavLM ft + audeering) → 5 cột cảm xúc cho DEV\n",
446
+ "Tái dùng nguyên cơ chế load của exp08b: dựng kiến trúc → `load_state_dict` từ `ft_emotion_full_20epoch.pt`."
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": null,
452
+ "id": "f98eca99",
453
+ "metadata": {
454
+ "lines_to_next_cell": 1
455
+ },
456
+ "outputs": [],
457
+ "source": [
458
+ "import torch.nn.functional as F\n",
459
+ "\n",
460
+ "ckpt = torch.load(EMO_CKPT, map_location=\"cpu\", weights_only=False) # ckpt có numpy → weights_only=False\n",
461
+ "assert \"wavlm\" in ckpt, (\"❌ EMO_CKPT không có 'wavlm' (backbone). Cần ft_emotion_full_20epoch.pt (bản đủ backbone), \"\n",
462
+ " \"KHÔNG phải ft_emotion_meta.pt cũ.\")\n",
463
+ "print(\"✅ Nạp ckpt cảm xúc:\", EMO_CKPT, \"| keys:\", list(ckpt.keys()))\n",
464
+ "\n",
465
+ "def find_hf_backbone(module):\n",
466
+ " cands = []\n",
467
+ " for nm, m in module.named_modules():\n",
468
+ " enc = getattr(m, \"encoder\", None)\n",
469
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
470
+ " and getattr(enc, \"layers\", None) is not None:\n",
471
+ " cands.append((nm, m))\n",
472
+ " if not cands:\n",
473
+ " return None, None\n",
474
+ " cands.sort(key=lambda x: sum(p.numel() for p in x[1].parameters()), reverse=True)\n",
475
+ " return cands[0]\n",
476
+ "\n",
477
+ "wavlm = None\n",
478
+ "try:\n",
479
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
480
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
481
+ " _name, wavlm = find_hf_backbone(_wrapper)\n",
482
+ " if wavlm is not None:\n",
483
+ " print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{_name}'\")\n",
484
+ "except Exception as e:\n",
485
+ " print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
486
+ "if wavlm is None:\n",
487
+ " from transformers import WavLMModel\n",
488
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
489
+ " print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
490
+ "\n",
491
+ "wavlm = wavlm.to(device).eval()\n",
492
+ "WAVLM_DIM = int(wavlm.config.hidden_size)\n",
493
+ "miss, unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
494
+ "print(f\"🔁 load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).\")\n",
495
+ "\n",
496
+ "def masked_mean(hidden, attn_mask):\n",
497
+ " if attn_mask is None:\n",
498
+ " return hidden.mean(dim=1)\n",
499
+ " try:\n",
500
+ " fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
501
+ " except Exception:\n",
502
+ " return hidden.mean(dim=1)\n",
503
+ " fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
504
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
505
+ "\n",
506
+ "@torch.no_grad()\n",
507
+ "def wavlm_embed(input_values, attn_mask):\n",
508
+ " out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
509
+ " return masked_mean(out, attn_mask)\n",
510
+ "\n",
511
+ "# ── audeering FROZEN (đặc trưng phụ) — như exp08 ──\n",
512
+ "AUD_DIM = 0\n",
513
+ "aud_backbone = aud_head = aud_proc = None\n",
514
+ "if USE_AUDEERING:\n",
515
+ " from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
516
+ " from huggingface_hub import hf_hub_download\n",
517
+ " AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
518
+ " aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
519
+ " aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
520
+ " aud_backbone = Wav2Vec2Model(aud_cfg)\n",
521
+ " try:\n",
522
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
523
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
524
+ " except Exception:\n",
525
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
526
+ " bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
527
+ " aud_backbone.load_state_dict(bb_sd, strict=False)\n",
528
+ " _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
529
+ " _out = _sd[\"classifier.out_proj.weight\"].shape[0]\n",
530
+ " aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
531
+ " aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
532
+ " aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
533
+ " aud_backbone = aud_backbone.to(device).eval()\n",
534
+ " aud_head = aud_head.to(device).eval()\n",
535
+ " AUD_DIM = _hid + 3\n",
536
+ " print(f\"✅ audeering frozen ({AUD_DIM}-D)\")\n",
537
+ "\n",
538
+ "def load_wav_emo(sid):\n",
539
+ " p = os.path.join(WAV_DIR, sid + \".wav\")\n",
540
+ " if not os.path.exists(p):\n",
541
+ " return None\n",
542
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
543
+ " return wave[: EMO_MAX_SEC * SR].astype(np.float32)\n",
544
+ "\n",
545
+ "@torch.no_grad()\n",
546
+ "def extract_audeering(stems, tag):\n",
547
+ " if not USE_AUDEERING:\n",
548
+ " return {}\n",
549
+ " cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
550
+ " store = {}\n",
551
+ " if os.path.exists(cache_path):\n",
552
+ " z = np.load(cache_path, allow_pickle=True)\n",
553
+ " store = {k: z[k] for k in z.files}\n",
554
+ " print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
555
+ " todo = [s for s in stems if s not in store]\n",
556
+ " for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
557
+ " wave = load_wav_emo(s)\n",
558
+ " if wave is None:\n",
559
+ " continue\n",
560
+ " x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
561
+ " x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
562
+ " h = aud_backbone(x)[0].mean(dim=1)\n",
563
+ " out = aud_head(h)[0].cpu().numpy()\n",
564
+ " vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
565
+ " store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
566
+ " if (i + 1) % 500 == 0:\n",
567
+ " np.savez(cache_path, **store)\n",
568
+ " if todo:\n",
569
+ " np.savez(cache_path, **store)\n",
570
+ " return store\n",
571
+ "\n",
572
+ "# ── EmoHeads (khớp exp08) + nạp trọng số head + thống kê chuẩn hóa từ ckpt ──\n",
573
+ "N_EMO = len(EMOTIONS5)\n",
574
+ "TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
575
+ "\n",
576
+ "class EmoHeads(nn.Module):\n",
577
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
578
+ " super().__init__()\n",
579
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
580
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
581
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
582
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
583
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
584
+ " def forward(self, feat, tgt):\n",
585
+ " h = self.trunk(feat)\n",
586
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
587
+ "\n",
588
+ "heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device).eval()\n",
589
+ "hmiss, hunexp = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
590
+ "print(f\"🔁 load heads từ ckpt: thiếu {len(hmiss)} / dư {len(hunexp)} key (kỳ vọng 0).\")\n",
591
+ "\n",
592
+ "emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
593
+ "vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
594
+ "print(f\"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
595
+ "\n",
596
+ "# Target cảm xúc (cho EMOS head) từ metadata\n",
597
+ "def load_target_emotions():\n",
598
+ " tgt = {}\n",
599
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
600
+ " for ln in f:\n",
601
+ " parts = ln.strip().split(\"|\")\n",
602
+ " if len(parts) >= 2:\n",
603
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
604
+ " return tgt\n",
605
+ "\n",
606
+ "target_map = load_target_emotions()\n",
607
+ "\n",
608
+ "def onehot_target(tgt):\n",
609
+ " v = np.zeros(N_EMO, dtype=np.float32)\n",
610
+ " if tgt in EMOTIONS5:\n",
611
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
612
+ " return v\n",
613
+ "\n",
614
+ "dev_stems = [stem(n) for n in dev_names]\n",
615
+ "aud_dev = extract_audeering(dev_stems, \"dev\")\n",
616
+ "\n",
617
+ "@torch.no_grad()\n",
618
+ "def predict_emotion(sid):\n",
619
+ " wave = load_wav_emo(sid)\n",
620
+ " if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
621
+ " return None\n",
622
+ " iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
623
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
624
+ " tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
625
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
626
+ " fw = wavlm_embed(iv, am)\n",
627
+ " feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw\n",
628
+ " emos_p, cat_l, vad_p = heads(feat, tgt)\n",
629
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
630
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
631
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
632
+ " return emos, cat5, vad3"
633
+ ]
634
+ },
635
+ {
636
+ "cell_type": "markdown",
637
+ "id": "f813bfaf",
638
+ "metadata": {},
639
+ "source": [
640
+ "## 6. PHẦN C — Ghép QMOS (fine-tune) + 5 cột cảm xúc (exp08) → answer.txt 6 cột"
641
+ ]
642
+ },
643
+ {
644
+ "cell_type": "code",
645
+ "execution_count": null,
646
+ "id": "f9fd3208",
647
+ "metadata": {
648
+ "lines_to_next_cell": 1
649
+ },
650
+ "outputs": [],
651
+ "source": [
652
+ "def fmt_cat(p5):\n",
653
+ " return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
654
+ "\n",
655
+ "def build_answer(out_path):\n",
656
+ " n_real = n_def = 0\n",
657
+ " with open(out_path, \"w\") as f:\n",
658
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
659
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
660
+ " sid = stem(name)\n",
661
+ " pr = predict_emotion(sid)\n",
662
+ " if pr is None:\n",
663
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
664
+ " else:\n",
665
+ " emos, cat5, vad3 = pr; n_real += 1\n",
666
+ " qmos = qmos_pred.get(name, qmos_pred.get(sid, 3.0))\n",
667
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
668
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
669
+ "\n",
670
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
671
+ "build_answer(answer_path)"
672
+ ]
673
+ },
674
+ {
675
+ "cell_type": "markdown",
676
+ "id": "3f78d7fe",
677
+ "metadata": {},
678
+ "source": [
679
+ "## 7. Validate + zip"
680
+ ]
681
+ },
682
+ {
683
+ "cell_type": "code",
684
+ "execution_count": null,
685
+ "id": "d873783b",
686
+ "metadata": {},
687
+ "outputs": [],
688
+ "source": [
689
+ "def validate(path):\n",
690
+ " import csv\n",
691
+ " with open(path) as f:\n",
692
+ " rows = list(csv.reader(f))\n",
693
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0] and \"EMOS\" in rows[0], \"Header sai\"\n",
694
+ " for i, r in enumerate(rows[1:], 2):\n",
695
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
696
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
697
+ "\n",
698
+ "validate(answer_path)\n",
699
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp13_ft-qmos.zip answer.txt \"\n",
700
+ " f\"&& unzip -l submission_track2_exp13_ft-qmos.zip\")\n",
701
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp13_ft-qmos.zip\"))"
702
+ ]
703
+ },
704
+ {
705
+ "cell_type": "markdown",
706
+ "id": "33290a62",
707
+ "metadata": {},
708
+ "source": [
709
+ "## Ghi chú\n",
710
+ "- **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để chạy trơn (không OOM, 1 epoch xong); rồi đặt `None`.\n",
711
+ "- **OOM trên T4?** giảm `QMOS_MAX_SEC` (12→10→8); giữ `FREEZE_FEAT_EXT=True`; `BATCH=1` đã là min.\n",
712
+ " ⚠️ **Bật `RANK_LAMBDA>0` tốn VRAM hơn** vì phải GIỮ đồ thị cả cửa sổ ACCUM (=16) để so thứ hạng →\n",
713
+ " nếu OOM khi bật ranking: giảm `ACCUM` (vd 8, cũng là kích thước nhóm ranking) hoặc `QMOS_MAX_SEC`.\n",
714
+ "- **Đọc mục A:** so `val SRCC fine-tune` với `zero-shot`. Chỉ nộp QMOS fine-tune nếu **vượt zero-shot**\n",
715
+ " (lý tưởng vượt cả exp07 0.548); nếu không → giữ QMOS exp07 (Add Input answer.txt exp07, đổi cột QMOS).\n",
716
+ "- Nếu chưa vượt: tăng `EPOCHS`, bật `RANK_LAMBDA=0.3` (tối ưu thẳng thứ hạng), hoặc `FREEZE_FEAT_EXT=False`\n",
717
+ " (mở băng feature-extractor — mạnh hơn nhưng dễ overfit + nặng VRAM).\n",
718
+ "- **Lưu checkpoint:** `ft_qmos_utmos.pt` lưu mỗi best → **Save Version NGAY** sau khi chạy (bài học exp08).\n",
719
+ "- **License QMOS:** UTMOS/SpeechMOS (kiểm tra license tarepan/SpeechMOS) — khai báo `docs/12_system_description.md`.\n",
720
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp13)."
721
+ ]
722
+ }
723
+ ],
724
+ "metadata": {
725
+ "jupytext": {
726
+ "cell_metadata_filter": "-all",
727
+ "main_language": "python",
728
+ "notebook_metadata_filter": "-all"
729
+ }
730
+ },
731
+ "nbformat": 4,
732
+ "nbformat_minor": 5
733
+ }
track2/exp13_finetune_qmos_pipeline.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp13 (FINE-TUNE UTMOS cho QMOS) + answer 6 cột — Kaggle
3
+ #
4
+ # **Mục tiêu:** QMOS hiện tốt nhất = 0.548 (exp07, head ĐÓNG BĂNG + neo UTMOS). exp13 thử **MỞ BĂNG
5
+ # (fine-tune) thẳng UTMOS** trên nhãn `qMOS` thật của Track 2 → kéo model chất lượng về đúng domain giọng
6
+ # cảm xúc. Sau đó **mượn 5 cột cảm xúc từ checkpoint exp08** (`ft_emotion_full_20epoch.pt` — bản TỐT NHẤT)
7
+ # → ghép `answer.txt` 6 cột.
8
+ #
9
+ # ## Vì sao fine-tune UTMOS (không phải UTMOSv2)
10
+ # - UTMOS (`utmos22_strong`, tarepan/SpeechMOS) = **1 model đơn**, tải qua `torch.hub`, **bản thân đã dự đoán
11
+ # QMOS** → warm-start hoàn hảo cho cột chất lượng (khác UTMOSv2 = ensemble nhiều fold + 2 luồng → khó train).
12
+ # - forward: `model(wave[B,T], sr) -> MOS[B]`, là `nn.Module` chuẩn → backprop được toàn model.
13
+ # - **Không dùng neo UTMOS riêng** (đã chốt): khi fine-tune chính UTMOS thì "neo" nằm sẵn trong trọng số
14
+ # warm-start → head/neo ngoài là thừa.
15
+ #
16
+ # ## Thiết kế
17
+ # ```
18
+ # [PHẦN A] wav ─► UTMOS (utmos22_strong, TRAINABLE, warm-start pretrained) ─► QMOS (train trên qMOS gold)
19
+ # [PHẦN B] wav ─► WavLM(exp08 ft) + audeering(frozen) ─► EMOS/CAT/VAD (NẠP ckpt, chỉ inference)
20
+ # [PHẦN C] ghép QMOS(A) + 5 cột cảm xúc(B) ─► answer.txt 6 cột ─► validate ─► zip
21
+ # ```
22
+ #
23
+ # ## ⚠️ Phải biết trước
24
+ # - Fine-tune = **không cache** (mỗi epoch chạy lại UTMOS forward+backward) → tốn giờ GPU. **Lần đầu BẮT BUỘC
25
+ # `LIMIT_TRAIN=300`, `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.
26
+ # - Lưới an toàn: chỉ nộp QMOS fine-tune nếu **SRCC val nội bộ > zero-shot UTMOS** (mục A in cả 2 số).
27
+ # - **Lưu checkpoint `ft_qmos_utmos.pt` mỗi best + Save Version NGAY** (bài học exp08: kernel chết là mất).
28
+ #
29
+ # **Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input (1) dataset Track 2, (2) dataset chứa
30
+ # `ft_emotion_full.pt` (exp08), (3) tùy chọn cache `aud_dev.npz` → sửa slug cell 0 → Run All.
31
+
32
+ # %% [markdown]
33
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
34
+
35
+ # %%
36
+ import os, shutil, glob
37
+
38
+ # ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets/train.csv + wav/ + metadata.csv) ──
39
+ def find_data_root(search_root="/kaggle/input"):
40
+ cands = []
41
+ for train_csv in glob.glob(os.path.join(search_root, "**", "sets", "train.csv"), recursive=True):
42
+ root = os.path.dirname(os.path.dirname(train_csv)) # .../<root>/sets/train.csv → <root>
43
+ score = os.path.isdir(os.path.join(root, "wav")) + os.path.exists(os.path.join(root, "metadata.csv"))
44
+ cands.append((score, root))
45
+ cands.sort(reverse=True) # ưu tiên thư mục đủ wav + metadata
46
+ return cands
47
+
48
+ _cands = find_data_root("/kaggle/input")
49
+ if _cands:
50
+ print("🔎 Ứng viên DATA_ROOT (điểm cao = đủ wav+metadata):")
51
+ for sc, r in _cands:
52
+ print(f" [{sc}/2] {r}")
53
+ DATA_ROOT = _cands[0][1]
54
+ print(f"👉 Tự chọn DATA_ROOT = {DATA_ROOT}")
55
+ else:
56
+ DATA_ROOT = "/kaggle/input/datasets/minhtoan2" # dự phòng — sửa tay nếu auto-dò không thấy
57
+ print(f"❌ Không thấy sets/train.csv trong /kaggle/input → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)")
58
+
59
+ WAV_DIR = f"{DATA_ROOT}/wav"
60
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (cho cột cảm xúc)
61
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
62
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
63
+
64
+ # ── Checkpoint cảm xúc exp08 (để sinh 5 cột EMOS/CAT/VAD) ─────────────────────
65
+ # ⭐ TỐT NHẤT = ft_emotion_full_20epoch.pt (bản 20 epoch) — dùng bản này, KHÔNG dùng ft_emotion_full.pt.
66
+ EMO_CKPT = "/kaggle/input/ft-emotion-full/ft_emotion_full_20epoch.pt" # << ckpt exp08 20ep (CÓ backbone WavLM)
67
+ CACHE_INPUT = "/kaggle/input/ft-emotion-cache" # << (tùy chọn) thư mục chứa aud_dev.npz; "" nếu không có
68
+
69
+ OUT_DIR = "/kaggle/working"
70
+ CACHE_DIR = "/kaggle/working/ft_cache" # /kaggle/input read-only → copy cache audeering sang đây
71
+ os.makedirs(CACHE_DIR, exist_ok=True)
72
+
73
+ # ── PHẦN A: fine-tune UTMOS (QMOS) ───────────────────────────────────────────
74
+ DEVICE = "cuda"
75
+ SR = 16000
76
+ QMOS_MAX_SEC = 12 # cắt audio chặn bộ nhớ backprop (UTMOS); OOM thì giảm 10/8
77
+ LR = 1e-5 # LR nhỏ cho fine-tune (warm-start sẵn tốt)
78
+ WEIGHT_DECAY = 1e-5
79
+ EPOCHS = 10 # TRẦN; early-stop quyết số epoch thật
80
+ PATIENCE = 3
81
+ BATCH = 1 # UTMOS forward KHÔNG có attention-mask → BATCH=1 an toàn (pad zero sẽ lệch pooling)
82
+ ACCUM = 16 # effective batch = BATCH*ACCUM = 16
83
+ VAL_FRAC = 0.10
84
+ SEED = 42
85
+ USE_AMP = True
86
+ RANK_LAMBDA = 0.0 # 0 = chỉ MSE. >0 (vd 0.3) = cộng pairwise ranking loss (tối ưu thẳng thứ hạng=SRCC)
87
+ FREEZE_FEAT_EXT = True # đóng băng feature-extractor (CNN conv) của UTMOS → đỡ VRAM + chống overfit
88
+
89
+ # ── PHẦN B: inference cảm xúc (PHẢI khớp kiến trúc exp08) ─────────────────────
90
+ EMO_MAX_SEC = 8
91
+ UNFREEZE_TOP_LAYERS = 6 # khớp ckpt exp08
92
+ TRUNK_HIDDEN = 512
93
+ HEAD_HIDDEN = 128
94
+ DROPOUT = 0.3
95
+ USE_AUDEERING = True # khớp ckpt exp08
96
+
97
+ LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
98
+ LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
99
+
100
+ # Mốc QMOS để so (leaderboard DEV)
101
+ QMOS_BASELINE = {"utmos_zeroshot": 0.414, "exp07_head": 0.548}
102
+
103
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
104
+ _EMO_ALIAS = {
105
+ "angry": "angry", "anger": "angry",
106
+ "happy": "happy", "happiness": "happy", "joy": "happy",
107
+ "neutral": "neutral", "calm": "neutral",
108
+ "sad": "sad", "sadness": "sad",
109
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
110
+ }
111
+
112
+ def norm_emotion(label):
113
+ key = str(label).strip().lower()
114
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
115
+
116
+ def stem(p):
117
+ return os.path.splitext(os.path.basename(str(p)))[0]
118
+
119
+ print("DATA_ROOT:", DATA_ROOT)
120
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP, EMO_CKPT]:
121
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
122
+ print(f"Fine-tune UTMOS: LR {LR} · BATCH {BATCH}×ACCUM {ACCUM} · MAX {QMOS_MAX_SEC}s · rank λ {RANK_LAMBDA}")
123
+
124
+ # Copy cache audeering (aud_dev.npz) từ input read-only sang working (để cột cảm xúc khỏi trích lại)
125
+ if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
126
+ n = 0
127
+ for fn in os.listdir(CACHE_INPUT):
128
+ if fn.startswith("aud_") and fn.endswith(".npz"):
129
+ shutil.copy(os.path.join(CACHE_INPUT, fn), os.path.join(CACHE_DIR, fn)); n += 1
130
+ print(f"📦 Copy {n} file cache audeering từ {CACHE_INPUT} → {CACHE_DIR}")
131
+ else:
132
+ print("ℹ️ Không có CACHE_INPUT → sẽ tự trích audeering cho DEV (chậm hơn lần đầu).")
133
+
134
+ # %% [markdown]
135
+ # ## 1. Cài đặt
136
+
137
+ # %%
138
+ import sys, subprocess
139
+
140
+ def pip_install(*pkgs):
141
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
142
+
143
+ pip_install("speechmos", "loralib", "speechbrain", "librosa", "soundfile",
144
+ "scipy", "scikit-learn", "pandas", "tqdm")
145
+
146
+ # Code SAILER (để dựng đúng kiến trúc WavLM rồi nạp ckpt exp08 đè lên) — chỉ cần cho PHẦN B
147
+ REPO_DIR = "/kaggle/working/vox-profile-release"
148
+ if not os.path.exists(REPO_DIR):
149
+ subprocess.run(["git", "clone", "--depth", "1",
150
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
151
+ if REPO_DIR not in sys.path:
152
+ sys.path.insert(0, REPO_DIR)
153
+
154
+ # %% [markdown]
155
+ # ## 2. Nhãn vàng qMOS (gộp trung bình theo wav) — như exp06/exp09a
156
+
157
+ # %%
158
+ import numpy as np
159
+ import pandas as pd
160
+
161
+ def load_qmos_labels():
162
+ df = pd.read_csv(TRAIN_CSV, sep="|")
163
+ cols = {c.lower().strip(): c for c in df.columns}
164
+ wav_col = cols.get("wavid") or cols.get("wav") or list(df.columns)[1]
165
+ qmos_col = cols.get("qmos") or cols.get("mos")
166
+ assert qmos_col, f"Không thấy cột qMOS (cột: {list(df.columns)})"
167
+ df["_stem"] = df[wav_col].map(stem)
168
+ g = df.groupby("_stem")[qmos_col].mean()
169
+ return {s: float(v) for s, v in g.items()}
170
+
171
+ qmos_gold = load_qmos_labels()
172
+ print(f"Số wav train có nhãn qMOS: {len(qmos_gold)}")
173
+ _vals = np.array(list(qmos_gold.values()))
174
+ print(f"qMOS gold: mean {_vals.mean():.3f} · std {_vals.std():.3f} · min {_vals.min():.2f} · max {_vals.max():.2f}")
175
+
176
+ # %% [markdown]
177
+ # ## 3. PHẦN A — Fine-tune UTMOS trên qMOS
178
+ # UTMOS xuất MOS thang ~1–5 (đã warm-start) → train MSE trên thang GỐC (không z-score, để giữ ý nghĩa warm-start).
179
+ # `BATCH=1` + grad-accum: tránh phải pad (UTMOS forward không nhận attention-mask).
180
+
181
+ # %%
182
+ import torch
183
+ import torch.nn as nn
184
+ import librosa
185
+ from tqdm.auto import tqdm
186
+ from scipy.stats import spearmanr
187
+ from sklearn.model_selection import train_test_split
188
+
189
+ device = DEVICE if torch.cuda.is_available() else "cpu"
190
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
191
+ torch.manual_seed(SEED); np.random.seed(SEED)
192
+
193
+ # Nạp UTMOS (torch.hub) — model nn.Module, forward(wave[B,T], sr) -> MOS[B]
194
+ utmos = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).to(device)
195
+ n_all = sum(p.numel() for p in utmos.parameters())
196
+
197
+ # (tùy chọn) đóng băng feature-extractor (các lớp conv trích đặc trưng) → đỡ VRAM + chống overfit
198
+ if FREEZE_FEAT_EXT:
199
+ n_frozen = 0
200
+ for name, p in utmos.named_parameters():
201
+ if "feature_extractor" in name or "feature_projection" in name or "conv" in name.lower():
202
+ p.requires_grad = False; n_frozen += p.numel()
203
+ print(f"❄️ Đóng băng feature-extractor: {n_frozen/1e6:.1f}M / {n_all/1e6:.1f}M param")
204
+ n_train = sum(p.numel() for p in utmos.parameters() if p.requires_grad)
205
+ print(f"UTMOS: {n_all/1e6:.1f}M param tổng · {n_train/1e6:.1f}M param sẽ train")
206
+
207
+ def load_wav_qmos(sid):
208
+ p = os.path.join(WAV_DIR, sid + ".wav")
209
+ if not os.path.exists(p):
210
+ return None
211
+ wave, _ = librosa.load(p, sr=SR, mono=True)
212
+ return wave[: QMOS_MAX_SEC * SR].astype(np.float32)
213
+
214
+ # Tập train QMOS: chỉ wav tồn tại trên đĩa
215
+ train_stems_q = [s for s in qmos_gold if os.path.exists(os.path.join(WAV_DIR, s + ".wav"))]
216
+ np.random.shuffle(train_stems_q)
217
+ if LIMIT_TRAIN:
218
+ train_stems_q = train_stems_q[:LIMIT_TRAIN]
219
+ tr_q, va_q = train_test_split(train_stems_q, test_size=VAL_FRAC, random_state=SEED)
220
+ print(f"QMOS train: {len(tr_q)} · val nội bộ: {len(va_q)}")
221
+
222
+ opt = torch.optim.AdamW([p for p in utmos.parameters() if p.requires_grad],
223
+ lr=LR, weight_decay=WEIGHT_DECAY)
224
+ scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
225
+ mse = nn.MSELoss()
226
+
227
+ def utmos_forward(wave_np):
228
+ """1 wav numpy -> MOS scalar tensor (giữ grad)."""
229
+ x = torch.from_numpy(wave_np).unsqueeze(0).to(device) # [1, T]
230
+ out = utmos(x, SR) # [1] (hoặc [1,?])
231
+ return out.reshape(-1).mean() # scalar an toàn mọi shape
232
+
233
+ def pairwise_rank_loss(preds, targets):
234
+ """Hinge ranking trên các cặp trong 1 nhóm (khuyến khích đúng thứ hạng = tối ưu SRCC)."""
235
+ p = torch.stack(preds); t = torch.tensor(targets, device=device, dtype=torch.float32)
236
+ if len(p) < 2:
237
+ return torch.zeros((), device=device)
238
+ sign = torch.sign(t.unsqueeze(0) - t.unsqueeze(1))
239
+ diff = p.unsqueeze(0) - p.unsqueeze(1)
240
+ return torch.relu(-sign * diff).mean()
241
+
242
+ @torch.no_grad()
243
+ def eval_qmos_val():
244
+ utmos.eval()
245
+ preds, gts = [], []
246
+ for s in va_q:
247
+ wave = load_wav_qmos(s)
248
+ if wave is None:
249
+ continue
250
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
251
+ preds.append(float(utmos_forward(wave).item()))
252
+ gts.append(qmos_gold[s])
253
+ return float(spearmanr(preds, gts).correlation)
254
+
255
+ # Baseline ZERO-SHOT (trước khi train) trên CÙNG val → mốc phải vượt
256
+ srcc_zeroshot = eval_qmos_val()
257
+ print(f"\n📍 UTMOS zero-shot (val nội bộ): SRCC = {srcc_zeroshot:.4f} "
258
+ f"(leaderboard DEV ~{QMOS_BASELINE['utmos_zeroshot']}; exp07 head {QMOS_BASELINE['exp07_head']})")
259
+
260
+ CKPT_QMOS = os.path.join(OUT_DIR, "ft_qmos_utmos.pt")
261
+ def save_qmos_ckpt(srcc):
262
+ torch.save({"utmos_state": {k: v.cpu() for k, v in utmos.state_dict().items()},
263
+ "val_srcc": float(srcc), "raw_scale": True,
264
+ "QMOS_MAX_SEC": QMOS_MAX_SEC, "FREEZE_FEAT_EXT": FREEZE_FEAT_EXT}, CKPT_QMOS)
265
+
266
+ best, best_state, bad = srcc_zeroshot, {k: v.cpu().clone() for k, v in utmos.state_dict().items()}, 0
267
+ save_qmos_ckpt(best) # lưu sẵn bản zero-shot (worst case vẫn = baseline)
268
+
269
+ # Gom theo CỬA SỔ = ACCUM mẫu HỢP LỆ (micro). Hai chế độ backward:
270
+ # • RANK off (mặc định) → backward NGAY từng mẫu → đồ thị giải phóng liền → VRAM thấp.
271
+ # • RANK on → ranking cần SO các pred TRONG cửa sổ → PHẢI giữ đồ thị cả cửa sổ →
272
+ # gom MSE (win_loss) + pred (buf_p) rồi backward MỘT lần (MSE_mean + λ·rank).
273
+ # ⚠️ Lỗi cũ: backward MSE từng bước đã giải phóng đồ thị → rank_loss.backward() sau đó
274
+ # sẽ lỗi "backward through the graph a second time". Bản này gom rồi backward 1 lần → hết lỗi.
275
+ for ep in range(1, EPOCHS + 1):
276
+ utmos.train()
277
+ opt.zero_grad()
278
+ np.random.shuffle(tr_q)
279
+ run = 0.0; nb = 0
280
+ micro = 0; win_loss = None; buf_p, buf_t = [], []
281
+ for s in tqdm(tr_q, desc=f"epoch {ep}"):
282
+ wave = load_wav_qmos(s)
283
+ if wave is None:
284
+ continue
285
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
286
+ pred = utmos_forward(wave)
287
+ loss = mse(pred, torch.tensor(qmos_gold[s], device=device, dtype=pred.dtype))
288
+ run += float(loss.item()); nb += 1
289
+ if RANK_LAMBDA > 0:
290
+ win_loss = loss if win_loss is None else win_loss + loss # GIỮ đồ thị (không backward ngay)
291
+ buf_p.append(pred); buf_t.append(qmos_gold[s]); micro += 1
292
+ else:
293
+ scaler.scale(loss / ACCUM).backward(); micro += 1 # backward ngay → VRAM thấp
294
+ if micro == ACCUM:
295
+ if RANK_LAMBDA > 0:
296
+ total = win_loss / micro
297
+ if len(buf_p) >= 2:
298
+ total = total + RANK_LAMBDA * pairwise_rank_loss(buf_p, buf_t)
299
+ scaler.scale(total).backward()
300
+ scaler.step(opt); scaler.update(); opt.zero_grad()
301
+ micro = 0; win_loss = None; buf_p, buf_t = [], []
302
+ # flush cửa sổ dư cuối epoch (số mẫu không chia hết cho ACCUM)
303
+ if micro > 0:
304
+ if RANK_LAMBDA > 0:
305
+ total = win_loss / micro
306
+ if len(buf_p) >= 2:
307
+ total = total + RANK_LAMBDA * pairwise_rank_loss(buf_p, buf_t)
308
+ scaler.scale(total).backward()
309
+ scaler.step(opt); scaler.update(); opt.zero_grad()
310
+ sc = eval_qmos_val()
311
+ print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | val SRCC {sc:.4f} "
312
+ f"(zero-shot {srcc_zeroshot:.4f} · best {max(best,sc):.4f})")
313
+ if sc > best:
314
+ best = sc
315
+ best_state = {k: v.cpu().clone() for k, v in utmos.state_dict().items()}
316
+ save_qmos_ckpt(best)
317
+ print(f" 💾 lưu best → {CKPT_QMOS} (epoch {ep}, SRCC {sc:.4f})")
318
+ bad = 0
319
+ else:
320
+ bad += 1
321
+ if bad >= PATIENCE:
322
+ print(f"Early stop ở epoch {ep}."); break
323
+
324
+ utmos.load_state_dict(best_state)
325
+ print(f"\n✅ PHẦN A xong — QMOS val nội bộ: zero-shot {srcc_zeroshot:.4f} → fine-tune {best:.4f} "
326
+ + ("🚀 cải thiện" if best > srcc_zeroshot + 1e-4 else "➖ KHÔNG vượt zero-shot"))
327
+ if best <= srcc_zeroshot + 1e-4:
328
+ print(" ⚠️ Fine-tune chưa vượt zero-shot → cân nhắc tăng EPOCHS / bật RANK_LAMBDA=0.3 / "
329
+ "mở băng feature-extractor (FREEZE_FEAT_EXT=False); hoặc giữ QMOS exp07 (0.548).")
330
+
331
+ # %% [markdown]
332
+ # ## 4. PHẦN A (tiếp) — Dự đoán QMOS cho DEV bằng UTMOS đã fine-tune
333
+
334
+ # %%
335
+ def list_dev():
336
+ with open(DEV_SCP) as f:
337
+ return [ln.strip() for ln in f if ln.strip()]
338
+
339
+ dev_names = list_dev()
340
+ if LIMIT_DEV:
341
+ dev_names = dev_names[:LIMIT_DEV]
342
+ print("DEV:", len(dev_names), "mẫu")
343
+
344
+ @torch.no_grad()
345
+ def predict_qmos(name):
346
+ p = os.path.join(WAV_DIR, name if str(name).endswith(".wav") else str(name) + ".wav")
347
+ if not os.path.exists(p):
348
+ return None
349
+ wave, _ = librosa.load(p, sr=SR, mono=True)
350
+ wave = wave[: QMOS_MAX_SEC * SR].astype(np.float32)
351
+ utmos.eval()
352
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
353
+ v = float(utmos_forward(wave).item())
354
+ return float(np.clip(v, 1.0, 5.0))
355
+
356
+ qmos_pred = {}
357
+ n_real = n_def = 0
358
+ for name in tqdm(dev_names, desc="QMOS dev"):
359
+ v = predict_qmos(name)
360
+ if v is None:
361
+ v = 3.0; n_def += 1
362
+ else:
363
+ n_real += 1
364
+ qmos_pred[name] = v
365
+ print(f"QMOS dự đoán: thật {n_real}, mặc định {n_def}")
366
+
367
+ # Giải phóng UTMOS trước khi nạp backbone cảm xúc (đỡ VRAM T4)
368
+ del utmos, opt, scaler
369
+ torch.cuda.empty_cache() if device == "cuda" else None
370
+
371
+ # %% [markdown]
372
+ # ## 5. PHẦN B — Nạp ckpt exp08 (WavLM ft + audeering) → 5 cột cảm xúc cho DEV
373
+ # Tái dùng nguyên cơ chế load của exp08b: dựng kiến trúc → `load_state_dict` từ `ft_emotion_full_20epoch.pt`.
374
+
375
+ # %%
376
+ import torch.nn.functional as F
377
+
378
+ ckpt = torch.load(EMO_CKPT, map_location="cpu", weights_only=False) # ckpt có numpy → weights_only=False
379
+ assert "wavlm" in ckpt, ("❌ EMO_CKPT không có 'wavlm' (backbone). Cần ft_emotion_full_20epoch.pt (bản đủ backbone), "
380
+ "KHÔNG phải ft_emotion_meta.pt cũ.")
381
+ print("✅ Nạp ckpt cảm xúc:", EMO_CKPT, "| keys:", list(ckpt.keys()))
382
+
383
+ def find_hf_backbone(module):
384
+ cands = []
385
+ for nm, m in module.named_modules():
386
+ enc = getattr(m, "encoder", None)
387
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
388
+ and getattr(enc, "layers", None) is not None:
389
+ cands.append((nm, m))
390
+ if not cands:
391
+ return None, None
392
+ cands.sort(key=lambda x: sum(p.numel() for p in x[1].parameters()), reverse=True)
393
+ return cands[0]
394
+
395
+ wavlm = None
396
+ try:
397
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
398
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
399
+ _name, wavlm = find_hf_backbone(_wrapper)
400
+ if wavlm is not None:
401
+ print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{_name}'")
402
+ except Exception as e:
403
+ print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
404
+ if wavlm is None:
405
+ from transformers import WavLMModel
406
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
407
+ print("ℹ️ Fallback: microsoft/wavlm-large.")
408
+
409
+ wavlm = wavlm.to(device).eval()
410
+ WAVLM_DIM = int(wavlm.config.hidden_size)
411
+ miss, unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
412
+ print(f"🔁 load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).")
413
+
414
+ def masked_mean(hidden, attn_mask):
415
+ if attn_mask is None:
416
+ return hidden.mean(dim=1)
417
+ try:
418
+ fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
419
+ except Exception:
420
+ return hidden.mean(dim=1)
421
+ fm = fm.unsqueeze(-1).to(hidden.dtype)
422
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
423
+
424
+ @torch.no_grad()
425
+ def wavlm_embed(input_values, attn_mask):
426
+ out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
427
+ return masked_mean(out, attn_mask)
428
+
429
+ # ── audeering FROZEN (đặc trưng phụ) — như exp08 ──
430
+ AUD_DIM = 0
431
+ aud_backbone = aud_head = aud_proc = None
432
+ if USE_AUDEERING:
433
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
434
+ from huggingface_hub import hf_hub_download
435
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
436
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
437
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
438
+ aud_backbone = Wav2Vec2Model(aud_cfg)
439
+ try:
440
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
441
+ hf_hub_download(AUD_NAME, "model.safetensors"))
442
+ except Exception:
443
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
444
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
445
+ aud_backbone.load_state_dict(bb_sd, strict=False)
446
+ _hid = _sd["classifier.dense.weight"].shape[0]
447
+ _out = _sd["classifier.out_proj.weight"].shape[0]
448
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
449
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
450
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
451
+ aud_backbone = aud_backbone.to(device).eval()
452
+ aud_head = aud_head.to(device).eval()
453
+ AUD_DIM = _hid + 3
454
+ print(f"✅ audeering frozen ({AUD_DIM}-D)")
455
+
456
+ def load_wav_emo(sid):
457
+ p = os.path.join(WAV_DIR, sid + ".wav")
458
+ if not os.path.exists(p):
459
+ return None
460
+ wave, _ = librosa.load(p, sr=SR, mono=True)
461
+ return wave[: EMO_MAX_SEC * SR].astype(np.float32)
462
+
463
+ @torch.no_grad()
464
+ def extract_audeering(stems, tag):
465
+ if not USE_AUDEERING:
466
+ return {}
467
+ cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
468
+ store = {}
469
+ if os.path.exists(cache_path):
470
+ z = np.load(cache_path, allow_pickle=True)
471
+ store = {k: z[k] for k in z.files}
472
+ print(f"[aud/{tag}] nạp cache: {len(store)}")
473
+ todo = [s for s in stems if s not in store]
474
+ for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
475
+ wave = load_wav_emo(s)
476
+ if wave is None:
477
+ continue
478
+ x = aud_proc(wave, sampling_rate=SR).input_values[0]
479
+ x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
480
+ h = aud_backbone(x)[0].mean(dim=1)
481
+ out = aud_head(h)[0].cpu().numpy()
482
+ vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
483
+ store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
484
+ if (i + 1) % 500 == 0:
485
+ np.savez(cache_path, **store)
486
+ if todo:
487
+ np.savez(cache_path, **store)
488
+ return store
489
+
490
+ # ── EmoHeads (khớp exp08) + nạp trọng số head + thống kê chuẩn hóa từ ckpt ──
491
+ N_EMO = len(EMOTIONS5)
492
+ TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
493
+
494
+ class EmoHeads(nn.Module):
495
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
496
+ super().__init__()
497
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
498
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
499
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
500
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
501
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
502
+ def forward(self, feat, tgt):
503
+ h = self.trunk(feat)
504
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
505
+
506
+ heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device).eval()
507
+ hmiss, hunexp = heads.load_state_dict(ckpt["heads"], strict=False)
508
+ print(f"🔁 load heads từ ckpt: thiếu {len(hmiss)} / dư {len(hunexp)} key (kỳ vọng 0).")
509
+
510
+ emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
511
+ vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
512
+ print(f"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
513
+
514
+ # Target cảm xúc (cho EMOS head) từ metadata
515
+ def load_target_emotions():
516
+ tgt = {}
517
+ with open(METADATA_CSV, encoding="utf-8") as f:
518
+ for ln in f:
519
+ parts = ln.strip().split("|")
520
+ if len(parts) >= 2:
521
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
522
+ return tgt
523
+
524
+ target_map = load_target_emotions()
525
+
526
+ def onehot_target(tgt):
527
+ v = np.zeros(N_EMO, dtype=np.float32)
528
+ if tgt in EMOTIONS5:
529
+ v[EMOTIONS5.index(tgt)] = 1.0
530
+ return v
531
+
532
+ dev_stems = [stem(n) for n in dev_names]
533
+ aud_dev = extract_audeering(dev_stems, "dev")
534
+
535
+ @torch.no_grad()
536
+ def predict_emotion(sid):
537
+ wave = load_wav_emo(sid)
538
+ if wave is None or (USE_AUDEERING and sid not in aud_dev):
539
+ return None
540
+ iv = torch.from_numpy(wave).unsqueeze(0).to(device)
541
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
542
+ tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
543
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
544
+ fw = wavlm_embed(iv, am)
545
+ feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw
546
+ emos_p, cat_l, vad_p = heads(feat, tgt)
547
+ emos = float(emos_p.item()) * emos_sd + emos_mu
548
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
549
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
550
+ return emos, cat5, vad3
551
+
552
+ # %% [markdown]
553
+ # ## 6. PHẦN C — Ghép QMOS (fine-tune) + 5 cột cảm xúc (exp08) → answer.txt 6 cột
554
+
555
+ # %%
556
+ def fmt_cat(p5):
557
+ return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
558
+
559
+ def build_answer(out_path):
560
+ n_real = n_def = 0
561
+ with open(out_path, "w") as f:
562
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
563
+ for name in tqdm(dev_names, desc="answer"):
564
+ sid = stem(name)
565
+ pr = predict_emotion(sid)
566
+ if pr is None:
567
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
568
+ else:
569
+ emos, cat5, vad3 = pr; n_real += 1
570
+ qmos = qmos_pred.get(name, qmos_pred.get(sid, 3.0))
571
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
572
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
573
+
574
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
575
+ build_answer(answer_path)
576
+
577
+ # %% [markdown]
578
+ # ## 7. Validate + zip
579
+
580
+ # %%
581
+ def validate(path):
582
+ import csv
583
+ with open(path) as f:
584
+ rows = list(csv.reader(f))
585
+ assert rows[0][0] == "wav" and "QMOS" in rows[0] and "EMOS" in rows[0], "Header sai"
586
+ for i, r in enumerate(rows[1:], 2):
587
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
588
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
589
+
590
+ validate(answer_path)
591
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp13_ft-qmos.zip answer.txt "
592
+ f"&& unzip -l submission_track2_exp13_ft-qmos.zip")
593
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp13_ft-qmos.zip"))
594
+
595
+ # %% [markdown]
596
+ # ## Ghi chú
597
+ # - **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để chạy trơn (không OOM, 1 epoch xong); rồi đặt `None`.
598
+ # - **OOM trên T4?** giảm `QMOS_MAX_SEC` (12→10→8); giữ `FREEZE_FEAT_EXT=True`; `BATCH=1` đã là min.
599
+ # ⚠️ **Bật `RANK_LAMBDA>0` tốn VRAM hơn** vì phải GIỮ đồ thị cả cửa sổ ACCUM (=16) để so thứ hạng →
600
+ # nếu OOM khi bật ranking: giảm `ACCUM` (vd 8, cũng là kích thước nhóm ranking) hoặc `QMOS_MAX_SEC`.
601
+ # - **Đọc mục A:** so `val SRCC fine-tune` với `zero-shot`. Chỉ nộp QMOS fine-tune nếu **vượt zero-shot**
602
+ # (lý tưởng vượt cả exp07 0.548); nếu không → giữ QMOS exp07 (Add Input answer.txt exp07, đổi cột QMOS).
603
+ # - Nếu chưa vượt: tăng `EPOCHS`, bật `RANK_LAMBDA=0.3` (tối ưu thẳng thứ hạng), hoặc `FREEZE_FEAT_EXT=False`
604
+ # (mở băng feature-extractor — mạnh hơn nhưng dễ overfit + nặng VRAM).
605
+ # - **Lưu checkpoint:** `ft_qmos_utmos.pt` lưu mỗi best → **Save Version NGAY** sau khi chạy (bài học exp08).
606
+ # - **License QMOS:** UTMOS/SpeechMOS (kiểm tra license tarepan/SpeechMOS) — khai báo `docs/12_system_description.md`.
607
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp13).
track2/exp14_mamba_head.ipynb ADDED
@@ -0,0 +1,952 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "63b4bfa4",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp14 (MAMBA temporal head, CỘNG vào FUSION 6 cột) — Kaggle\n",
9
+ "\n",
10
+ "**Ý tưởng (theo gợi ý mentor \"thử Mamba\"):** exp04/exp07 đều **mean-pool** đặc trưng SSL →\n",
11
+ "mỗi wav thành 1 vector → mất hết **động lực theo thời gian** (lên/xuống giọng, ngắt quãng, rung).\n",
12
+ "**Mamba** là State Space Model (SSM) xử lý **chuỗi** với độ phức tạp tuyến tính → cho nó **dãy frame**\n",
13
+ "(chưa pool) để học temporal dynamics, rồi mới pool. Tham khảo: MambaRate (AudioMOS 2025), arXiv:2507.12090.\n",
14
+ "\n",
15
+ "## exp14 = exp07 + 1 nhánh Mamba (CỘNG thêm, không thay thế)\n",
16
+ "```\n",
17
+ " ┌─ đặc trưng POOLED [e2v_emb|e2v_p5|sailer_emb|sailer_p9|sailer_vad3] (y hệt exp07 → DÙNG LẠI cache)\n",
18
+ " mỗi wav ──┤\n",
19
+ " └─ WavLM frame-level (chuỗi T×1024) ─► Mamba (2 lớp, 2 chiều) ─► attn-pool ─► z_seq (Z chiều)\n",
20
+ " │\n",
21
+ " concat ──► TRUNK chung ──► 6 head: QMOS · EMOS · CAT · VAL · ARO · DOM\n",
22
+ "```\n",
23
+ "- **Cờ `USE_MAMBA`:** `False` → chạy ra **đúng exp07** (kiểm chứng tái lập ~0.548/0.795). `True` → bật nhánh Mamba.\n",
24
+ " Đây CHÍNH là **ablation \"có/không Mamba\"** cho paper.\n",
25
+ "- WavLM **đóng băng** (chỉ trích đặc trưng) → Mamba head nhỏ → train nhanh, vừa T4.\n",
26
+ "\n",
27
+ "## 2 gotcha Kaggle đã xử trong file\n",
28
+ "1. `mamba-ssm` hay lỗi build CUDA → **nhúng sẵn Mamba thuần PyTorch** (không cần pip); tự dùng `mamba-ssm` nếu import được.\n",
29
+ "2. Cache frame-level RẤT nặng → **cap `MAX_FRAMES`** + lưu **fp16**. Ước lượng: MAX_FRAMES=256, 1024 chiều, fp16\n",
30
+ " ≈ 0.5 MB/wav → train ~12k ≈ 6 GB, dev ~2.7k ≈ 1.4 GB (vừa /kaggle/working). **Save Version** để giữ cache.\n",
31
+ "\n",
32
+ "**Cách chạy:** GPU T4 + Internet On → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.\n",
33
+ "Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để soi nhanh; OK rồi đặt `None`."
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "markdown",
38
+ "id": "3fe243f8",
39
+ "metadata": {},
40
+ "source": [
41
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "bd2e582a",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "import os\n",
52
+ "\n",
53
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
54
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
55
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
56
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
57
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
58
+ "\n",
59
+ "OUT_DIR = \"/kaggle/working\"\n",
60
+ "CACHE_DIR = \"/kaggle/working/fusion_cache\" # DÙNG CHUNG với exp04/exp07 (e2v_*, sailer_*, utmos_*)\n",
61
+ "SEQ_DIR = \"/kaggle/working/wavlm_seq_cache\" # MỚI: cache frame-level WavLM (fp16)\n",
62
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
63
+ "os.makedirs(SEQ_DIR, exist_ok=True)\n",
64
+ "\n",
65
+ "# ── Bật/tắt nhánh Mamba (ablation chính) ─────────────────────────────────────\n",
66
+ "USE_MAMBA = True # False → ra ĐÚNG exp07 (sanity check). True → bật nhánh Mamba.\n",
67
+ "\n",
68
+ "# ── Siêu tham số nhánh Mamba ─────────────────────────────────────────────────\n",
69
+ "WAVLM_NAME = \"microsoft/wavlm-large\" # backbone frame-level (đóng băng). Trả chuỗi (T, 1024).\n",
70
+ "MAX_FRAMES = 256 # cap độ dài chuỗi (256 frame ≈ 5.1s @ 50Hz). Giảm nếu hết đĩa.\n",
71
+ "MAMBA_DMODEL = 256 # chiều ẩn của khối Mamba (proj 1024→256 trước khi vào Mamba)\n",
72
+ "MAMBA_LAYERS = 2 # số khối Mamba xếp chồng\n",
73
+ "MAMBA_DSTATE = 16 # chiều state SSM\n",
74
+ "BIDIRECTIONAL = True # chạy Mamba cả 2 chiều (xuôi + ngược) rồi cộng\n",
75
+ "Z_DIM = 128 # chiều vector z_seq sau attentive-pool, đem concat vào fusion\n",
76
+ "\n",
77
+ "# ── Siêu tham số fusion (giống exp07) ────────────────────────────────────────\n",
78
+ "DEVICE = \"cuda\"\n",
79
+ "TRUNK_HIDDEN = 512\n",
80
+ "HEAD_HIDDEN = 128\n",
81
+ "DROPOUT = 0.3\n",
82
+ "LR = 1e-3\n",
83
+ "EPOCHS = 80\n",
84
+ "BATCH = 32 # nhỏ hơn exp07 (64) vì có nhánh Mamba tốn RAM hơn\n",
85
+ "VAL_FRAC = 0.10\n",
86
+ "PATIENCE = 15\n",
87
+ "SEED = 42\n",
88
+ "\n",
89
+ "USE_UNCERTAINTY = True\n",
90
+ "LOSS_W = {\"qmos\": 1.0, \"emos\": 1.0, \"cat\": 1.0, \"val\": 1.0, \"aro\": 1.0, \"dom\": 1.0}\n",
91
+ "USE_E2V = True\n",
92
+ "USE_SAILER = True\n",
93
+ "USE_CLASSPROB = True\n",
94
+ "USE_UTMOS_FEAT = True\n",
95
+ "\n",
96
+ "LIMIT_TRAIN = None\n",
97
+ "LIMIT_DEV = None\n",
98
+ "\n",
99
+ "# Mốc exp07 để so (đây là hệ thống đang tốt nhất)\n",
100
+ "EXP07 = {\"qmos\": 0.548, \"emos\": 0.795, \"cat_err\": 0.153, \"val\": 0.581, \"aro\": 0.752, \"dom\": 0.705}\n",
101
+ "\n",
102
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
103
+ "SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
104
+ "\n",
105
+ "_EMO_ALIAS = {\n",
106
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
107
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
108
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
109
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
110
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
111
+ "}\n",
112
+ "\n",
113
+ "def norm_emotion(label):\n",
114
+ " key = str(label).strip().lower()\n",
115
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
116
+ "\n",
117
+ "def stem(p):\n",
118
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
119
+ "\n",
120
+ "assert USE_E2V or USE_SAILER, \"Phải bật ít nhất 1 backbone pooled.\"\n",
121
+ "print(\"USE_MAMBA =\", USE_MAMBA, \"| nếu False → ra đúng exp07\")\n",
122
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
123
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
124
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "markdown",
129
+ "id": "5ad58750",
130
+ "metadata": {},
131
+ "source": [
132
+ "## 1. Cài đặt + tải code SAILER\n",
133
+ "Chỉ cài gói còn thiếu (Kaggle có sẵn torch/transformers). KHÔNG đụng numpy (tránh lệch ABI torch — bài học exp12)."
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": null,
139
+ "id": "3260eb06",
140
+ "metadata": {},
141
+ "outputs": [],
142
+ "source": [
143
+ "import sys, subprocess\n",
144
+ "\n",
145
+ "def pip_install(*pkgs):\n",
146
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
147
+ "\n",
148
+ "pip_install(\"speechmos\", \"funasr\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
149
+ "\n",
150
+ "if USE_SAILER:\n",
151
+ " pip_install(\"loralib\", \"speechbrain\")\n",
152
+ " REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
153
+ " if not os.path.exists(REPO_DIR):\n",
154
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
155
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
156
+ " if REPO_DIR not in sys.path:\n",
157
+ " sys.path.insert(0, REPO_DIR)"
158
+ ]
159
+ },
160
+ {
161
+ "cell_type": "markdown",
162
+ "id": "f92c0e17",
163
+ "metadata": {},
164
+ "source": [
165
+ "## 2. Đọc & gộp nhãn theo wavID (giống exp07)"
166
+ ]
167
+ },
168
+ {
169
+ "cell_type": "code",
170
+ "execution_count": null,
171
+ "id": "bab3f8d5",
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "import numpy as np\n",
176
+ "import pandas as pd\n",
177
+ "\n",
178
+ "def load_target_emotions():\n",
179
+ " tgt = {}\n",
180
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
181
+ " for ln in f:\n",
182
+ " parts = ln.strip().split(\"|\")\n",
183
+ " if len(parts) < 2:\n",
184
+ " continue\n",
185
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
186
+ " return tgt\n",
187
+ "\n",
188
+ "def _col(cols_map, *names, default_idx=None, df=None):\n",
189
+ " for n in names:\n",
190
+ " if n in cols_map:\n",
191
+ " return cols_map[n]\n",
192
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
193
+ "\n",
194
+ "def parse_emocat_votes(cell):\n",
195
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
196
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
197
+ " e = norm_emotion(tok)\n",
198
+ " if e in EMOTIONS5:\n",
199
+ " v[EMOTIONS5.index(e)] += 1.0\n",
200
+ " return v\n",
201
+ "\n",
202
+ "def load_train_labels():\n",
203
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
204
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
205
+ " wav_col = _col(cols, \"wavid\", \"wav\", default_idx=1, df=df)\n",
206
+ " qmos_col = _col(cols, \"qmos\", \"mos\")\n",
207
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
208
+ " val_col = _col(cols, \"val\", \"valence\")\n",
209
+ " aro_col = _col(cols, \"aro\", \"arousal\")\n",
210
+ " dom_col = _col(cols, \"dom\", \"dominance\")\n",
211
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
212
+ " assert qmos_col and emos_col, f\"Thiếu cột qMOS/eMOS (cột: {list(df.columns)})\"\n",
213
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
214
+ " rows = []\n",
215
+ " for sid, g in df.groupby(\"_stem\"):\n",
216
+ " rec = {\"wavID\": sid, \"qmos\": float(g[qmos_col].mean()), \"emos\": float(g[emos_col].mean())}\n",
217
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
218
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
219
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
220
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
221
+ " if cat_col:\n",
222
+ " for cell in g[cat_col]:\n",
223
+ " votes += parse_emocat_votes(cell)\n",
224
+ " s = votes.sum()\n",
225
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)\n",
226
+ " for i in range(len(EMOTIONS5)):\n",
227
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
228
+ " rows.append(rec)\n",
229
+ " return pd.DataFrame(rows)\n",
230
+ "\n",
231
+ "target_map = load_target_emotions()\n",
232
+ "train_df = load_train_labels()\n",
233
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
234
+ "print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "markdown",
239
+ "id": "a5cd1ff1",
240
+ "metadata": {},
241
+ "source": [
242
+ "## 3. Đặc trưng POOLED (e2v + sailer + UTMOS) — TÁI DÙNG cache exp04/exp07\n",
243
+ "(Y hệt exp07; nếu đã chạy exp07 thì cache `fusion_cache/` còn nguyên → không tính lại.)"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": null,
249
+ "id": "8c31b6a4",
250
+ "metadata": {
251
+ "lines_to_next_cell": 1
252
+ },
253
+ "outputs": [],
254
+ "source": [
255
+ "import torch\n",
256
+ "import torch.nn.functional as F\n",
257
+ "\n",
258
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
259
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU\")\n",
260
+ "\n",
261
+ "def extract_e2v(stems, tag):\n",
262
+ " from tqdm.auto import tqdm\n",
263
+ " cache_path = os.path.join(CACHE_DIR, f\"e2v_{tag}.npz\")\n",
264
+ " store = {}\n",
265
+ " if os.path.exists(cache_path):\n",
266
+ " z = np.load(cache_path, allow_pickle=True)\n",
267
+ " store = {k: z[k] for k in z.files}\n",
268
+ " print(f\"[e2v/{tag}] nạp cache: {len(store)}\")\n",
269
+ " todo = [s for s in stems if s not in store]\n",
270
+ " if todo:\n",
271
+ " from funasr import AutoModel\n",
272
+ " m = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\", device=device)\n",
273
+ " for i, s in enumerate(tqdm(todo, desc=f\"e2v {tag}\")):\n",
274
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
275
+ " if not os.path.exists(wav):\n",
276
+ " continue\n",
277
+ " r = m.generate(wav, granularity=\"utterance\", extract_embedding=True)[0]\n",
278
+ " emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
279
+ " probs = {e: 0.0 for e in EMOTIONS5}\n",
280
+ " for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
281
+ " name = lab.split(\"/\")[-1]\n",
282
+ " if name in probs:\n",
283
+ " probs[name] = float(sc)\n",
284
+ " tot = sum(probs.values())\n",
285
+ " p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)\n",
286
+ " store[s] = np.concatenate([emb, p5]).astype(np.float32)\n",
287
+ " if (i + 1) % 500 == 0:\n",
288
+ " np.savez(cache_path, **store)\n",
289
+ " np.savez(cache_path, **store)\n",
290
+ " del m\n",
291
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
292
+ " return {s: (v[:-5], v[-5:]) for s, v in store.items()}\n",
293
+ "\n",
294
+ "def _pool_feat(features):\n",
295
+ " f = features.detach().cpu().numpy()\n",
296
+ " if f.ndim <= 1:\n",
297
+ " return f.reshape(-1).astype(np.float32)\n",
298
+ " return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)\n",
299
+ "\n",
300
+ "def extract_sailer(stems, tag):\n",
301
+ " import librosa\n",
302
+ " from tqdm.auto import tqdm\n",
303
+ " cache_path = os.path.join(CACHE_DIR, f\"sailer_{tag}.npz\")\n",
304
+ " store = {}\n",
305
+ " if os.path.exists(cache_path):\n",
306
+ " z = np.load(cache_path, allow_pickle=True)\n",
307
+ " store = {k: z[k] for k in z.files}\n",
308
+ " print(f\"[sailer/{tag}] nạp cache: {len(store)}\")\n",
309
+ " todo = [s for s in stems if s not in store]\n",
310
+ " if todo:\n",
311
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
312
+ " sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device).eval()\n",
313
+ " with torch.no_grad():\n",
314
+ " for i, s in enumerate(tqdm(todo, desc=f\"sailer {tag}\")):\n",
315
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
316
+ " if not os.path.exists(wav):\n",
317
+ " continue\n",
318
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
319
+ " wave = wave[: 15 * 16000]\n",
320
+ " data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
321
+ " logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
322
+ " emb = _pool_feat(feat)\n",
323
+ " p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
324
+ " vad3 = np.array([1 + 4 * float(valence.item()),\n",
325
+ " 1 + 4 * float(arousal.item()),\n",
326
+ " 1 + 4 * float(dominance.item())], dtype=np.float32)\n",
327
+ " store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)\n",
328
+ " if (i + 1) % 500 == 0:\n",
329
+ " np.savez(cache_path, **store)\n",
330
+ " np.savez(cache_path, **store)\n",
331
+ " del sailer\n",
332
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
333
+ " return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}\n",
334
+ "\n",
335
+ "def extract_utmos(names, tag):\n",
336
+ " import librosa\n",
337
+ " from tqdm.auto import tqdm\n",
338
+ " cache_path = os.path.join(CACHE_DIR, f\"utmos_{tag}.npz\")\n",
339
+ " store = {}\n",
340
+ " if os.path.exists(cache_path):\n",
341
+ " z = np.load(cache_path, allow_pickle=True)\n",
342
+ " store = {k: float(z[k]) for k in z.files}\n",
343
+ " print(f\"[utmos/{tag}] nạp cache: {len(store)}\")\n",
344
+ " todo = [n for n in names if stem(n) not in store]\n",
345
+ " if todo:\n",
346
+ " predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\",\n",
347
+ " trust_repo=True).to(device).eval()\n",
348
+ " with torch.no_grad():\n",
349
+ " for i, n in enumerate(tqdm(todo, desc=f\"utmos {tag}\")):\n",
350
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else n + \".wav\")\n",
351
+ " if not os.path.exists(wav):\n",
352
+ " continue\n",
353
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
354
+ " store[stem(n)] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),\n",
355
+ " sr=16000).mean().item())\n",
356
+ " if (i + 1) % 500 == 0:\n",
357
+ " np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
358
+ " np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
359
+ " del predictor\n",
360
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
361
+ " return store"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "markdown",
366
+ "id": "a6a9dfc9",
367
+ "metadata": {},
368
+ "source": [
369
+ "## 3b. Đặc trưng FRAME-LEVEL WavLM (chuỗi T×1024) cho nhánh Mamba — cache fp16\n",
370
+ "Mỗi wav lưu 1 file `.npy` riêng trong `SEQ_DIR` (mảng fp16 [T, 1024], T ≤ MAX_FRAMES).\n",
371
+ "WavLM **đóng băng** (eval, no_grad) → layerdrop tự tắt ở eval, không đụng gotcha checkpoint."
372
+ ]
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "execution_count": null,
377
+ "id": "60c9e86e",
378
+ "metadata": {
379
+ "lines_to_next_cell": 1
380
+ },
381
+ "outputs": [],
382
+ "source": [
383
+ "_wavlm = None\n",
384
+ "def _get_wavlm():\n",
385
+ " \"\"\"Lazy-load microsoft/wavlm-large (đóng băng). Trả model + feature_extractor.\"\"\"\n",
386
+ " global _wavlm\n",
387
+ " if _wavlm is None:\n",
388
+ " from transformers import WavLMModel, AutoFeatureExtractor\n",
389
+ " fe = AutoFeatureExtractor.from_pretrained(WAVLM_NAME)\n",
390
+ " mdl = WavLMModel.from_pretrained(WAVLM_NAME).to(device).eval()\n",
391
+ " for p in mdl.parameters():\n",
392
+ " p.requires_grad = False\n",
393
+ " _wavlm = (mdl, fe)\n",
394
+ " return _wavlm\n",
395
+ "\n",
396
+ "def seq_path(sid):\n",
397
+ " return os.path.join(SEQ_DIR, sid + \".npy\")\n",
398
+ "\n",
399
+ "def extract_wavlm_seq(stems, tag):\n",
400
+ " \"\"\"Trích frame-level WavLM cho từng wav, cache fp16 ra .npy. Trả set stem đã có.\"\"\"\n",
401
+ " if not USE_MAMBA:\n",
402
+ " return set()\n",
403
+ " import librosa\n",
404
+ " from tqdm.auto import tqdm\n",
405
+ " todo = [s for s in stems if not os.path.exists(seq_path(s))]\n",
406
+ " if todo:\n",
407
+ " mdl, fe = _get_wavlm()\n",
408
+ " with torch.no_grad():\n",
409
+ " for i, s in enumerate(tqdm(todo, desc=f\"wavlm-seq {tag}\")):\n",
410
+ " wav = os.path.join(WAV_DIR, s + \".wav\")\n",
411
+ " if not os.path.exists(wav):\n",
412
+ " continue\n",
413
+ " wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
414
+ " wave = wave[: 15 * 16000]\n",
415
+ " inp = fe(wave, sampling_rate=16000, return_tensors=\"pt\").input_values.to(device)\n",
416
+ " hs = mdl(inp).last_hidden_state[0] # (T, 1024)\n",
417
+ " if hs.shape[0] > MAX_FRAMES: # cap độ dài (đều theo thời gian)\n",
418
+ " idx = torch.linspace(0, hs.shape[0] - 1, MAX_FRAMES).long()\n",
419
+ " hs = hs[idx]\n",
420
+ " np.save(seq_path(s), hs.cpu().numpy().astype(np.float16))\n",
421
+ " torch.cuda.empty_cache() if device == \"cuda\" else None\n",
422
+ " return {s for s in stems if os.path.exists(seq_path(s))}\n",
423
+ "\n",
424
+ "def load_seq(sid):\n",
425
+ " \"\"\"Đọc chuỗi fp16 → tensor float32 (T, 1024). Thiếu file → None.\"\"\"\n",
426
+ " p = seq_path(sid)\n",
427
+ " if not os.path.exists(p):\n",
428
+ " return None\n",
429
+ " return torch.from_numpy(np.load(p).astype(np.float32))\n",
430
+ "\n",
431
+ "def collate_seqs(sids):\n",
432
+ " \"\"\"Gộp list chuỗi độ dài khác nhau → (B, Lmax, 1024) + mask (B, Lmax) bool (True=thật).\"\"\"\n",
433
+ " seqs = [load_seq(s) for s in sids]\n",
434
+ " lens = [t.shape[0] for t in seqs]\n",
435
+ " Lmax = max(lens)\n",
436
+ " B = len(seqs)\n",
437
+ " x = torch.zeros(B, Lmax, seqs[0].shape[1], dtype=torch.float32)\n",
438
+ " mask = torch.zeros(B, Lmax, dtype=torch.bool)\n",
439
+ " for i, t in enumerate(seqs):\n",
440
+ " x[i, : t.shape[0]] = t\n",
441
+ " mask[i, : t.shape[0]] = True\n",
442
+ " return x, mask"
443
+ ]
444
+ },
445
+ {
446
+ "cell_type": "markdown",
447
+ "id": "328a5f30",
448
+ "metadata": {},
449
+ "source": [
450
+ "## 4. Dựng feature pooled + nhãn cho train (lọc các wav đủ mọi nguồn)"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "execution_count": null,
456
+ "id": "4449a153",
457
+ "metadata": {},
458
+ "outputs": [],
459
+ "source": [
460
+ "train_stems = list(train_df[\"wavID\"])\n",
461
+ "if LIMIT_TRAIN:\n",
462
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
463
+ "\n",
464
+ "e2v_tr = extract_e2v(train_stems, \"train\") if USE_E2V else {}\n",
465
+ "sailer_tr = extract_sailer(train_stems, \"train\") if USE_SAILER else {}\n",
466
+ "utmos_tr = extract_utmos(train_stems, \"train\") if USE_UTMOS_FEAT else {}\n",
467
+ "seq_tr = extract_wavlm_seq(train_stems, \"train\")\n",
468
+ "\n",
469
+ "def audio_feature(sid, e2v_map, sailer_map):\n",
470
+ " parts = []\n",
471
+ " if USE_E2V:\n",
472
+ " pk = e2v_map.get(sid)\n",
473
+ " if pk is None:\n",
474
+ " return None\n",
475
+ " emb, p5 = pk\n",
476
+ " parts.append(emb)\n",
477
+ " if USE_CLASSPROB:\n",
478
+ " parts.append(p5)\n",
479
+ " if USE_SAILER:\n",
480
+ " pk = sailer_map.get(sid)\n",
481
+ " if pk is None:\n",
482
+ " return None\n",
483
+ " emb, p9, vad3 = pk\n",
484
+ " parts.append(emb)\n",
485
+ " if USE_CLASSPROB:\n",
486
+ " parts.append(p9); parts.append(vad3)\n",
487
+ " return np.concatenate(parts).astype(np.float32)\n",
488
+ "\n",
489
+ "def onehot_target(tgt):\n",
490
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
491
+ " if tgt in EMOTIONS5:\n",
492
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
493
+ " return v\n",
494
+ "\n",
495
+ "lab = train_df.set_index(\"wavID\")\n",
496
+ "keep_sids, X, T, U = [], [], [], []\n",
497
+ "y_qmos, y_emos, y_vad, y_cat = [], [], [], []\n",
498
+ "for s in train_stems:\n",
499
+ " f = audio_feature(s, e2v_tr, sailer_tr)\n",
500
+ " tgt = target_map.get(s)\n",
501
+ " if f is None or tgt is None or s not in lab.index:\n",
502
+ " continue\n",
503
+ " if USE_UTMOS_FEAT and s not in utmos_tr:\n",
504
+ " continue\n",
505
+ " if USE_MAMBA and s not in seq_tr: # cần có chuỗi WavLM nếu bật Mamba\n",
506
+ " continue\n",
507
+ " keep_sids.append(s)\n",
508
+ " X.append(f)\n",
509
+ " T.append(onehot_target(tgt))\n",
510
+ " U.append(utmos_tr.get(s, 3.0) if USE_UTMOS_FEAT else 0.0)\n",
511
+ " y_qmos.append(lab.loc[s, \"qmos\"]); y_emos.append(lab.loc[s, \"emos\"])\n",
512
+ " y_vad.append([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]])\n",
513
+ " y_cat.append([lab.loc[s, f\"cat{i}\"] for i in range(len(EMOTIONS5))])\n",
514
+ "\n",
515
+ "X = np.stack(X).astype(np.float32)\n",
516
+ "T = np.stack(T).astype(np.float32)\n",
517
+ "U = np.array(U, dtype=np.float32).reshape(-1, 1)\n",
518
+ "y_qmos = np.array(y_qmos, dtype=np.float32); y_emos = np.array(y_emos, dtype=np.float32)\n",
519
+ "y_vad = np.array(y_vad, dtype=np.float32); y_cat = np.array(y_cat, dtype=np.float32)\n",
520
+ "FEAT_DIM = X.shape[1]\n",
521
+ "print(f\"Train giữ lại: {len(keep_sids)} wav | X={X.shape} | Mamba={'ON' if USE_MAMBA else 'OFF'}\")\n",
522
+ "\n",
523
+ "# Chuẩn hóa feature pooled + UTMOS + nhãn liên tục (z-score)\n",
524
+ "feat_mean = X.mean(0, keepdims=True); feat_std = X.std(0, keepdims=True) + 1e-6\n",
525
+ "Xn = (X - feat_mean) / feat_std\n",
526
+ "u_mu, u_sd = float(U.mean()), float(U.std() + 1e-6); Un = (U - u_mu) / u_sd\n",
527
+ "qmos_mu, qmos_sd = float(y_qmos.mean()), float(y_qmos.std() + 1e-6); y_qmos_z = (y_qmos - qmos_mu) / qmos_sd\n",
528
+ "emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6); y_emos_z = (y_emos - emos_mu) / emos_sd\n",
529
+ "if HAS_VAD:\n",
530
+ " vad_mu = np.nanmean(y_vad, axis=0); vad_sd = np.nanstd(y_vad, axis=0) + 1e-6\n",
531
+ " y_vad_z = (y_vad - vad_mu) / vad_sd\n",
532
+ "else:\n",
533
+ " vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32); y_vad_z = np.zeros_like(y_vad)"
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "markdown",
538
+ "id": "5f0a94ff",
539
+ "metadata": {},
540
+ "source": [
541
+ "## 5a. Khối MAMBA (thuần PyTorch, không cần `mamba-ssm`)\n",
542
+ "Tự dùng `mamba-ssm` nếu import được (nhanh hơn); nếu không → bản thuần PyTorch (selective scan vòng lặp thời gian).\n",
543
+ "Bản này theo \"mamba-minimal\" (johnma2006) — đúng công thức, chỉ chậm hơn kernel CUDA, nhưng head nhỏ nên OK trên T4."
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "code",
548
+ "execution_count": null,
549
+ "id": "535fcd63",
550
+ "metadata": {
551
+ "lines_to_next_cell": 1
552
+ },
553
+ "outputs": [],
554
+ "source": [
555
+ "import math\n",
556
+ "import torch.nn as nn\n",
557
+ "\n",
558
+ "try:\n",
559
+ " from mamba_ssm import Mamba as _OfficialMamba # nếu cài được thì dùng (tùy chọn)\n",
560
+ " _HAS_MAMBA_SSM = True\n",
561
+ " print(\"✅ Dùng mamba-ssm (CUDA kernel)\")\n",
562
+ "except Exception:\n",
563
+ " _HAS_MAMBA_SSM = False\n",
564
+ " print(\"ℹ️ Không có mamba-ssm → dùng Mamba thuần PyTorch (nhúng sẵn)\")\n",
565
+ "\n",
566
+ "class MambaBlockTorch(nn.Module):\n",
567
+ " \"\"\"Một khối Mamba (selective SSM) thuần PyTorch. d_model = chiều ẩn.\"\"\"\n",
568
+ " def __init__(self, d_model, d_state=16, d_conv=4, expand=2):\n",
569
+ " super().__init__()\n",
570
+ " self.d_inner = expand * d_model\n",
571
+ " self.dt_rank = math.ceil(d_model / 16)\n",
572
+ " self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)\n",
573
+ " self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,\n",
574
+ " groups=self.d_inner, padding=d_conv - 1, bias=True)\n",
575
+ " self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)\n",
576
+ " self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)\n",
577
+ " A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)\n",
578
+ " self.A_log = nn.Parameter(torch.log(A)) # (d_inner, d_state)\n",
579
+ " self.D = nn.Parameter(torch.ones(self.d_inner))\n",
580
+ " self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)\n",
581
+ " self.d_state = d_state\n",
582
+ "\n",
583
+ " def forward(self, x): # x: (B, L, d_model)\n",
584
+ " B, L, _ = x.shape\n",
585
+ " xz = self.in_proj(x) # (B, L, 2*d_inner)\n",
586
+ " xin, z = xz.chunk(2, dim=-1)\n",
587
+ " xin = xin.transpose(1, 2) # (B, d_inner, L)\n",
588
+ " xin = self.conv1d(xin)[..., :L].transpose(1, 2) # (B, L, d_inner) causal conv\n",
589
+ " xin = F.silu(xin)\n",
590
+ " y = self._ssm(xin) # (B, L, d_inner)\n",
591
+ " y = y * F.silu(z)\n",
592
+ " return self.out_proj(y)\n",
593
+ "\n",
594
+ " def _ssm(self, x): # x: (B, L, d_inner)\n",
595
+ " A = -torch.exp(self.A_log) # (d_inner, d_state)\n",
596
+ " x_dbl = self.x_proj(x) # (B, L, dt_rank + 2*d_state)\n",
597
+ " delta, Bm, Cm = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)\n",
598
+ " delta = F.softplus(self.dt_proj(delta)) # (B, L, d_inner)\n",
599
+ " dA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)\n",
600
+ " dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1) # (B, L, d_inner, d_state)\n",
601
+ " h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)\n",
602
+ " ys = []\n",
603
+ " for t in range(x.shape[1]): # selective scan theo thời gian\n",
604
+ " h = dA[:, t] * h + dB_x[:, t]\n",
605
+ " ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1)) # (B, d_inner)\n",
606
+ " y = torch.stack(ys, dim=1) # (B, L, d_inner)\n",
607
+ " return y + x * self.D\n",
608
+ "\n",
609
+ "class MambaLayer(nn.Module):\n",
610
+ " \"\"\"Pre-norm residual quanh 1 khối Mamba (chọn official nếu có).\"\"\"\n",
611
+ " def __init__(self, d_model, d_state):\n",
612
+ " super().__init__()\n",
613
+ " self.norm = nn.LayerNorm(d_model)\n",
614
+ " if _HAS_MAMBA_SSM:\n",
615
+ " self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2)\n",
616
+ " else:\n",
617
+ " self.mix = MambaBlockTorch(d_model, d_state=d_state)\n",
618
+ "\n",
619
+ " def forward(self, x):\n",
620
+ " return x + self.mix(self.norm(x))\n",
621
+ "\n",
622
+ "class MambaEncoder(nn.Module):\n",
623
+ " \"\"\"1024 → d_model → [Mamba ×L] (2 chiều nếu BIDIRECTIONAL) → attentive-pool → Z_DIM.\"\"\"\n",
624
+ " def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):\n",
625
+ " super().__init__()\n",
626
+ " self.bidir = bidir\n",
627
+ " self.proj = nn.Linear(d_in, d_model)\n",
628
+ " self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
629
+ " if bidir:\n",
630
+ " self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
631
+ " self.attn = nn.Linear(d_model, 1) # attentive pooling\n",
632
+ " self.out = nn.Linear(d_model, z_dim)\n",
633
+ "\n",
634
+ " def _run(self, layers, h):\n",
635
+ " for L in layers:\n",
636
+ " h = L(h)\n",
637
+ " return h\n",
638
+ "\n",
639
+ " def forward(self, x, mask): # x: (B, L, 1024), mask: (B, L) bool\n",
640
+ " h = self.proj(x)\n",
641
+ " out = self._run(self.fwd, h)\n",
642
+ " if self.bidir:\n",
643
+ " rev = torch.flip(h, dims=[1])\n",
644
+ " out = out + torch.flip(self._run(self.bwd, rev), dims=[1])\n",
645
+ " a = self.attn(out).squeeze(-1) # (B, L)\n",
646
+ " a = a.masked_fill(~mask, float(\"-inf\"))\n",
647
+ " w = torch.softmax(a, dim=1).unsqueeze(-1) # (B, L, 1)\n",
648
+ " pooled = (out * w).sum(1) # (B, d_model)\n",
649
+ " return self.out(pooled) # (B, z_dim)"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "markdown",
654
+ "id": "a1a3026b",
655
+ "metadata": {},
656
+ "source": [
657
+ "## 5b. Model fusion 6 head + nhánh Mamba + train loop"
658
+ ]
659
+ },
660
+ {
661
+ "cell_type": "code",
662
+ "execution_count": null,
663
+ "id": "e5f743ef",
664
+ "metadata": {
665
+ "lines_to_next_cell": 1
666
+ },
667
+ "outputs": [],
668
+ "source": [
669
+ "from scipy.stats import spearmanr\n",
670
+ "from sklearn.model_selection import train_test_split\n",
671
+ "\n",
672
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
673
+ "N_EMO = len(EMOTIONS5)\n",
674
+ "idx_all = np.arange(X.shape[0])\n",
675
+ "tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)\n",
676
+ "\n",
677
+ "def to_t(a):\n",
678
+ " return torch.tensor(a, dtype=torch.float32, device=device)\n",
679
+ "\n",
680
+ "Xn_t, T_t, Un_t = to_t(Xn), to_t(T), to_t(Un)\n",
681
+ "qmos_t = to_t(y_qmos_z).unsqueeze(1); emos_t = to_t(y_emos_z).unsqueeze(1)\n",
682
+ "vad_t = to_t(y_vad_z); cat_t = to_t(y_cat)\n",
683
+ "\n",
684
+ "class FusionMamba6(nn.Module):\n",
685
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo, use_utmos, use_mamba):\n",
686
+ " super().__init__()\n",
687
+ " self.use_utmos = use_utmos\n",
688
+ " self.use_mamba = use_mamba\n",
689
+ " z_extra = Z_DIM if use_mamba else 0\n",
690
+ " if use_mamba:\n",
691
+ " self.enc = MambaEncoder(1024, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL)\n",
692
+ " self.trunk = nn.Sequential(\n",
693
+ " nn.Linear(d_in + z_extra, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
694
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
695
+ " self.qmos = nn.Sequential(\n",
696
+ " nn.Linear(trunk_h + (1 if use_utmos else 0), head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
697
+ " self.emos = nn.Sequential(\n",
698
+ " nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
699
+ " self.cat = nn.Sequential(\n",
700
+ " nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
701
+ " self.vad = nn.Sequential(\n",
702
+ " nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
703
+ "\n",
704
+ " def forward(self, x, tgt, utmos, seq=None, mask=None):\n",
705
+ " if self.use_mamba:\n",
706
+ " z = self.enc(seq, mask)\n",
707
+ " x = torch.cat([x, z], dim=1)\n",
708
+ " h = self.trunk(x)\n",
709
+ " qmos_in = torch.cat([h, utmos], dim=1) if self.use_utmos else h\n",
710
+ " return self.qmos(qmos_in), self.emos(torch.cat([h, tgt], dim=1)), self.cat(h), self.vad(h)\n",
711
+ "\n",
712
+ "model = FusionMamba6(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO, USE_UTMOS_FEAT, USE_MAMBA).to(device)\n",
713
+ "n_par = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
714
+ "print(f\"Tham số train được: {n_par/1e6:.2f} M\")\n",
715
+ "\n",
716
+ "TASKS = [\"qmos\", \"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
717
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
718
+ "params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
719
+ "opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)\n",
720
+ "mse = nn.MSELoss(reduction=\"none\")\n",
721
+ "\n",
722
+ "def soft_ce(logits, target_dist):\n",
723
+ " return -(target_dist * F.log_softmax(logits, dim=1)).sum(dim=1)\n",
724
+ "\n",
725
+ "def task_losses(qmos_p, emos_p, cat_logits, vad_p, b):\n",
726
+ " L = {\"qmos\": mse(qmos_p, qmos_t[b]).mean(),\n",
727
+ " \"emos\": mse(emos_p, emos_t[b]).mean(),\n",
728
+ " \"cat\": soft_ce(cat_logits, cat_t[b]).mean()}\n",
729
+ " if HAS_VAD:\n",
730
+ " L[\"val\"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()\n",
731
+ " L[\"aro\"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()\n",
732
+ " L[\"dom\"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()\n",
733
+ " else:\n",
734
+ " z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
735
+ " return L\n",
736
+ "\n",
737
+ "def combine(L):\n",
738
+ " if USE_UNCERTAINTY:\n",
739
+ " return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
740
+ " return sum(LOSS_W[t] * L[t] for t in TASKS)\n",
741
+ "\n",
742
+ "# batch theo INDEX (vì nhánh Mamba cần đọc chuỗi theo sid → collate động)\n",
743
+ "sids_arr = np.array(keep_sids)\n",
744
+ "\n",
745
+ "def forward_batch(bidx):\n",
746
+ " \"\"\"bidx: numpy index. Trả output model cho batch (tự collate chuỗi nếu bật Mamba).\"\"\"\n",
747
+ " bt = torch.tensor(bidx, device=device)\n",
748
+ " if USE_MAMBA:\n",
749
+ " seq, mask = collate_seqs(list(sids_arr[bidx]))\n",
750
+ " seq, mask = seq.to(device), mask.to(device)\n",
751
+ " return model(Xn_t[bt], T_t[bt], Un_t[bt], seq, mask)\n",
752
+ " return model(Xn_t[bt], T_t[bt], Un_t[bt])\n",
753
+ "\n",
754
+ "@torch.no_grad()\n",
755
+ "def eval_val():\n",
756
+ " model.eval()\n",
757
+ " qp, ep, vp = [], [], []\n",
758
+ " for i in range(0, len(va_idx), BATCH):\n",
759
+ " b = va_idx[i:i + BATCH]\n",
760
+ " q, e, _cl, v = forward_batch(b)\n",
761
+ " qp.append(q.cpu().numpy().ravel()); ep.append(e.cpu().numpy().ravel()); vp.append(v.cpu().numpy())\n",
762
+ " qp = np.concatenate(qp); ep = np.concatenate(ep); vp = np.concatenate(vp)\n",
763
+ " out = {\"qmos\": spearmanr(qp, y_qmos[va_idx]).correlation,\n",
764
+ " \"emos\": spearmanr(ep, y_emos[va_idx]).correlation}\n",
765
+ " if USE_UTMOS_FEAT:\n",
766
+ " out[\"qmos_utmos\"] = spearmanr(U[va_idx, 0], y_qmos[va_idx]).correlation\n",
767
+ " if HAS_VAD:\n",
768
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
769
+ " out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation\n",
770
+ " return out\n",
771
+ "\n",
772
+ "def val_score(m):\n",
773
+ " keys = [\"qmos\", \"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
774
+ " return float(np.mean([m[k] for k in keys]))\n",
775
+ "\n",
776
+ "best_score, best_state, bad = -1e9, None, 0\n",
777
+ "for ep_i in range(1, EPOCHS + 1):\n",
778
+ " model.train()\n",
779
+ " perm = np.random.permutation(tr_idx)\n",
780
+ " run = 0.0\n",
781
+ " for i in range(0, len(perm), BATCH):\n",
782
+ " b = perm[i:i + BATCH]\n",
783
+ " opt.zero_grad()\n",
784
+ " q, e, cl, v = forward_batch(b)\n",
785
+ " loss = combine(task_losses(q, e, cl, v, torch.tensor(b, device=device)))\n",
786
+ " loss.backward(); opt.step()\n",
787
+ " run += loss.item() * len(b)\n",
788
+ " m = eval_val(); sc = val_score(m)\n",
789
+ " if sc > best_score:\n",
790
+ " best_score = sc; bad = 0\n",
791
+ " best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
792
+ " else:\n",
793
+ " bad += 1\n",
794
+ " if ep_i % 2 == 0 or ep_i == 1:\n",
795
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"qmos\", \"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
796
+ " print(f\"epoch {ep_i:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}\")\n",
797
+ " if bad >= PATIENCE:\n",
798
+ " print(f\"Early stop ở epoch {ep_i}.\"); break\n",
799
+ "\n",
800
+ "model.load_state_dict(best_state)\n",
801
+ "final = eval_val()\n",
802
+ "print(f\"\\n✅ VAL (nội bộ) — exp14 (Mamba={'ON' if USE_MAMBA else 'OFF'}):\")\n",
803
+ "print(f\" QMOS={final['qmos']:.4f} (exp07 {EXP07['qmos']}) | EMOS={final['emos']:.4f} (exp07 {EXP07['emos']})\")\n",
804
+ "if HAS_VAD:\n",
805
+ " print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}\"\n",
806
+ " f\" (exp07 {EXP07['val']}/{EXP07['aro']}/{EXP07['dom']})\")\n",
807
+ "print(\" → So sánh USE_MAMBA True vs False = ablation Mamba cho paper.\")\n",
808
+ "\n",
809
+ "torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
810
+ " \"u_mu\": u_mu, \"u_sd\": u_sd, \"qmos_mu\": qmos_mu, \"qmos_sd\": qmos_sd,\n",
811
+ " \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
812
+ " \"FEAT_DIM\": FEAT_DIM, \"USE_MAMBA\": USE_MAMBA, \"val_score\": best_score},\n",
813
+ " os.path.join(OUT_DIR, \"fusion_mamba_mtl.pt\"))\n",
814
+ "print(\"Đã lưu\", os.path.join(OUT_DIR, \"fusion_mamba_mtl.pt\"))"
815
+ ]
816
+ },
817
+ {
818
+ "cell_type": "markdown",
819
+ "id": "ea38383a",
820
+ "metadata": {},
821
+ "source": [
822
+ "## 6. Dự đoán DEV → `answer.txt` đủ 6 cột"
823
+ ]
824
+ },
825
+ {
826
+ "cell_type": "code",
827
+ "execution_count": null,
828
+ "id": "6e774431",
829
+ "metadata": {
830
+ "lines_to_next_cell": 1
831
+ },
832
+ "outputs": [],
833
+ "source": [
834
+ "def list_dev():\n",
835
+ " with open(DEV_SCP) as f:\n",
836
+ " return [ln.strip() for ln in f if ln.strip()]\n",
837
+ "\n",
838
+ "dev_names = list_dev()\n",
839
+ "if LIMIT_DEV:\n",
840
+ " dev_names = dev_names[:LIMIT_DEV]\n",
841
+ "dev_stems = [stem(n) for n in dev_names]\n",
842
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
843
+ "\n",
844
+ "e2v_dev = extract_e2v(dev_stems, \"dev\") if USE_E2V else {}\n",
845
+ "sailer_dev = extract_sailer(dev_stems, \"dev\") if USE_SAILER else {}\n",
846
+ "utmos_dev = extract_utmos(dev_names, \"dev\") if USE_UTMOS_FEAT else {}\n",
847
+ "seq_dev = extract_wavlm_seq(dev_stems, \"dev\")\n",
848
+ "\n",
849
+ "@torch.no_grad()\n",
850
+ "def predict_all(sid):\n",
851
+ " f = audio_feature(sid, e2v_dev, sailer_dev)\n",
852
+ " if f is None:\n",
853
+ " return None\n",
854
+ " if USE_MAMBA and not os.path.exists(seq_path(sid)):\n",
855
+ " return None\n",
856
+ " fn = (f[None, :] - feat_mean) / feat_std\n",
857
+ " tgt = onehot_target(target_map.get(sid))[None, :]\n",
858
+ " u = np.array([[utmos_dev.get(sid, 3.0)]], dtype=np.float32); un = (u - u_mu) / u_sd\n",
859
+ " model.eval()\n",
860
+ " if USE_MAMBA:\n",
861
+ " seq, mask = collate_seqs([sid]); seq, mask = seq.to(device), mask.to(device)\n",
862
+ " q, e, cl, v = model(to_t(fn), to_t(tgt), to_t(un), seq, mask)\n",
863
+ " else:\n",
864
+ " q, e, cl, v = model(to_t(fn), to_t(tgt), to_t(un))\n",
865
+ " qmos = float(q.item()) * qmos_sd + qmos_mu\n",
866
+ " emos = float(e.item()) * emos_sd + emos_mu\n",
867
+ " cat5 = F.softmax(cl, dim=1)[0].cpu().numpy()\n",
868
+ " vad3 = v[0].cpu().numpy() * vad_sd + vad_mu\n",
869
+ " return qmos, emos, cat5, vad3\n",
870
+ "\n",
871
+ "def fmt_cat(probs5):\n",
872
+ " return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
873
+ "\n",
874
+ "def build_answer(out_path):\n",
875
+ " from tqdm.auto import tqdm\n",
876
+ " n_real = n_default = 0\n",
877
+ " with open(out_path, \"w\") as f:\n",
878
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
879
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
880
+ " sid = stem(name)\n",
881
+ " pred = predict_all(sid)\n",
882
+ " if pred is None:\n",
883
+ " qmos = utmos_dev.get(sid, 3.0)\n",
884
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])\n",
885
+ " n_default += 1\n",
886
+ " else:\n",
887
+ " qmos, emos, cat5, vad3 = pred; n_real += 1\n",
888
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
889
+ " f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
890
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | head thật {n_real}, mặc định {n_default}\")\n",
891
+ "\n",
892
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
893
+ "build_answer(answer_path)"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "markdown",
898
+ "id": "bcab20d3",
899
+ "metadata": {},
900
+ "source": [
901
+ "## 7. Validate + đóng zip"
902
+ ]
903
+ },
904
+ {
905
+ "cell_type": "code",
906
+ "execution_count": null,
907
+ "id": "e9b2e0ab",
908
+ "metadata": {},
909
+ "outputs": [],
910
+ "source": [
911
+ "def validate(path):\n",
912
+ " import csv\n",
913
+ " with open(path) as f:\n",
914
+ " rows = list(csv.reader(f))\n",
915
+ " header = rows[0]\n",
916
+ " assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
917
+ " for i, r in enumerate(rows[1:], 2):\n",
918
+ " assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
919
+ " print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
920
+ "\n",
921
+ "validate(answer_path)\n",
922
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp14_mamba.zip answer.txt \"\n",
923
+ " f\"&& unzip -l submission_track2_exp14_mamba.zip\")\n",
924
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp14_mamba.zip\"))"
925
+ ]
926
+ },
927
+ {
928
+ "cell_type": "markdown",
929
+ "id": "7604df81",
930
+ "metadata": {},
931
+ "source": [
932
+ "## Ghi chú\n",
933
+ "- **Ablation chính cho paper:** chạy 2 lần — `USE_MAMBA=False` (= exp07, mốc) và `USE_MAMBA=True`.\n",
934
+ " So QMOS/EMOS/VAD nội bộ → trả lời \"bộ mã hóa thời gian Mamba có hơn mean-pooling không?\".\n",
935
+ "- **Nếu hết đĩa khi cache chuỗi:** giảm `MAX_FRAMES` (256→160) hoặc xóa `wavlm_seq_cache/` sau khi chạy xong.\n",
936
+ "- **Nếu Mamba chậm:** thử `pip install mamba-ssm causal-conv1d` (file tự dùng nếu import được); hoặc giảm\n",
937
+ " `MAMBA_LAYERS`/`MAX_FRAMES`. Bản thuần PyTorch dùng vòng lặp thời gian nên chậm hơn kernel CUDA.\n",
938
+ "- **Save Version** để giữ cache `fusion_cache/` + `wavlm_seq_cache/` cho lần sau.\n",
939
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp14)."
940
+ ]
941
+ }
942
+ ],
943
+ "metadata": {
944
+ "jupytext": {
945
+ "cell_metadata_filter": "-all",
946
+ "main_language": "python",
947
+ "notebook_metadata_filter": "-all"
948
+ }
949
+ },
950
+ "nbformat": 4,
951
+ "nbformat_minor": 5
952
+ }
track2/exp14_mamba_head_pipeline.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp14 (MAMBA temporal head, CỘNG vào FUSION 6 cột) — Kaggle
3
+ #
4
+ # **Ý tưởng (theo gợi ý mentor "thử Mamba"):** exp04/exp07 đều **mean-pool** đặc trưng SSL →
5
+ # mỗi wav thành 1 vector → mất hết **động lực theo thời gian** (lên/xuống giọng, ngắt quãng, rung).
6
+ # **Mamba** là State Space Model (SSM) xử lý **chuỗi** với độ phức tạp tuyến tính → cho nó **dãy frame**
7
+ # (chưa pool) để học temporal dynamics, rồi mới pool. Tham khảo: MambaRate (AudioMOS 2025), arXiv:2507.12090.
8
+ #
9
+ # ## exp14 = exp07 + 1 nhánh Mamba (CỘNG thêm, không thay thế)
10
+ # ```
11
+ # ┌─ đặc trưng POOLED [e2v_emb|e2v_p5|sailer_emb|sailer_p9|sailer_vad3] (y hệt exp07 → DÙNG LẠI cache)
12
+ # mỗi wav ──┤
13
+ # └─ WavLM frame-level (chuỗi T×1024) ─► Mamba (2 lớp, 2 chiều) ─► attn-pool ─► z_seq (Z chiều)
14
+ # │
15
+ # concat ──► TRUNK chung ──► 6 head: QMOS · EMOS · CAT · VAL · ARO · DOM
16
+ # ```
17
+ # - **Cờ `USE_MAMBA`:** `False` → chạy ra **đúng exp07** (kiểm chứng tái lập ~0.548/0.795). `True` → bật nhánh Mamba.
18
+ # Đây CHÍNH là **ablation "có/không Mamba"** cho paper.
19
+ # - WavLM **đóng băng** (chỉ trích đặc trưng) → Mamba head nhỏ → train nhanh, vừa T4.
20
+ #
21
+ # ## 2 gotcha Kaggle đã xử trong file
22
+ # 1. `mamba-ssm` hay lỗi build CUDA → **nhúng sẵn Mamba thuần PyTorch** (không cần pip); tự dùng `mamba-ssm` nếu import được.
23
+ # 2. Cache frame-level RẤT nặng → **cap `MAX_FRAMES`** + lưu **fp16**. Ước lượng: MAX_FRAMES=256, 1024 chiều, fp16
24
+ # ≈ 0.5 MB/wav → train ~12k ≈ 6 GB, dev ~2.7k ≈ 1.4 GB (vừa /kaggle/working). **Save Version** để giữ cache.
25
+ #
26
+ # **Cách chạy:** GPU T4 + Internet On → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.
27
+ # Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để soi nhanh; OK rồi đặt `None`.
28
+
29
+ # %% [markdown]
30
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
31
+
32
+ # %%
33
+ import os
34
+
35
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
36
+ WAV_DIR = f"{DATA_ROOT}/wav"
37
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
38
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
39
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
40
+
41
+ OUT_DIR = "/kaggle/working"
42
+ CACHE_DIR = "/kaggle/working/fusion_cache" # DÙNG CHUNG với exp04/exp07 (e2v_*, sailer_*, utmos_*)
43
+ SEQ_DIR = "/kaggle/working/wavlm_seq_cache" # MỚI: cache frame-level WavLM (fp16)
44
+ os.makedirs(CACHE_DIR, exist_ok=True)
45
+ os.makedirs(SEQ_DIR, exist_ok=True)
46
+
47
+ # ── Bật/tắt nhánh Mamba (ablation chính) ─────────────────────────────────────
48
+ USE_MAMBA = True # False → ra ĐÚNG exp07 (sanity check). True → bật nhánh Mamba.
49
+
50
+ # ── Siêu tham số nhánh Mamba ─────────────────────────────────────────────────
51
+ WAVLM_NAME = "microsoft/wavlm-large" # backbone frame-level (đóng băng). Trả chuỗi (T, 1024).
52
+ MAX_FRAMES = 256 # cap độ dài chuỗi (256 frame ≈ 5.1s @ 50Hz). Giảm nếu hết đĩa.
53
+ MAMBA_DMODEL = 256 # chiều ẩn của khối Mamba (proj 1024→256 trước khi vào Mamba)
54
+ MAMBA_LAYERS = 2 # số khối Mamba xếp chồng
55
+ MAMBA_DSTATE = 16 # chiều state SSM
56
+ BIDIRECTIONAL = True # chạy Mamba cả 2 chiều (xuôi + ngược) rồi cộng
57
+ Z_DIM = 128 # chiều vector z_seq sau attentive-pool, đem concat vào fusion
58
+
59
+ # ── Siêu tham số fusion (giống exp07) ────────────────────────────────────────
60
+ DEVICE = "cuda"
61
+ TRUNK_HIDDEN = 512
62
+ HEAD_HIDDEN = 128
63
+ DROPOUT = 0.3
64
+ LR = 1e-3
65
+ EPOCHS = 80
66
+ BATCH = 32 # nhỏ hơn exp07 (64) vì có nhánh Mamba tốn RAM hơn
67
+ VAL_FRAC = 0.10
68
+ PATIENCE = 15
69
+ SEED = 42
70
+
71
+ USE_UNCERTAINTY = True
72
+ LOSS_W = {"qmos": 1.0, "emos": 1.0, "cat": 1.0, "val": 1.0, "aro": 1.0, "dom": 1.0}
73
+ USE_E2V = True
74
+ USE_SAILER = True
75
+ USE_CLASSPROB = True
76
+ USE_UTMOS_FEAT = True
77
+
78
+ LIMIT_TRAIN = None
79
+ LIMIT_DEV = None
80
+
81
+ # Mốc exp07 để so (đây là hệ thống đang tốt nhất)
82
+ EXP07 = {"qmos": 0.548, "emos": 0.795, "cat_err": 0.153, "val": 0.581, "aro": 0.752, "dom": 0.705}
83
+
84
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
85
+ SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
86
+
87
+ _EMO_ALIAS = {
88
+ "angry": "angry", "anger": "angry",
89
+ "happy": "happy", "happiness": "happy", "joy": "happy",
90
+ "neutral": "neutral", "calm": "neutral",
91
+ "sad": "sad", "sadness": "sad",
92
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
93
+ }
94
+
95
+ def norm_emotion(label):
96
+ key = str(label).strip().lower()
97
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
98
+
99
+ def stem(p):
100
+ return os.path.splitext(os.path.basename(str(p)))[0]
101
+
102
+ assert USE_E2V or USE_SAILER, "Phải bật ít nhất 1 backbone pooled."
103
+ print("USE_MAMBA =", USE_MAMBA, "| nếu False → ra đúng exp07")
104
+ print("DATA_ROOT:", DATA_ROOT)
105
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
106
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
107
+
108
+ # %% [markdown]
109
+ # ## 1. Cài đặt + tải code SAILER
110
+ # Chỉ cài gói còn thiếu (Kaggle có sẵn torch/transformers). KHÔNG đụng numpy (tránh lệch ABI torch — bài học exp12).
111
+
112
+ # %%
113
+ import sys, subprocess
114
+
115
+ def pip_install(*pkgs):
116
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
117
+
118
+ pip_install("speechmos", "funasr", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
119
+
120
+ if USE_SAILER:
121
+ pip_install("loralib", "speechbrain")
122
+ REPO_DIR = "/kaggle/working/vox-profile-release"
123
+ if not os.path.exists(REPO_DIR):
124
+ subprocess.run(["git", "clone", "--depth", "1",
125
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
126
+ if REPO_DIR not in sys.path:
127
+ sys.path.insert(0, REPO_DIR)
128
+
129
+ # %% [markdown]
130
+ # ## 2. Đọc & gộp nhãn theo wavID (giống exp07)
131
+
132
+ # %%
133
+ import numpy as np
134
+ import pandas as pd
135
+
136
+ def load_target_emotions():
137
+ tgt = {}
138
+ with open(METADATA_CSV, encoding="utf-8") as f:
139
+ for ln in f:
140
+ parts = ln.strip().split("|")
141
+ if len(parts) < 2:
142
+ continue
143
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
144
+ return tgt
145
+
146
+ def _col(cols_map, *names, default_idx=None, df=None):
147
+ for n in names:
148
+ if n in cols_map:
149
+ return cols_map[n]
150
+ return list(df.columns)[default_idx] if default_idx is not None else None
151
+
152
+ def parse_emocat_votes(cell):
153
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
154
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
155
+ e = norm_emotion(tok)
156
+ if e in EMOTIONS5:
157
+ v[EMOTIONS5.index(e)] += 1.0
158
+ return v
159
+
160
+ def load_train_labels():
161
+ df = pd.read_csv(TRAIN_CSV, sep="|")
162
+ cols = {c.lower().strip(): c for c in df.columns}
163
+ wav_col = _col(cols, "wavid", "wav", default_idx=1, df=df)
164
+ qmos_col = _col(cols, "qmos", "mos")
165
+ emos_col = _col(cols, "emos", "emo", "emomos")
166
+ val_col = _col(cols, "val", "valence")
167
+ aro_col = _col(cols, "aro", "arousal")
168
+ dom_col = _col(cols, "dom", "dominance")
169
+ cat_col = _col(cols, "emocat", "cat", "emotion")
170
+ assert qmos_col and emos_col, f"Thiếu cột qMOS/eMOS (cột: {list(df.columns)})"
171
+ df["_stem"] = df[wav_col].map(stem)
172
+ rows = []
173
+ for sid, g in df.groupby("_stem"):
174
+ rec = {"wavID": sid, "qmos": float(g[qmos_col].mean()), "emos": float(g[emos_col].mean())}
175
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
176
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
177
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
178
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
179
+ if cat_col:
180
+ for cell in g[cat_col]:
181
+ votes += parse_emocat_votes(cell)
182
+ s = votes.sum()
183
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)
184
+ for i in range(len(EMOTIONS5)):
185
+ rec[f"cat{i}"] = float(cat[i])
186
+ rows.append(rec)
187
+ return pd.DataFrame(rows)
188
+
189
+ target_map = load_target_emotions()
190
+ train_df = load_train_labels()
191
+ HAS_VAD = bool(train_df["val"].notna().any())
192
+ print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
193
+
194
+ # %% [markdown]
195
+ # ## 3. Đặc trưng POOLED (e2v + sailer + UTMOS) — TÁI DÙNG cache exp04/exp07
196
+ # (Y hệt exp07; nếu đã chạy exp07 thì cache `fusion_cache/` còn nguyên → không tính lại.)
197
+
198
+ # %%
199
+ import torch
200
+ import torch.nn.functional as F
201
+
202
+ device = DEVICE if torch.cuda.is_available() else "cpu"
203
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU")
204
+
205
+ def extract_e2v(stems, tag):
206
+ from tqdm.auto import tqdm
207
+ cache_path = os.path.join(CACHE_DIR, f"e2v_{tag}.npz")
208
+ store = {}
209
+ if os.path.exists(cache_path):
210
+ z = np.load(cache_path, allow_pickle=True)
211
+ store = {k: z[k] for k in z.files}
212
+ print(f"[e2v/{tag}] nạp cache: {len(store)}")
213
+ todo = [s for s in stems if s not in store]
214
+ if todo:
215
+ from funasr import AutoModel
216
+ m = AutoModel(model="iic/emotion2vec_plus_large", hub="hf", device=device)
217
+ for i, s in enumerate(tqdm(todo, desc=f"e2v {tag}")):
218
+ wav = os.path.join(WAV_DIR, s + ".wav")
219
+ if not os.path.exists(wav):
220
+ continue
221
+ r = m.generate(wav, granularity="utterance", extract_embedding=True)[0]
222
+ emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
223
+ probs = {e: 0.0 for e in EMOTIONS5}
224
+ for lab, sc in zip(r["labels"], r["scores"]):
225
+ name = lab.split("/")[-1]
226
+ if name in probs:
227
+ probs[name] = float(sc)
228
+ tot = sum(probs.values())
229
+ p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)
230
+ store[s] = np.concatenate([emb, p5]).astype(np.float32)
231
+ if (i + 1) % 500 == 0:
232
+ np.savez(cache_path, **store)
233
+ np.savez(cache_path, **store)
234
+ del m
235
+ torch.cuda.empty_cache() if device == "cuda" else None
236
+ return {s: (v[:-5], v[-5:]) for s, v in store.items()}
237
+
238
+ def _pool_feat(features):
239
+ f = features.detach().cpu().numpy()
240
+ if f.ndim <= 1:
241
+ return f.reshape(-1).astype(np.float32)
242
+ return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)
243
+
244
+ def extract_sailer(stems, tag):
245
+ import librosa
246
+ from tqdm.auto import tqdm
247
+ cache_path = os.path.join(CACHE_DIR, f"sailer_{tag}.npz")
248
+ store = {}
249
+ if os.path.exists(cache_path):
250
+ z = np.load(cache_path, allow_pickle=True)
251
+ store = {k: z[k] for k in z.files}
252
+ print(f"[sailer/{tag}] nạp cache: {len(store)}")
253
+ todo = [s for s in stems if s not in store]
254
+ if todo:
255
+ from src.model.emotion.wavlm_emotion import WavLMWrapper
256
+ sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device).eval()
257
+ with torch.no_grad():
258
+ for i, s in enumerate(tqdm(todo, desc=f"sailer {tag}")):
259
+ wav = os.path.join(WAV_DIR, s + ".wav")
260
+ if not os.path.exists(wav):
261
+ continue
262
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
263
+ wave = wave[: 15 * 16000]
264
+ data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
265
+ logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
266
+ emb = _pool_feat(feat)
267
+ p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
268
+ vad3 = np.array([1 + 4 * float(valence.item()),
269
+ 1 + 4 * float(arousal.item()),
270
+ 1 + 4 * float(dominance.item())], dtype=np.float32)
271
+ store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)
272
+ if (i + 1) % 500 == 0:
273
+ np.savez(cache_path, **store)
274
+ np.savez(cache_path, **store)
275
+ del sailer
276
+ torch.cuda.empty_cache() if device == "cuda" else None
277
+ return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}
278
+
279
+ def extract_utmos(names, tag):
280
+ import librosa
281
+ from tqdm.auto import tqdm
282
+ cache_path = os.path.join(CACHE_DIR, f"utmos_{tag}.npz")
283
+ store = {}
284
+ if os.path.exists(cache_path):
285
+ z = np.load(cache_path, allow_pickle=True)
286
+ store = {k: float(z[k]) for k in z.files}
287
+ print(f"[utmos/{tag}] nạp cache: {len(store)}")
288
+ todo = [n for n in names if stem(n) not in store]
289
+ if todo:
290
+ predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong",
291
+ trust_repo=True).to(device).eval()
292
+ with torch.no_grad():
293
+ for i, n in enumerate(tqdm(todo, desc=f"utmos {tag}")):
294
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else n + ".wav")
295
+ if not os.path.exists(wav):
296
+ continue
297
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
298
+ store[stem(n)] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),
299
+ sr=16000).mean().item())
300
+ if (i + 1) % 500 == 0:
301
+ np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
302
+ np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
303
+ del predictor
304
+ torch.cuda.empty_cache() if device == "cuda" else None
305
+ return store
306
+
307
+ # %% [markdown]
308
+ # ## 3b. Đặc trưng FRAME-LEVEL WavLM (chuỗi T×1024) cho nhánh Mamba — cache fp16
309
+ # Mỗi wav lưu 1 file `.npy` riêng trong `SEQ_DIR` (mảng fp16 [T, 1024], T ≤ MAX_FRAMES).
310
+ # WavLM **đóng băng** (eval, no_grad) → layerdrop tự tắt ở eval, không đụng gotcha checkpoint.
311
+
312
+ # %%
313
+ _wavlm = None
314
+ def _get_wavlm():
315
+ """Lazy-load microsoft/wavlm-large (đóng băng). Trả model + feature_extractor."""
316
+ global _wavlm
317
+ if _wavlm is None:
318
+ from transformers import WavLMModel, AutoFeatureExtractor
319
+ fe = AutoFeatureExtractor.from_pretrained(WAVLM_NAME)
320
+ mdl = WavLMModel.from_pretrained(WAVLM_NAME).to(device).eval()
321
+ for p in mdl.parameters():
322
+ p.requires_grad = False
323
+ _wavlm = (mdl, fe)
324
+ return _wavlm
325
+
326
+ def seq_path(sid):
327
+ return os.path.join(SEQ_DIR, sid + ".npy")
328
+
329
+ def extract_wavlm_seq(stems, tag):
330
+ """Trích frame-level WavLM cho từng wav, cache fp16 ra .npy. Trả set stem đã có."""
331
+ if not USE_MAMBA:
332
+ return set()
333
+ import librosa
334
+ from tqdm.auto import tqdm
335
+ todo = [s for s in stems if not os.path.exists(seq_path(s))]
336
+ if todo:
337
+ mdl, fe = _get_wavlm()
338
+ with torch.no_grad():
339
+ for i, s in enumerate(tqdm(todo, desc=f"wavlm-seq {tag}")):
340
+ wav = os.path.join(WAV_DIR, s + ".wav")
341
+ if not os.path.exists(wav):
342
+ continue
343
+ wave, _ = librosa.load(wav, sr=16000, mono=True)
344
+ wave = wave[: 15 * 16000]
345
+ inp = fe(wave, sampling_rate=16000, return_tensors="pt").input_values.to(device)
346
+ hs = mdl(inp).last_hidden_state[0] # (T, 1024)
347
+ if hs.shape[0] > MAX_FRAMES: # cap độ dài (đều theo thời gian)
348
+ idx = torch.linspace(0, hs.shape[0] - 1, MAX_FRAMES).long()
349
+ hs = hs[idx]
350
+ np.save(seq_path(s), hs.cpu().numpy().astype(np.float16))
351
+ torch.cuda.empty_cache() if device == "cuda" else None
352
+ return {s for s in stems if os.path.exists(seq_path(s))}
353
+
354
+ def load_seq(sid):
355
+ """Đọc chuỗi fp16 → tensor float32 (T, 1024). Thiếu file → None."""
356
+ p = seq_path(sid)
357
+ if not os.path.exists(p):
358
+ return None
359
+ return torch.from_numpy(np.load(p).astype(np.float32))
360
+
361
+ def collate_seqs(sids):
362
+ """Gộp list chuỗi độ dài khác nhau → (B, Lmax, 1024) + mask (B, Lmax) bool (True=thật)."""
363
+ seqs = [load_seq(s) for s in sids]
364
+ lens = [t.shape[0] for t in seqs]
365
+ Lmax = max(lens)
366
+ B = len(seqs)
367
+ x = torch.zeros(B, Lmax, seqs[0].shape[1], dtype=torch.float32)
368
+ mask = torch.zeros(B, Lmax, dtype=torch.bool)
369
+ for i, t in enumerate(seqs):
370
+ x[i, : t.shape[0]] = t
371
+ mask[i, : t.shape[0]] = True
372
+ return x, mask
373
+
374
+ # %% [markdown]
375
+ # ## 4. Dựng feature pooled + nhãn cho train (lọc các wav đủ mọi nguồn)
376
+
377
+ # %%
378
+ train_stems = list(train_df["wavID"])
379
+ if LIMIT_TRAIN:
380
+ train_stems = train_stems[:LIMIT_TRAIN]
381
+
382
+ e2v_tr = extract_e2v(train_stems, "train") if USE_E2V else {}
383
+ sailer_tr = extract_sailer(train_stems, "train") if USE_SAILER else {}
384
+ utmos_tr = extract_utmos(train_stems, "train") if USE_UTMOS_FEAT else {}
385
+ seq_tr = extract_wavlm_seq(train_stems, "train")
386
+
387
+ def audio_feature(sid, e2v_map, sailer_map):
388
+ parts = []
389
+ if USE_E2V:
390
+ pk = e2v_map.get(sid)
391
+ if pk is None:
392
+ return None
393
+ emb, p5 = pk
394
+ parts.append(emb)
395
+ if USE_CLASSPROB:
396
+ parts.append(p5)
397
+ if USE_SAILER:
398
+ pk = sailer_map.get(sid)
399
+ if pk is None:
400
+ return None
401
+ emb, p9, vad3 = pk
402
+ parts.append(emb)
403
+ if USE_CLASSPROB:
404
+ parts.append(p9); parts.append(vad3)
405
+ return np.concatenate(parts).astype(np.float32)
406
+
407
+ def onehot_target(tgt):
408
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
409
+ if tgt in EMOTIONS5:
410
+ v[EMOTIONS5.index(tgt)] = 1.0
411
+ return v
412
+
413
+ lab = train_df.set_index("wavID")
414
+ keep_sids, X, T, U = [], [], [], []
415
+ y_qmos, y_emos, y_vad, y_cat = [], [], [], []
416
+ for s in train_stems:
417
+ f = audio_feature(s, e2v_tr, sailer_tr)
418
+ tgt = target_map.get(s)
419
+ if f is None or tgt is None or s not in lab.index:
420
+ continue
421
+ if USE_UTMOS_FEAT and s not in utmos_tr:
422
+ continue
423
+ if USE_MAMBA and s not in seq_tr: # cần có chuỗi WavLM nếu bật Mamba
424
+ continue
425
+ keep_sids.append(s)
426
+ X.append(f)
427
+ T.append(onehot_target(tgt))
428
+ U.append(utmos_tr.get(s, 3.0) if USE_UTMOS_FEAT else 0.0)
429
+ y_qmos.append(lab.loc[s, "qmos"]); y_emos.append(lab.loc[s, "emos"])
430
+ y_vad.append([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]])
431
+ y_cat.append([lab.loc[s, f"cat{i}"] for i in range(len(EMOTIONS5))])
432
+
433
+ X = np.stack(X).astype(np.float32)
434
+ T = np.stack(T).astype(np.float32)
435
+ U = np.array(U, dtype=np.float32).reshape(-1, 1)
436
+ y_qmos = np.array(y_qmos, dtype=np.float32); y_emos = np.array(y_emos, dtype=np.float32)
437
+ y_vad = np.array(y_vad, dtype=np.float32); y_cat = np.array(y_cat, dtype=np.float32)
438
+ FEAT_DIM = X.shape[1]
439
+ print(f"Train giữ lại: {len(keep_sids)} wav | X={X.shape} | Mamba={'ON' if USE_MAMBA else 'OFF'}")
440
+
441
+ # Chuẩn hóa feature pooled + UTMOS + nhãn liên tục (z-score)
442
+ feat_mean = X.mean(0, keepdims=True); feat_std = X.std(0, keepdims=True) + 1e-6
443
+ Xn = (X - feat_mean) / feat_std
444
+ u_mu, u_sd = float(U.mean()), float(U.std() + 1e-6); Un = (U - u_mu) / u_sd
445
+ qmos_mu, qmos_sd = float(y_qmos.mean()), float(y_qmos.std() + 1e-6); y_qmos_z = (y_qmos - qmos_mu) / qmos_sd
446
+ emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6); y_emos_z = (y_emos - emos_mu) / emos_sd
447
+ if HAS_VAD:
448
+ vad_mu = np.nanmean(y_vad, axis=0); vad_sd = np.nanstd(y_vad, axis=0) + 1e-6
449
+ y_vad_z = (y_vad - vad_mu) / vad_sd
450
+ else:
451
+ vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32); y_vad_z = np.zeros_like(y_vad)
452
+
453
+ # %% [markdown]
454
+ # ## 5a. Khối MAMBA (thuần PyTorch, không cần `mamba-ssm`)
455
+ # Tự dùng `mamba-ssm` nếu import được (nhanh hơn); nếu không → bản thuần PyTorch (selective scan vòng lặp thời gian).
456
+ # Bản này theo "mamba-minimal" (johnma2006) — đúng công thức, chỉ chậm hơn kernel CUDA, nhưng head nhỏ nên OK trên T4.
457
+
458
+ # %%
459
+ import math
460
+ import torch.nn as nn
461
+
462
+ try:
463
+ from mamba_ssm import Mamba as _OfficialMamba # nếu cài được thì dùng (tùy chọn)
464
+ _HAS_MAMBA_SSM = True
465
+ print("✅ Dùng mamba-ssm (CUDA kernel)")
466
+ except Exception:
467
+ _HAS_MAMBA_SSM = False
468
+ print("ℹ️ Không có mamba-ssm → dùng Mamba thuần PyTorch (nhúng sẵn)")
469
+
470
+ class MambaBlockTorch(nn.Module):
471
+ """Một khối Mamba (selective SSM) thuần PyTorch. d_model = chiều ẩn."""
472
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
473
+ super().__init__()
474
+ self.d_inner = expand * d_model
475
+ self.dt_rank = math.ceil(d_model / 16)
476
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
477
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,
478
+ groups=self.d_inner, padding=d_conv - 1, bias=True)
479
+ self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
480
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
481
+ A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
482
+ self.A_log = nn.Parameter(torch.log(A)) # (d_inner, d_state)
483
+ self.D = nn.Parameter(torch.ones(self.d_inner))
484
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
485
+ self.d_state = d_state
486
+
487
+ def forward(self, x): # x: (B, L, d_model)
488
+ B, L, _ = x.shape
489
+ xz = self.in_proj(x) # (B, L, 2*d_inner)
490
+ xin, z = xz.chunk(2, dim=-1)
491
+ xin = xin.transpose(1, 2) # (B, d_inner, L)
492
+ xin = self.conv1d(xin)[..., :L].transpose(1, 2) # (B, L, d_inner) causal conv
493
+ xin = F.silu(xin)
494
+ y = self._ssm(xin) # (B, L, d_inner)
495
+ y = y * F.silu(z)
496
+ return self.out_proj(y)
497
+
498
+ def _ssm(self, x): # x: (B, L, d_inner)
499
+ A = -torch.exp(self.A_log) # (d_inner, d_state)
500
+ x_dbl = self.x_proj(x) # (B, L, dt_rank + 2*d_state)
501
+ delta, Bm, Cm = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
502
+ delta = F.softplus(self.dt_proj(delta)) # (B, L, d_inner)
503
+ dA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)
504
+ dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1) # (B, L, d_inner, d_state)
505
+ h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
506
+ ys = []
507
+ for t in range(x.shape[1]): # selective scan theo thời gian
508
+ h = dA[:, t] * h + dB_x[:, t]
509
+ ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1)) # (B, d_inner)
510
+ y = torch.stack(ys, dim=1) # (B, L, d_inner)
511
+ return y + x * self.D
512
+
513
+ class MambaLayer(nn.Module):
514
+ """Pre-norm residual quanh 1 khối Mamba (chọn official nếu có)."""
515
+ def __init__(self, d_model, d_state):
516
+ super().__init__()
517
+ self.norm = nn.LayerNorm(d_model)
518
+ if _HAS_MAMBA_SSM:
519
+ self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2)
520
+ else:
521
+ self.mix = MambaBlockTorch(d_model, d_state=d_state)
522
+
523
+ def forward(self, x):
524
+ return x + self.mix(self.norm(x))
525
+
526
+ class MambaEncoder(nn.Module):
527
+ """1024 → d_model → [Mamba ×L] (2 chiều nếu BIDIRECTIONAL) → attentive-pool → Z_DIM."""
528
+ def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):
529
+ super().__init__()
530
+ self.bidir = bidir
531
+ self.proj = nn.Linear(d_in, d_model)
532
+ self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
533
+ if bidir:
534
+ self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
535
+ self.attn = nn.Linear(d_model, 1) # attentive pooling
536
+ self.out = nn.Linear(d_model, z_dim)
537
+
538
+ def _run(self, layers, h):
539
+ for L in layers:
540
+ h = L(h)
541
+ return h
542
+
543
+ def forward(self, x, mask): # x: (B, L, 1024), mask: (B, L) bool
544
+ h = self.proj(x)
545
+ out = self._run(self.fwd, h)
546
+ if self.bidir:
547
+ rev = torch.flip(h, dims=[1])
548
+ out = out + torch.flip(self._run(self.bwd, rev), dims=[1])
549
+ a = self.attn(out).squeeze(-1) # (B, L)
550
+ a = a.masked_fill(~mask, float("-inf"))
551
+ w = torch.softmax(a, dim=1).unsqueeze(-1) # (B, L, 1)
552
+ pooled = (out * w).sum(1) # (B, d_model)
553
+ return self.out(pooled) # (B, z_dim)
554
+
555
+ # %% [markdown]
556
+ # ## 5b. Model fusion 6 head + nhánh Mamba + train loop
557
+
558
+ # %%
559
+ from scipy.stats import spearmanr
560
+ from sklearn.model_selection import train_test_split
561
+
562
+ torch.manual_seed(SEED); np.random.seed(SEED)
563
+ N_EMO = len(EMOTIONS5)
564
+ idx_all = np.arange(X.shape[0])
565
+ tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)
566
+
567
+ def to_t(a):
568
+ return torch.tensor(a, dtype=torch.float32, device=device)
569
+
570
+ Xn_t, T_t, Un_t = to_t(Xn), to_t(T), to_t(Un)
571
+ qmos_t = to_t(y_qmos_z).unsqueeze(1); emos_t = to_t(y_emos_z).unsqueeze(1)
572
+ vad_t = to_t(y_vad_z); cat_t = to_t(y_cat)
573
+
574
+ class FusionMamba6(nn.Module):
575
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo, use_utmos, use_mamba):
576
+ super().__init__()
577
+ self.use_utmos = use_utmos
578
+ self.use_mamba = use_mamba
579
+ z_extra = Z_DIM if use_mamba else 0
580
+ if use_mamba:
581
+ self.enc = MambaEncoder(1024, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL)
582
+ self.trunk = nn.Sequential(
583
+ nn.Linear(d_in + z_extra, trunk_h), nn.ReLU(), nn.Dropout(p),
584
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
585
+ self.qmos = nn.Sequential(
586
+ nn.Linear(trunk_h + (1 if use_utmos else 0), head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
587
+ self.emos = nn.Sequential(
588
+ nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
589
+ self.cat = nn.Sequential(
590
+ nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
591
+ self.vad = nn.Sequential(
592
+ nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
593
+
594
+ def forward(self, x, tgt, utmos, seq=None, mask=None):
595
+ if self.use_mamba:
596
+ z = self.enc(seq, mask)
597
+ x = torch.cat([x, z], dim=1)
598
+ h = self.trunk(x)
599
+ qmos_in = torch.cat([h, utmos], dim=1) if self.use_utmos else h
600
+ return self.qmos(qmos_in), self.emos(torch.cat([h, tgt], dim=1)), self.cat(h), self.vad(h)
601
+
602
+ model = FusionMamba6(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO, USE_UTMOS_FEAT, USE_MAMBA).to(device)
603
+ n_par = sum(p.numel() for p in model.parameters() if p.requires_grad)
604
+ print(f"Tham số train được: {n_par/1e6:.2f} M")
605
+
606
+ TASKS = ["qmos", "emos", "cat", "val", "aro", "dom"]
607
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
608
+ params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
609
+ opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)
610
+ mse = nn.MSELoss(reduction="none")
611
+
612
+ def soft_ce(logits, target_dist):
613
+ return -(target_dist * F.log_softmax(logits, dim=1)).sum(dim=1)
614
+
615
+ def task_losses(qmos_p, emos_p, cat_logits, vad_p, b):
616
+ L = {"qmos": mse(qmos_p, qmos_t[b]).mean(),
617
+ "emos": mse(emos_p, emos_t[b]).mean(),
618
+ "cat": soft_ce(cat_logits, cat_t[b]).mean()}
619
+ if HAS_VAD:
620
+ L["val"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()
621
+ L["aro"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()
622
+ L["dom"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()
623
+ else:
624
+ z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
625
+ return L
626
+
627
+ def combine(L):
628
+ if USE_UNCERTAINTY:
629
+ return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
630
+ return sum(LOSS_W[t] * L[t] for t in TASKS)
631
+
632
+ # batch theo INDEX (vì nhánh Mamba cần đọc chuỗi theo sid → collate động)
633
+ sids_arr = np.array(keep_sids)
634
+
635
+ def forward_batch(bidx):
636
+ """bidx: numpy index. Trả output model cho batch (tự collate chuỗi nếu bật Mamba)."""
637
+ bt = torch.tensor(bidx, device=device)
638
+ if USE_MAMBA:
639
+ seq, mask = collate_seqs(list(sids_arr[bidx]))
640
+ seq, mask = seq.to(device), mask.to(device)
641
+ return model(Xn_t[bt], T_t[bt], Un_t[bt], seq, mask)
642
+ return model(Xn_t[bt], T_t[bt], Un_t[bt])
643
+
644
+ @torch.no_grad()
645
+ def eval_val():
646
+ model.eval()
647
+ qp, ep, vp = [], [], []
648
+ for i in range(0, len(va_idx), BATCH):
649
+ b = va_idx[i:i + BATCH]
650
+ q, e, _cl, v = forward_batch(b)
651
+ qp.append(q.cpu().numpy().ravel()); ep.append(e.cpu().numpy().ravel()); vp.append(v.cpu().numpy())
652
+ qp = np.concatenate(qp); ep = np.concatenate(ep); vp = np.concatenate(vp)
653
+ out = {"qmos": spearmanr(qp, y_qmos[va_idx]).correlation,
654
+ "emos": spearmanr(ep, y_emos[va_idx]).correlation}
655
+ if USE_UTMOS_FEAT:
656
+ out["qmos_utmos"] = spearmanr(U[va_idx, 0], y_qmos[va_idx]).correlation
657
+ if HAS_VAD:
658
+ for j, t in enumerate(["val", "aro", "dom"]):
659
+ out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation
660
+ return out
661
+
662
+ def val_score(m):
663
+ keys = ["qmos", "emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
664
+ return float(np.mean([m[k] for k in keys]))
665
+
666
+ best_score, best_state, bad = -1e9, None, 0
667
+ for ep_i in range(1, EPOCHS + 1):
668
+ model.train()
669
+ perm = np.random.permutation(tr_idx)
670
+ run = 0.0
671
+ for i in range(0, len(perm), BATCH):
672
+ b = perm[i:i + BATCH]
673
+ opt.zero_grad()
674
+ q, e, cl, v = forward_batch(b)
675
+ loss = combine(task_losses(q, e, cl, v, torch.tensor(b, device=device)))
676
+ loss.backward(); opt.step()
677
+ run += loss.item() * len(b)
678
+ m = eval_val(); sc = val_score(m)
679
+ if sc > best_score:
680
+ best_score = sc; bad = 0
681
+ best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
682
+ else:
683
+ bad += 1
684
+ if ep_i % 2 == 0 or ep_i == 1:
685
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in ["qmos", "emos", "val", "aro", "dom"] if k in m)
686
+ print(f"epoch {ep_i:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}")
687
+ if bad >= PATIENCE:
688
+ print(f"Early stop ở epoch {ep_i}."); break
689
+
690
+ model.load_state_dict(best_state)
691
+ final = eval_val()
692
+ print(f"\n✅ VAL (nội bộ) — exp14 (Mamba={'ON' if USE_MAMBA else 'OFF'}):")
693
+ print(f" QMOS={final['qmos']:.4f} (exp07 {EXP07['qmos']}) | EMOS={final['emos']:.4f} (exp07 {EXP07['emos']})")
694
+ if HAS_VAD:
695
+ print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}"
696
+ f" (exp07 {EXP07['val']}/{EXP07['aro']}/{EXP07['dom']})")
697
+ print(" → So sánh USE_MAMBA True vs False = ablation Mamba cho paper.")
698
+
699
+ torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
700
+ "u_mu": u_mu, "u_sd": u_sd, "qmos_mu": qmos_mu, "qmos_sd": qmos_sd,
701
+ "emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
702
+ "FEAT_DIM": FEAT_DIM, "USE_MAMBA": USE_MAMBA, "val_score": best_score},
703
+ os.path.join(OUT_DIR, "fusion_mamba_mtl.pt"))
704
+ print("Đã lưu", os.path.join(OUT_DIR, "fusion_mamba_mtl.pt"))
705
+
706
+ # %% [markdown]
707
+ # ## 6. Dự đoán DEV → `answer.txt` đủ 6 cột
708
+
709
+ # %%
710
+ def list_dev():
711
+ with open(DEV_SCP) as f:
712
+ return [ln.strip() for ln in f if ln.strip()]
713
+
714
+ dev_names = list_dev()
715
+ if LIMIT_DEV:
716
+ dev_names = dev_names[:LIMIT_DEV]
717
+ dev_stems = [stem(n) for n in dev_names]
718
+ print("DEV:", len(dev_names), "mẫu")
719
+
720
+ e2v_dev = extract_e2v(dev_stems, "dev") if USE_E2V else {}
721
+ sailer_dev = extract_sailer(dev_stems, "dev") if USE_SAILER else {}
722
+ utmos_dev = extract_utmos(dev_names, "dev") if USE_UTMOS_FEAT else {}
723
+ seq_dev = extract_wavlm_seq(dev_stems, "dev")
724
+
725
+ @torch.no_grad()
726
+ def predict_all(sid):
727
+ f = audio_feature(sid, e2v_dev, sailer_dev)
728
+ if f is None:
729
+ return None
730
+ if USE_MAMBA and not os.path.exists(seq_path(sid)):
731
+ return None
732
+ fn = (f[None, :] - feat_mean) / feat_std
733
+ tgt = onehot_target(target_map.get(sid))[None, :]
734
+ u = np.array([[utmos_dev.get(sid, 3.0)]], dtype=np.float32); un = (u - u_mu) / u_sd
735
+ model.eval()
736
+ if USE_MAMBA:
737
+ seq, mask = collate_seqs([sid]); seq, mask = seq.to(device), mask.to(device)
738
+ q, e, cl, v = model(to_t(fn), to_t(tgt), to_t(un), seq, mask)
739
+ else:
740
+ q, e, cl, v = model(to_t(fn), to_t(tgt), to_t(un))
741
+ qmos = float(q.item()) * qmos_sd + qmos_mu
742
+ emos = float(e.item()) * emos_sd + emos_mu
743
+ cat5 = F.softmax(cl, dim=1)[0].cpu().numpy()
744
+ vad3 = v[0].cpu().numpy() * vad_sd + vad_mu
745
+ return qmos, emos, cat5, vad3
746
+
747
+ def fmt_cat(probs5):
748
+ return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
749
+
750
+ def build_answer(out_path):
751
+ from tqdm.auto import tqdm
752
+ n_real = n_default = 0
753
+ with open(out_path, "w") as f:
754
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
755
+ for name in tqdm(dev_names, desc="answer"):
756
+ sid = stem(name)
757
+ pred = predict_all(sid)
758
+ if pred is None:
759
+ qmos = utmos_dev.get(sid, 3.0)
760
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])
761
+ n_default += 1
762
+ else:
763
+ qmos, emos, cat5, vad3 = pred; n_real += 1
764
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
765
+ f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
766
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | head thật {n_real}, mặc định {n_default}")
767
+
768
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
769
+ build_answer(answer_path)
770
+
771
+ # %% [markdown]
772
+ # ## 7. Validate + đóng zip
773
+
774
+ # %%
775
+ def validate(path):
776
+ import csv
777
+ with open(path) as f:
778
+ rows = list(csv.reader(f))
779
+ header = rows[0]
780
+ assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
781
+ for i, r in enumerate(rows[1:], 2):
782
+ assert len(r) == len(header), f"Dòng {i} sai số cột"
783
+ print(f"OK: {len(rows)-1} dòng, header = {header}")
784
+
785
+ validate(answer_path)
786
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp14_mamba.zip answer.txt "
787
+ f"&& unzip -l submission_track2_exp14_mamba.zip")
788
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp14_mamba.zip"))
789
+
790
+ # %% [markdown]
791
+ # ## Ghi chú
792
+ # - **Ablation chính cho paper:** chạy 2 lần — `USE_MAMBA=False` (= exp07, mốc) và `USE_MAMBA=True`.
793
+ # So QMOS/EMOS/VAD nội bộ → trả lời "bộ mã hóa thời gian Mamba có hơn mean-pooling không?".
794
+ # - **Nếu hết đĩa khi cache chuỗi:** giảm `MAX_FRAMES` (256→160) hoặc xóa `wavlm_seq_cache/` sau khi chạy xong.
795
+ # - **Nếu Mamba chậm:** thử `pip install mamba-ssm causal-conv1d` (file tự dùng nếu import được); hoặc giảm
796
+ # `MAMBA_LAYERS`/`MAX_FRAMES`. Bản thuần PyTorch dùng vòng lặp thời gian nên chậm hơn kernel CUDA.
797
+ # - **Save Version** để giữ cache `fusion_cache/` + `wavlm_seq_cache/` cho lần sau.
798
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp14).
track2/exp15_predict.ipynb ADDED
@@ -0,0 +1,698 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "9f2c52f2",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp15 PREDICT-ONLY (nạp checkpoint → chấm DEV, KHÔNG train) — Kaggle\n",
9
+ "\n",
10
+ "**Mục đích:** bạn ĐÃ có checkpoint exp15 (`ft_mamba_emotion_full*.pt`, lưu cả backbone WavLM + Mamba enc + heads).\n",
11
+ "File này **chỉ inference**: dựng lại đúng kiến trúc → nạp trọng số + thống kê chuẩn hóa TỪ ckpt →\n",
12
+ "dự đoán 5 cột cảm xúc trên tập DEV → ghép QMOS (exp07/UTMOSv2) → `answer.txt` → zip nộp.\n",
13
+ "**KHÔNG** train, **KHÔNG** cần train.csv (chỉ cần wav DEV + metadata.csv để lấy cảm xúc target cho EMOS).\n",
14
+ "\n",
15
+ "## Vì sao nhanh\n",
16
+ "- Không có vòng train → chỉ 1 lượt forward qua DEV (~2730 mẫu). Việc lâu nhất là trích audeering DEV\n",
17
+ " (~vài phút; có cache thì gần như tức thì).\n",
18
+ "\n",
19
+ "## Chuẩn bị input trên Kaggle (Add Input)\n",
20
+ "1. Dataset Track 2 (wav + `metadata.csv` + `sets/dev.scp`).\n",
21
+ "2. **Checkpoint** exp15: dataset chứa `ft_mamba_emotion_full*.pt` (vd `cache_exp8`). Auto-dò; hoặc trỏ `CKPT_PATH`.\n",
22
+ "3. (tùy chọn) cache audeering `aud_dev.npz` để khỏi trích lại.\n",
23
+ "4. (tùy chọn) `answer.txt` exp07 để mượn cột QMOS 0.548.\n",
24
+ "\n",
25
+ "**Cách chạy:** GPU **T4** + Internet **On** → Add Input → Run All."
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "markdown",
30
+ "id": "adbc7c65",
31
+ "metadata": {},
32
+ "source": [
33
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "7eb066d5",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "import os, glob\n",
44
+ "\n",
45
+ "# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets + wav/ + metadata.csv) ──\n",
46
+ "def find_data_root(search_root=\"/kaggle/input\"):\n",
47
+ " cands = []\n",
48
+ " for dev_scp in glob.glob(os.path.join(search_root, \"**\", \"sets\", \"dev.scp\"), recursive=True):\n",
49
+ " root = os.path.dirname(os.path.dirname(dev_scp))\n",
50
+ " score = os.path.isdir(os.path.join(root, \"wav\")) + os.path.exists(os.path.join(root, \"metadata.csv\"))\n",
51
+ " cands.append((score, root))\n",
52
+ " cands.sort(reverse=True)\n",
53
+ " return cands\n",
54
+ "\n",
55
+ "_cands = find_data_root(\"/kaggle/input\")\n",
56
+ "if _cands:\n",
57
+ " print(\"🔎 Ứng viên DATA_ROOT:\")\n",
58
+ " for sc, r in _cands:\n",
59
+ " print(f\" [{sc}/2] {r}\")\n",
60
+ " DATA_ROOT = _cands[0][1]\n",
61
+ " print(f\"👉 Tự chọn DATA_ROOT = {DATA_ROOT}\")\n",
62
+ "else:\n",
63
+ " DATA_ROOT = \"/kaggle/input/datasets/minhtoan2\" # dự phòng — sửa tay\n",
64
+ " print(f\"❌ Không thấy sets/dev.scp → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)\")\n",
65
+ "\n",
66
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
67
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) — lấy cảm xúc target cho EMOS\n",
68
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
69
+ "\n",
70
+ "OUT_DIR = \"/kaggle/working\"\n",
71
+ "CACHE_DIR = \"/kaggle/working/ft_cache\"\n",
72
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
73
+ "\n",
74
+ "# ── CHECKPOINT exp15 (đủ backbone + Mamba + heads) ───────────────────────────\n",
75
+ "CKPT_PATH = \"\" # << \"\" = auto-dò ft_mamba_emotion_full*.pt; hoặc \"/kaggle/input/<slug>/ft_mamba_emotion_full (2).pt\"\n",
76
+ "\n",
77
+ "def find_ckpt(explicit):\n",
78
+ " \"\"\"Tìm checkpoint exp15. Khớp cả tên bị thêm hậu tố trùng, vd 'ft_mamba_emotion_full (2).pt'.\"\"\"\n",
79
+ " if explicit and os.path.exists(explicit):\n",
80
+ " return explicit\n",
81
+ " for base in [\"/kaggle/input\", \"/kaggle/working\"]:\n",
82
+ " hits = sorted(glob.glob(os.path.join(base, \"**\", \"ft_mamba_emotion_full*.pt\"), recursive=True))\n",
83
+ " if hits:\n",
84
+ " return hits[0]\n",
85
+ " return \"\"\n",
86
+ "\n",
87
+ "CKPT_PATH = find_ckpt(CKPT_PATH)\n",
88
+ "assert CKPT_PATH, \"❌ Không thấy checkpoint ft_mamba_emotion_full*.pt. Đã Add Input dataset chứa ckpt chưa?\"\n",
89
+ "print(\"✅ Dùng checkpoint:\", CKPT_PATH)\n",
90
+ "\n",
91
+ "# (Tùy chọn) tái dùng cache audeering DEV — quét đệ quy (file có thể nằm trong archive/)\n",
92
+ "CACHE_INPUT = \"/kaggle/input/cache-exp8\" # << SỬA slug (hoặc \"\")\n",
93
+ "if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
94
+ " import shutil\n",
95
+ " _n = 0\n",
96
+ " for _fp in glob.glob(os.path.join(CACHE_INPUT, \"**\", \"aud_*.npz\"), recursive=True):\n",
97
+ " shutil.copy(_fp, os.path.join(CACHE_DIR, os.path.basename(_fp))); _n += 1\n",
98
+ " print(f\"📦 Copy {_n} file aud_*.npz từ {CACHE_INPUT}\")\n",
99
+ "\n",
100
+ "# Mượn cột QMOS exp07 (0.548). Trỏ answer.txt exp07 nếu có; không thì UTMOSv2.\n",
101
+ "EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn)\n",
102
+ "\n",
103
+ "# ── Siêu tham số PHẢI KHỚP lúc train exp15 (ckpt không lưu các số này của Mamba) ──\n",
104
+ "MAMBA_DMODEL = 256\n",
105
+ "MAMBA_LAYERS = 2\n",
106
+ "MAMBA_DSTATE = 16\n",
107
+ "BIDIRECTIONAL = True\n",
108
+ "TRUNK_HIDDEN = 512\n",
109
+ "HEAD_HIDDEN = 128\n",
110
+ "DROPOUT = 0.3 # không ảnh hưởng eval (model.eval() tắt dropout) — chỉ để dựng đúng shape\n",
111
+ "\n",
112
+ "DEVICE = \"cuda\"\n",
113
+ "SR = 16000\n",
114
+ "MAX_SECONDS = 6 # khớp lúc train (exp15 = 6)\n",
115
+ "USE_AMP = True\n",
116
+ "LIMIT_DEV = None # << để None chấm ĐỦ 2730; đặt 20 để smoke-test nhanh\n",
117
+ "\n",
118
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
119
+ "_EMO_ALIAS = {\n",
120
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
121
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
122
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
123
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
124
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
125
+ "}\n",
126
+ "\n",
127
+ "def norm_emotion(label):\n",
128
+ " key = str(label).strip().lower()\n",
129
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
130
+ "\n",
131
+ "def stem(p):\n",
132
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
133
+ "\n",
134
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
135
+ "for p in [WAV_DIR, METADATA_CSV, DEV_SCP, CKPT_PATH]:\n",
136
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "id": "febe8bdc",
142
+ "metadata": {},
143
+ "source": [
144
+ "## 1. Cài đặt + tải code SAILER (để dựng đúng kiến trúc WavLM rồi nạp ckpt đè lên)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "id": "7732e245",
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "import sys, subprocess\n",
155
+ "\n",
156
+ "def pip_install(*pkgs):\n",
157
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
158
+ "\n",
159
+ "pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\",\n",
160
+ " \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
161
+ "\n",
162
+ "# Mamba kernel CUDA (tùy chọn — không có thì dùng Mamba thuần PyTorch, inference vẫn ổn vì chỉ 1 lượt forward)\n",
163
+ "INSTALL_MAMBA_SSM = True\n",
164
+ "if INSTALL_MAMBA_SSM:\n",
165
+ " try:\n",
166
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"ninja\"], check=True)\n",
167
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--no-build-isolation\", \"causal-conv1d>=1.2.0\"], check=True)\n",
168
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--no-build-isolation\", \"mamba-ssm\"], check=True)\n",
169
+ " print(\"✅ Cài mamba-ssm xong (dùng kernel CUDA nếu import được).\")\n",
170
+ " except Exception as e:\n",
171
+ " print(\"⚠️ Cài mamba-ssm thất bại:\", repr(e), \"→ Mamba thuần PyTorch (inference vẫn chạy).\")\n",
172
+ "\n",
173
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
174
+ "if not os.path.exists(REPO_DIR):\n",
175
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
176
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
177
+ "if REPO_DIR not in sys.path:\n",
178
+ " sys.path.insert(0, REPO_DIR)"
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "markdown",
183
+ "id": "fba12581",
184
+ "metadata": {},
185
+ "source": [
186
+ "## 2. Nạp checkpoint → dựng WavLM → load trọng số backbone đã fine-tune"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "id": "61199736",
193
+ "metadata": {
194
+ "lines_to_next_cell": 1
195
+ },
196
+ "outputs": [],
197
+ "source": [
198
+ "import torch\n",
199
+ "import torch.nn as nn\n",
200
+ "import torch.nn.functional as F\n",
201
+ "\n",
202
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
203
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (chậm)\")\n",
204
+ "\n",
205
+ "ckpt = torch.load(CKPT_PATH, map_location=\"cpu\", weights_only=False) # ckpt có numpy → cần False\n",
206
+ "assert \"wavlm\" in ckpt, \"❌ Checkpoint KHÔNG có 'wavlm' (backbone) → không inference được. Cần ft_mamba_emotion_full*.pt đủ.\"\n",
207
+ "print(\"✅ Nạp ckpt | keys:\", list(ckpt.keys()))\n",
208
+ "\n",
209
+ "# Lấy cấu hình KIẾN TRÚC từ ckpt (để dựng đúng shape head)\n",
210
+ "USE_MAMBA = bool(ckpt.get(\"USE_MAMBA\", True))\n",
211
+ "Z_DIM = int(ckpt.get(\"Z_DIM\", 256))\n",
212
+ "AUD_DIM = int(ckpt.get(\"AUD_DIM\", 0))\n",
213
+ "USE_AUDEERING = AUD_DIM > 0\n",
214
+ "UNFREEZE_TOP_LAYERS = int(ckpt.get(\"UNFREEZE_TOP_LAYERS\", 6))\n",
215
+ "print(f\"Từ ckpt: USE_MAMBA={USE_MAMBA} · Z_DIM={Z_DIM} · AUD_DIM={AUD_DIM} (audeering={'ON' if USE_AUDEERING else 'OFF'})\")\n",
216
+ "\n",
217
+ "def find_hf_backbone(module):\n",
218
+ " cands = []\n",
219
+ " for name, m in module.named_modules():\n",
220
+ " enc = getattr(m, \"encoder\", None)\n",
221
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
222
+ " and getattr(enc, \"layers\", None) is not None:\n",
223
+ " cands.append((name, m))\n",
224
+ " if not cands:\n",
225
+ " return None, None\n",
226
+ " cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
227
+ " return cands[0]\n",
228
+ "\n",
229
+ "wavlm = None\n",
230
+ "try:\n",
231
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
232
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
233
+ " name, wavlm = find_hf_backbone(_wrapper)\n",
234
+ " if wavlm is not None:\n",
235
+ " print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'\")\n",
236
+ "except Exception as e:\n",
237
+ " print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
238
+ "\n",
239
+ "if wavlm is None:\n",
240
+ " from transformers import WavLMModel\n",
241
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
242
+ " print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
243
+ "\n",
244
+ "wavlm = wavlm.to(device)\n",
245
+ "WAVLM_DIM = int(wavlm.config.hidden_size)\n",
246
+ "wavlm.config.layerdrop = 0.0\n",
247
+ "\n",
248
+ "miss, unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
249
+ "print(f\"🔁 load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0)\")\n",
250
+ "if len(miss) > 20 or len(unexp) > 20:\n",
251
+ " print(\" ⚠️ Lệch key nhiều → kiểm tra backbone có khớp ckpt không.\")\n",
252
+ "wavlm.eval()\n",
253
+ "\n",
254
+ "def frame_mask(T, attn_mask):\n",
255
+ " if attn_mask is None:\n",
256
+ " return torch.ones((1, T), dtype=torch.bool, device=device)\n",
257
+ " try:\n",
258
+ " return wavlm._get_feature_vector_attention_mask(T, attn_mask).bool()\n",
259
+ " except Exception:\n",
260
+ " return torch.ones((attn_mask.shape[0], T), dtype=torch.bool, device=attn_mask.device)\n",
261
+ "\n",
262
+ "def masked_mean(hidden, attn_mask):\n",
263
+ " if attn_mask is None:\n",
264
+ " return hidden.mean(dim=1)\n",
265
+ " fm = frame_mask(hidden.shape[1], attn_mask).unsqueeze(-1).to(hidden.dtype)\n",
266
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)"
267
+ ]
268
+ },
269
+ {
270
+ "cell_type": "markdown",
271
+ "id": "421e0b6a",
272
+ "metadata": {},
273
+ "source": [
274
+ "## 3. audeering MSP-dim (FROZEN) — chỉ dựng nếu ckpt có dùng (AUD_DIM>0)"
275
+ ]
276
+ },
277
+ {
278
+ "cell_type": "code",
279
+ "execution_count": null,
280
+ "id": "d37d3d53",
281
+ "metadata": {
282
+ "lines_to_next_cell": 1
283
+ },
284
+ "outputs": [],
285
+ "source": [
286
+ "import numpy as np\n",
287
+ "import librosa\n",
288
+ "from tqdm.auto import tqdm\n",
289
+ "\n",
290
+ "aud_backbone = aud_head = aud_proc = None\n",
291
+ "if USE_AUDEERING:\n",
292
+ " from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
293
+ " from huggingface_hub import hf_hub_download\n",
294
+ " AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
295
+ " aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
296
+ " aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
297
+ " aud_backbone = Wav2Vec2Model(aud_cfg)\n",
298
+ " try:\n",
299
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
300
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
301
+ " except Exception:\n",
302
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
303
+ " bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
304
+ " aud_backbone.load_state_dict(bb_sd, strict=False)\n",
305
+ " _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
306
+ " aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd[\"classifier.out_proj.weight\"].shape[0]))\n",
307
+ " aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
308
+ " aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
309
+ " aud_backbone = aud_backbone.to(device).eval()\n",
310
+ " aud_head = aud_head.to(device).eval()\n",
311
+ " assert _hid + 3 == AUD_DIM, f\"⚠️ AUD_DIM dựng ({_hid+3}) ≠ ckpt ({AUD_DIM}) → audeering không khớp!\"\n",
312
+ " print(f\"✅ audeering frozen ({AUD_DIM}-D)\")\n",
313
+ "\n",
314
+ "def load_wav(name_or_stem):\n",
315
+ " p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(\n",
316
+ " WAV_DIR, name_or_stem if str(name_or_stem).endswith(\".wav\") else str(name_or_stem) + \".wav\")\n",
317
+ " if not os.path.exists(p):\n",
318
+ " return None\n",
319
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
320
+ " return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
321
+ "\n",
322
+ "@torch.no_grad()\n",
323
+ "def extract_audeering(stems, tag):\n",
324
+ " if not USE_AUDEERING:\n",
325
+ " return {}\n",
326
+ " cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
327
+ " store = {}\n",
328
+ " if os.path.exists(cache_path):\n",
329
+ " z = np.load(cache_path, allow_pickle=True)\n",
330
+ " store = {k: z[k] for k in z.files}\n",
331
+ " print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
332
+ " todo = [s for s in stems if s not in store]\n",
333
+ " for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
334
+ " wave = load_wav(s)\n",
335
+ " if wave is None:\n",
336
+ " continue\n",
337
+ " x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
338
+ " x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
339
+ " h = aud_backbone(x)[0].mean(dim=1)\n",
340
+ " out = aud_head(h)[0].cpu().numpy()\n",
341
+ " vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
342
+ " store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
343
+ " if (i + 1) % 500 == 0:\n",
344
+ " np.savez(cache_path, **store)\n",
345
+ " if todo:\n",
346
+ " np.savez(cache_path, **store)\n",
347
+ " return store"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "markdown",
352
+ "id": "0a04ef30",
353
+ "metadata": {},
354
+ "source": [
355
+ "## 4. Cảm xúc target theo wavID (cho one-hot điều kiện của head EMOS)"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "id": "3c092318",
362
+ "metadata": {
363
+ "lines_to_next_cell": 1
364
+ },
365
+ "outputs": [],
366
+ "source": [
367
+ "def load_target_emotions():\n",
368
+ " tgt = {}\n",
369
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
370
+ " for ln in f:\n",
371
+ " parts = ln.strip().split(\"|\")\n",
372
+ " if len(parts) >= 2:\n",
373
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
374
+ " return tgt\n",
375
+ "\n",
376
+ "target_map = load_target_emotions()\n",
377
+ "print(\"Target cảm xúc:\", len(target_map), \"wav\")\n",
378
+ "\n",
379
+ "def onehot_target(tgt):\n",
380
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
381
+ " if tgt in EMOTIONS5:\n",
382
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
383
+ " return v"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "markdown",
388
+ "id": "a0d7021a",
389
+ "metadata": {},
390
+ "source": [
391
+ "## 5. Khối Mamba (giống exp15) + MambaEncoder"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": null,
397
+ "id": "d8c31f88",
398
+ "metadata": {
399
+ "lines_to_next_cell": 1
400
+ },
401
+ "outputs": [],
402
+ "source": [
403
+ "import math\n",
404
+ "\n",
405
+ "try:\n",
406
+ " from mamba_ssm import Mamba as _OfficialMamba\n",
407
+ " _HAS_MAMBA_SSM = True\n",
408
+ " print(\"✅ Dùng mamba-ssm (CUDA kernel)\")\n",
409
+ "except Exception:\n",
410
+ " _HAS_MAMBA_SSM = False\n",
411
+ " print(\"ℹ️ Không có mamba-ssm → Mamba thuần PyTorch\")\n",
412
+ "\n",
413
+ "class MambaBlockTorch(nn.Module):\n",
414
+ " def __init__(self, d_model, d_state=16, d_conv=4, expand=2):\n",
415
+ " super().__init__()\n",
416
+ " self.d_inner = expand * d_model\n",
417
+ " self.dt_rank = math.ceil(d_model / 16)\n",
418
+ " self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)\n",
419
+ " self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,\n",
420
+ " groups=self.d_inner, padding=d_conv - 1, bias=True)\n",
421
+ " self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)\n",
422
+ " self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)\n",
423
+ " A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)\n",
424
+ " self.A_log = nn.Parameter(torch.log(A))\n",
425
+ " self.D = nn.Parameter(torch.ones(self.d_inner))\n",
426
+ " self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)\n",
427
+ " self.d_state = d_state\n",
428
+ "\n",
429
+ " def forward(self, x):\n",
430
+ " B, L, _ = x.shape\n",
431
+ " xin, z = self.in_proj(x).chunk(2, dim=-1)\n",
432
+ " xin = xin.transpose(1, 2)\n",
433
+ " xin = self.conv1d(xin)[..., :L].transpose(1, 2)\n",
434
+ " xin = F.silu(xin)\n",
435
+ " y = self._ssm(xin) * F.silu(z)\n",
436
+ " return self.out_proj(y)\n",
437
+ "\n",
438
+ " def _ssm(self, x):\n",
439
+ " A = -torch.exp(self.A_log)\n",
440
+ " delta, Bm, Cm = torch.split(self.x_proj(x), [self.dt_rank, self.d_state, self.d_state], dim=-1)\n",
441
+ " delta = F.softplus(self.dt_proj(delta))\n",
442
+ " dA = torch.exp(delta.unsqueeze(-1) * A)\n",
443
+ " dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1)\n",
444
+ " h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)\n",
445
+ " ys = []\n",
446
+ " for t in range(x.shape[1]):\n",
447
+ " h = dA[:, t] * h + dB_x[:, t]\n",
448
+ " ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1))\n",
449
+ " return torch.stack(ys, dim=1) + x * self.D\n",
450
+ "\n",
451
+ "class MambaLayer(nn.Module):\n",
452
+ " def __init__(self, d_model, d_state):\n",
453
+ " super().__init__()\n",
454
+ " self.norm = nn.LayerNorm(d_model)\n",
455
+ " self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2) \\\n",
456
+ " if _HAS_MAMBA_SSM else MambaBlockTorch(d_model, d_state=d_state)\n",
457
+ " def forward(self, x):\n",
458
+ " return x + self.mix(self.norm(x))\n",
459
+ "\n",
460
+ "class MambaEncoder(nn.Module):\n",
461
+ " def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):\n",
462
+ " super().__init__()\n",
463
+ " self.bidir = bidir\n",
464
+ " self.proj = nn.Linear(d_in, d_model)\n",
465
+ " self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
466
+ " if bidir:\n",
467
+ " self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
468
+ " self.attn = nn.Linear(d_model, 1)\n",
469
+ " self.out = nn.Linear(d_model, z_dim)\n",
470
+ "\n",
471
+ " @staticmethod\n",
472
+ " def _run(layers, h):\n",
473
+ " for L in layers:\n",
474
+ " h = L(h)\n",
475
+ " return h\n",
476
+ "\n",
477
+ " def forward(self, x, mask):\n",
478
+ " with torch.cuda.amp.autocast(enabled=False):\n",
479
+ " x = x.float()\n",
480
+ " h = self.proj(x)\n",
481
+ " out = self._run(self.fwd, h)\n",
482
+ " if self.bidir:\n",
483
+ " out = out + torch.flip(self._run(self.bwd, torch.flip(h, dims=[1])), dims=[1])\n",
484
+ " a = self.attn(out).squeeze(-1).masked_fill(~mask, float(\"-inf\"))\n",
485
+ " w = torch.softmax(a, dim=1).unsqueeze(-1)\n",
486
+ " return self.out((out * w).sum(1))"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "markdown",
491
+ "id": "c8369a6b",
492
+ "metadata": {},
493
+ "source": [
494
+ "## 6. Dựng enc + heads → nạp trọng số từ ckpt + lấy chuẩn hóa từ ckpt"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": null,
500
+ "id": "1c5e8556",
501
+ "metadata": {
502
+ "lines_to_next_cell": 1
503
+ },
504
+ "outputs": [],
505
+ "source": [
506
+ "N_EMO = len(EMOTIONS5)\n",
507
+ "WAVLM_BRANCH = Z_DIM if USE_MAMBA else WAVLM_DIM\n",
508
+ "TRUNK_IN = WAVLM_BRANCH + (AUD_DIM if USE_AUDEERING else 0)\n",
509
+ "\n",
510
+ "enc = MambaEncoder(WAVLM_DIM, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL).to(device) \\\n",
511
+ " if USE_MAMBA else None\n",
512
+ "\n",
513
+ "class EmoHeads(nn.Module):\n",
514
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
515
+ " super().__init__()\n",
516
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
517
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
518
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
519
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
520
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
521
+ " def forward(self, feat, tgt):\n",
522
+ " h = self.trunk(feat)\n",
523
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
524
+ "\n",
525
+ "heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
526
+ "hm, hu = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
527
+ "print(f\"🔁 load heads từ ckpt: thiếu {len(hm)} / dư {len(hu)} key (kỳ vọng 0)\")\n",
528
+ "if USE_MAMBA:\n",
529
+ " assert ckpt.get(\"enc\") is not None, \"❌ ckpt USE_MAMBA=True nhưng KHÔNG có 'enc' → không inference đúng được.\"\n",
530
+ " em, eu = enc.load_state_dict(ckpt[\"enc\"], strict=False)\n",
531
+ " print(f\"🔁 load Mamba enc từ ckpt: thiếu {len(em)} / dư {len(eu)} key (kỳ vọng 0)\")\n",
532
+ "heads.eval()\n",
533
+ "if USE_MAMBA:\n",
534
+ " enc.eval()\n",
535
+ "\n",
536
+ "# Chuẩn hóa LẤY TỪ ckpt (head dự đoán ở thang z-score này → phải giải chuẩn đúng thang)\n",
537
+ "emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
538
+ "vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
539
+ "print(f\"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
540
+ "\n",
541
+ "def wavlm_branch(input_values, attn_mask):\n",
542
+ " out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
543
+ " if USE_MAMBA:\n",
544
+ " return enc(out, frame_mask(out.shape[1], attn_mask))\n",
545
+ " return masked_mean(out, attn_mask)\n",
546
+ "\n",
547
+ "print(f\"Trunk input = {TRUNK_IN} (wavlm-branch {WAVLM_BRANCH} [{'Mamba' if USE_MAMBA else 'mean-pool'}] + aud {AUD_DIM if USE_AUDEERING else 0})\")"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "markdown",
552
+ "id": "fdcf05c2",
553
+ "metadata": {},
554
+ "source": [
555
+ "## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc; QMOS mượn exp07/UTMOSv2)"
556
+ ]
557
+ },
558
+ {
559
+ "cell_type": "code",
560
+ "execution_count": null,
561
+ "id": "4d225f54",
562
+ "metadata": {
563
+ "lines_to_next_cell": 1
564
+ },
565
+ "outputs": [],
566
+ "source": [
567
+ "def list_dev():\n",
568
+ " with open(DEV_SCP) as f:\n",
569
+ " return [ln.strip() for ln in f if ln.strip()]\n",
570
+ "\n",
571
+ "dev_names = list_dev()\n",
572
+ "if LIMIT_DEV:\n",
573
+ " dev_names = dev_names[:LIMIT_DEV]\n",
574
+ "dev_stems = [stem(n) for n in dev_names]\n",
575
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
576
+ "aud_dev = extract_audeering(dev_stems, \"dev\")\n",
577
+ "\n",
578
+ "def load_exp07_qmos():\n",
579
+ " if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
580
+ " import csv\n",
581
+ " d = {}\n",
582
+ " with open(EXP07_ANSWER) as f:\n",
583
+ " for row in csv.DictReader(f):\n",
584
+ " d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
585
+ " print(f\"✅ Mượn QMOS exp07 ({EXP07_ANSWER}): {len(d)//2} wav\")\n",
586
+ " return d\n",
587
+ " return None\n",
588
+ "\n",
589
+ "qmos_map = load_exp07_qmos()\n",
590
+ "if qmos_map is None:\n",
591
+ " print(\"ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).\")\n",
592
+ " pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
593
+ " import utmosv2\n",
594
+ " v2 = utmosv2.create_model(pretrained=True)\n",
595
+ " qmos_map = {}\n",
596
+ " for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
597
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
598
+ " if not os.path.exists(wav):\n",
599
+ " continue\n",
600
+ " out = v2.predict(input_path=wav)\n",
601
+ " qmos_map[n] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
602
+ " del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
603
+ "\n",
604
+ "@torch.no_grad()\n",
605
+ "def predict_emotion(sid):\n",
606
+ " wave = load_wav(sid)\n",
607
+ " if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
608
+ " return None\n",
609
+ " iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
610
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
611
+ " tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
612
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
613
+ " fw = wavlm_branch(iv, am)\n",
614
+ " feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw\n",
615
+ " emos_p, cat_l, vad_p = heads(feat, tgt)\n",
616
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
617
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
618
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
619
+ " return emos, cat5, vad3\n",
620
+ "\n",
621
+ "def fmt_cat(p5):\n",
622
+ " return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
623
+ "\n",
624
+ "def build_answer(out_path):\n",
625
+ " n_real = n_def = 0\n",
626
+ " with open(out_path, \"w\") as f:\n",
627
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
628
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
629
+ " sid = stem(name)\n",
630
+ " pr = predict_emotion(sid)\n",
631
+ " if pr is None:\n",
632
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
633
+ " else:\n",
634
+ " emos, cat5, vad3 = pr; n_real += 1\n",
635
+ " qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
636
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
637
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
638
+ "\n",
639
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
640
+ "build_answer(answer_path)"
641
+ ]
642
+ },
643
+ {
644
+ "cell_type": "markdown",
645
+ "id": "42503595",
646
+ "metadata": {},
647
+ "source": [
648
+ "## 8. Validate + đóng zip"
649
+ ]
650
+ },
651
+ {
652
+ "cell_type": "code",
653
+ "execution_count": null,
654
+ "id": "42dec31f",
655
+ "metadata": {},
656
+ "outputs": [],
657
+ "source": [
658
+ "def validate(path):\n",
659
+ " import csv\n",
660
+ " with open(path) as f:\n",
661
+ " rows = list(csv.reader(f))\n",
662
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0] and \"EMOS\" in rows[0], \"Header sai\"\n",
663
+ " for i, r in enumerate(rows[1:], 2):\n",
664
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
665
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
666
+ "\n",
667
+ "validate(answer_path)\n",
668
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp15_predict.zip answer.txt \"\n",
669
+ " f\"&& unzip -l submission_track2_exp15_predict.zip\")\n",
670
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp15_predict.zip\"))"
671
+ ]
672
+ },
673
+ {
674
+ "cell_type": "markdown",
675
+ "id": "fbef2a21",
676
+ "metadata": {},
677
+ "source": [
678
+ "## Ghi chú\n",
679
+ "- File này **chỉ inference** — không train, không cần train.csv. Dùng khi đã có `ft_mamba_emotion_full*.pt`.\n",
680
+ "- ⚠️ **Siêu tham số Mamba/heads (MAMBA_DMODEL/LAYERS/DSTATE, TRUNK_HIDDEN, HEAD_HIDDEN) PHẢI khớp lúc train**\n",
681
+ " (ckpt không lưu các số này) — nếu lúc train exp15 bạn đổi, hãy sửa cho khớp ở cell 0, nếu không load_state_dict\n",
682
+ " sẽ lệch key / sai shape.\n",
683
+ "- `USE_MAMBA`, `Z_DIM`, `AUD_DIM`, `UNFREEZE_TOP_LAYERS` thì **đọc tự động từ ckpt**.\n",
684
+ "- QMOS: tốt nhất Add Input `answer.txt` exp07 (0.548); không có thì tự chấm UTMOSv2.\n",
685
+ "- Smoke-test: đặt `LIMIT_DEV=20` chạy thử cho nhanh, OK rồi đặt lại `None` để chấm đủ 2730."
686
+ ]
687
+ }
688
+ ],
689
+ "metadata": {
690
+ "jupytext": {
691
+ "cell_metadata_filter": "-all",
692
+ "main_language": "python",
693
+ "notebook_metadata_filter": "-all"
694
+ }
695
+ },
696
+ "nbformat": 4,
697
+ "nbformat_minor": 5
698
+ }
track2/exp15_predict_pipeline.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp15 PREDICT-ONLY (nạp checkpoint → chấm DEV, KHÔNG train) — Kaggle
3
+ #
4
+ # **Mục đích:** bạn ĐÃ có checkpoint exp15 (`ft_mamba_emotion_full*.pt`, lưu cả backbone WavLM + Mamba enc + heads).
5
+ # File này **chỉ inference**: dựng lại đúng kiến trúc → nạp trọng số + thống kê chuẩn hóa TỪ ckpt →
6
+ # dự đoán 5 cột cảm xúc trên tập DEV → ghép QMOS (exp07/UTMOSv2) → `answer.txt` → zip nộp.
7
+ # **KHÔNG** train, **KHÔNG** cần train.csv (chỉ cần wav DEV + metadata.csv để lấy cảm xúc target cho EMOS).
8
+ #
9
+ # ## Vì sao nhanh
10
+ # - Không có vòng train → chỉ 1 lượt forward qua DEV (~2730 mẫu). Việc lâu nhất là trích audeering DEV
11
+ # (~vài phút; có cache thì gần như tức thì).
12
+ #
13
+ # ## Chuẩn bị input trên Kaggle (Add Input)
14
+ # 1. Dataset Track 2 (wav + `metadata.csv` + `sets/dev.scp`).
15
+ # 2. **Checkpoint** exp15: dataset chứa `ft_mamba_emotion_full*.pt` (vd `cache_exp8`). Auto-dò; hoặc trỏ `CKPT_PATH`.
16
+ # 3. (tùy chọn) cache audeering `aud_dev.npz` để khỏi trích lại.
17
+ # 4. (tùy chọn) `answer.txt` exp07 để mượn cột QMOS 0.548.
18
+ #
19
+ # **Cách chạy:** GPU **T4** + Internet **On** → Add Input → Run All.
20
+
21
+ # %% [markdown]
22
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
23
+
24
+ # %%
25
+ import os, glob
26
+
27
+ # ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets + wav/ + metadata.csv) ──
28
+ def find_data_root(search_root="/kaggle/input"):
29
+ cands = []
30
+ for dev_scp in glob.glob(os.path.join(search_root, "**", "sets", "dev.scp"), recursive=True):
31
+ root = os.path.dirname(os.path.dirname(dev_scp))
32
+ score = os.path.isdir(os.path.join(root, "wav")) + os.path.exists(os.path.join(root, "metadata.csv"))
33
+ cands.append((score, root))
34
+ cands.sort(reverse=True)
35
+ return cands
36
+
37
+ _cands = find_data_root("/kaggle/input")
38
+ if _cands:
39
+ print("🔎 Ứng viên DATA_ROOT:")
40
+ for sc, r in _cands:
41
+ print(f" [{sc}/2] {r}")
42
+ DATA_ROOT = _cands[0][1]
43
+ print(f"👉 Tự chọn DATA_ROOT = {DATA_ROOT}")
44
+ else:
45
+ DATA_ROOT = "/kaggle/input/datasets/minhtoan2" # dự phòng — sửa tay
46
+ print(f"❌ Không thấy sets/dev.scp → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)")
47
+
48
+ WAV_DIR = f"{DATA_ROOT}/wav"
49
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) — lấy cảm xúc target cho EMOS
50
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
51
+
52
+ OUT_DIR = "/kaggle/working"
53
+ CACHE_DIR = "/kaggle/working/ft_cache"
54
+ os.makedirs(CACHE_DIR, exist_ok=True)
55
+
56
+ # ── CHECKPOINT exp15 (đủ backbone + Mamba + heads) ───────────────────────────
57
+ CKPT_PATH = "" # << "" = auto-dò ft_mamba_emotion_full*.pt; hoặc "/kaggle/input/<slug>/ft_mamba_emotion_full (2).pt"
58
+
59
+ def find_ckpt(explicit):
60
+ """Tìm checkpoint exp15. Khớp cả tên bị thêm hậu tố trùng, vd 'ft_mamba_emotion_full (2).pt'."""
61
+ if explicit and os.path.exists(explicit):
62
+ return explicit
63
+ for base in ["/kaggle/input", "/kaggle/working"]:
64
+ hits = sorted(glob.glob(os.path.join(base, "**", "ft_mamba_emotion_full*.pt"), recursive=True))
65
+ if hits:
66
+ return hits[0]
67
+ return ""
68
+
69
+ CKPT_PATH = find_ckpt(CKPT_PATH)
70
+ assert CKPT_PATH, "❌ Không thấy checkpoint ft_mamba_emotion_full*.pt. Đã Add Input dataset chứa ckpt chưa?"
71
+ print("✅ Dùng checkpoint:", CKPT_PATH)
72
+
73
+ # (Tùy chọn) tái dùng cache audeering DEV — quét đệ quy (file có thể nằm trong archive/)
74
+ CACHE_INPUT = "/kaggle/input/cache-exp8" # << SỬA slug (hoặc "")
75
+ if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
76
+ import shutil
77
+ _n = 0
78
+ for _fp in glob.glob(os.path.join(CACHE_INPUT, "**", "aud_*.npz"), recursive=True):
79
+ shutil.copy(_fp, os.path.join(CACHE_DIR, os.path.basename(_fp))); _n += 1
80
+ print(f"📦 Copy {_n} file aud_*.npz từ {CACHE_INPUT}")
81
+
82
+ # Mượn cột QMOS exp07 (0.548). Trỏ answer.txt exp07 nếu có; không thì UTMOSv2.
83
+ EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn)
84
+
85
+ # ── Siêu tham số PHẢI KHỚP lúc train exp15 (ckpt không lưu các số này của Mamba) ──
86
+ MAMBA_DMODEL = 256
87
+ MAMBA_LAYERS = 2
88
+ MAMBA_DSTATE = 16
89
+ BIDIRECTIONAL = True
90
+ TRUNK_HIDDEN = 512
91
+ HEAD_HIDDEN = 128
92
+ DROPOUT = 0.3 # không ảnh hưởng eval (model.eval() tắt dropout) — chỉ để dựng đúng shape
93
+
94
+ DEVICE = "cuda"
95
+ SR = 16000
96
+ MAX_SECONDS = 6 # khớp lúc train (exp15 = 6)
97
+ USE_AMP = True
98
+ LIMIT_DEV = None # << để None chấm ĐỦ 2730; đặt 20 để smoke-test nhanh
99
+
100
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
101
+ _EMO_ALIAS = {
102
+ "angry": "angry", "anger": "angry",
103
+ "happy": "happy", "happiness": "happy", "joy": "happy",
104
+ "neutral": "neutral", "calm": "neutral",
105
+ "sad": "sad", "sadness": "sad",
106
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
107
+ }
108
+
109
+ def norm_emotion(label):
110
+ key = str(label).strip().lower()
111
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
112
+
113
+ def stem(p):
114
+ return os.path.splitext(os.path.basename(str(p)))[0]
115
+
116
+ print("DATA_ROOT:", DATA_ROOT)
117
+ for p in [WAV_DIR, METADATA_CSV, DEV_SCP, CKPT_PATH]:
118
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
119
+
120
+ # %% [markdown]
121
+ # ## 1. Cài đặt + tải code SAILER (để dựng đúng kiến trúc WavLM rồi nạp ckpt đè lên)
122
+
123
+ # %%
124
+ import sys, subprocess
125
+
126
+ def pip_install(*pkgs):
127
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
128
+
129
+ pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile",
130
+ "scipy", "scikit-learn", "pandas", "tqdm")
131
+
132
+ # Mamba kernel CUDA (tùy chọn — không có thì dùng Mamba thuần PyTorch, inference vẫn ổn vì chỉ 1 lượt forward)
133
+ INSTALL_MAMBA_SSM = True
134
+ if INSTALL_MAMBA_SSM:
135
+ try:
136
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", "ninja"], check=True)
137
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", "--no-build-isolation", "causal-conv1d>=1.2.0"], check=True)
138
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", "--no-build-isolation", "mamba-ssm"], check=True)
139
+ print("✅ Cài mamba-ssm xong (dùng kernel CUDA nếu import được).")
140
+ except Exception as e:
141
+ print("⚠️ Cài mamba-ssm thất bại:", repr(e), "→ Mamba thuần PyTorch (inference vẫn chạy).")
142
+
143
+ REPO_DIR = "/kaggle/working/vox-profile-release"
144
+ if not os.path.exists(REPO_DIR):
145
+ subprocess.run(["git", "clone", "--depth", "1",
146
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
147
+ if REPO_DIR not in sys.path:
148
+ sys.path.insert(0, REPO_DIR)
149
+
150
+ # %% [markdown]
151
+ # ## 2. Nạp checkpoint → dựng WavLM → load trọng số backbone đã fine-tune
152
+
153
+ # %%
154
+ import torch
155
+ import torch.nn as nn
156
+ import torch.nn.functional as F
157
+
158
+ device = DEVICE if torch.cuda.is_available() else "cpu"
159
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (chậm)")
160
+
161
+ ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False) # ckpt có numpy → cần False
162
+ assert "wavlm" in ckpt, "❌ Checkpoint KHÔNG có 'wavlm' (backbone) → không inference được. Cần ft_mamba_emotion_full*.pt đủ."
163
+ print("✅ Nạp ckpt | keys:", list(ckpt.keys()))
164
+
165
+ # Lấy cấu hình KIẾN TRÚC từ ckpt (để dựng đúng shape head)
166
+ USE_MAMBA = bool(ckpt.get("USE_MAMBA", True))
167
+ Z_DIM = int(ckpt.get("Z_DIM", 256))
168
+ AUD_DIM = int(ckpt.get("AUD_DIM", 0))
169
+ USE_AUDEERING = AUD_DIM > 0
170
+ UNFREEZE_TOP_LAYERS = int(ckpt.get("UNFREEZE_TOP_LAYERS", 6))
171
+ print(f"Từ ckpt: USE_MAMBA={USE_MAMBA} · Z_DIM={Z_DIM} · AUD_DIM={AUD_DIM} (audeering={'ON' if USE_AUDEERING else 'OFF'})")
172
+
173
+ def find_hf_backbone(module):
174
+ cands = []
175
+ for name, m in module.named_modules():
176
+ enc = getattr(m, "encoder", None)
177
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
178
+ and getattr(enc, "layers", None) is not None:
179
+ cands.append((name, m))
180
+ if not cands:
181
+ return None, None
182
+ cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
183
+ return cands[0]
184
+
185
+ wavlm = None
186
+ try:
187
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
188
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
189
+ name, wavlm = find_hf_backbone(_wrapper)
190
+ if wavlm is not None:
191
+ print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'")
192
+ except Exception as e:
193
+ print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
194
+
195
+ if wavlm is None:
196
+ from transformers import WavLMModel
197
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
198
+ print("ℹ️ Fallback: microsoft/wavlm-large.")
199
+
200
+ wavlm = wavlm.to(device)
201
+ WAVLM_DIM = int(wavlm.config.hidden_size)
202
+ wavlm.config.layerdrop = 0.0
203
+
204
+ miss, unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
205
+ print(f"🔁 load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0)")
206
+ if len(miss) > 20 or len(unexp) > 20:
207
+ print(" ⚠️ Lệch key nhiều → kiểm tra backbone có khớp ckpt không.")
208
+ wavlm.eval()
209
+
210
+ def frame_mask(T, attn_mask):
211
+ if attn_mask is None:
212
+ return torch.ones((1, T), dtype=torch.bool, device=device)
213
+ try:
214
+ return wavlm._get_feature_vector_attention_mask(T, attn_mask).bool()
215
+ except Exception:
216
+ return torch.ones((attn_mask.shape[0], T), dtype=torch.bool, device=attn_mask.device)
217
+
218
+ def masked_mean(hidden, attn_mask):
219
+ if attn_mask is None:
220
+ return hidden.mean(dim=1)
221
+ fm = frame_mask(hidden.shape[1], attn_mask).unsqueeze(-1).to(hidden.dtype)
222
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
223
+
224
+ # %% [markdown]
225
+ # ## 3. audeering MSP-dim (FROZEN) — chỉ dựng nếu ckpt có dùng (AUD_DIM>0)
226
+
227
+ # %%
228
+ import numpy as np
229
+ import librosa
230
+ from tqdm.auto import tqdm
231
+
232
+ aud_backbone = aud_head = aud_proc = None
233
+ if USE_AUDEERING:
234
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
235
+ from huggingface_hub import hf_hub_download
236
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
237
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
238
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
239
+ aud_backbone = Wav2Vec2Model(aud_cfg)
240
+ try:
241
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
242
+ hf_hub_download(AUD_NAME, "model.safetensors"))
243
+ except Exception:
244
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
245
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
246
+ aud_backbone.load_state_dict(bb_sd, strict=False)
247
+ _hid = _sd["classifier.dense.weight"].shape[0]
248
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd["classifier.out_proj.weight"].shape[0]))
249
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
250
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
251
+ aud_backbone = aud_backbone.to(device).eval()
252
+ aud_head = aud_head.to(device).eval()
253
+ assert _hid + 3 == AUD_DIM, f"⚠️ AUD_DIM dựng ({_hid+3}) ≠ ckpt ({AUD_DIM}) → audeering không khớp!"
254
+ print(f"✅ audeering frozen ({AUD_DIM}-D)")
255
+
256
+ def load_wav(name_or_stem):
257
+ p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(
258
+ WAV_DIR, name_or_stem if str(name_or_stem).endswith(".wav") else str(name_or_stem) + ".wav")
259
+ if not os.path.exists(p):
260
+ return None
261
+ wave, _ = librosa.load(p, sr=SR, mono=True)
262
+ return wave[: MAX_SECONDS * SR].astype(np.float32)
263
+
264
+ @torch.no_grad()
265
+ def extract_audeering(stems, tag):
266
+ if not USE_AUDEERING:
267
+ return {}
268
+ cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
269
+ store = {}
270
+ if os.path.exists(cache_path):
271
+ z = np.load(cache_path, allow_pickle=True)
272
+ store = {k: z[k] for k in z.files}
273
+ print(f"[aud/{tag}] nạp cache: {len(store)}")
274
+ todo = [s for s in stems if s not in store]
275
+ for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
276
+ wave = load_wav(s)
277
+ if wave is None:
278
+ continue
279
+ x = aud_proc(wave, sampling_rate=SR).input_values[0]
280
+ x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
281
+ h = aud_backbone(x)[0].mean(dim=1)
282
+ out = aud_head(h)[0].cpu().numpy()
283
+ vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
284
+ store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
285
+ if (i + 1) % 500 == 0:
286
+ np.savez(cache_path, **store)
287
+ if todo:
288
+ np.savez(cache_path, **store)
289
+ return store
290
+
291
+ # %% [markdown]
292
+ # ## 4. Cảm xúc target theo wavID (cho one-hot điều kiện của head EMOS)
293
+
294
+ # %%
295
+ def load_target_emotions():
296
+ tgt = {}
297
+ with open(METADATA_CSV, encoding="utf-8") as f:
298
+ for ln in f:
299
+ parts = ln.strip().split("|")
300
+ if len(parts) >= 2:
301
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
302
+ return tgt
303
+
304
+ target_map = load_target_emotions()
305
+ print("Target cảm xúc:", len(target_map), "wav")
306
+
307
+ def onehot_target(tgt):
308
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
309
+ if tgt in EMOTIONS5:
310
+ v[EMOTIONS5.index(tgt)] = 1.0
311
+ return v
312
+
313
+ # %% [markdown]
314
+ # ## 5. Khối Mamba (giống exp15) + MambaEncoder
315
+
316
+ # %%
317
+ import math
318
+
319
+ try:
320
+ from mamba_ssm import Mamba as _OfficialMamba
321
+ _HAS_MAMBA_SSM = True
322
+ print("✅ Dùng mamba-ssm (CUDA kernel)")
323
+ except Exception:
324
+ _HAS_MAMBA_SSM = False
325
+ print("ℹ️ Không có mamba-ssm → Mamba thuần PyTorch")
326
+
327
+ class MambaBlockTorch(nn.Module):
328
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
329
+ super().__init__()
330
+ self.d_inner = expand * d_model
331
+ self.dt_rank = math.ceil(d_model / 16)
332
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
333
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,
334
+ groups=self.d_inner, padding=d_conv - 1, bias=True)
335
+ self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
336
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
337
+ A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
338
+ self.A_log = nn.Parameter(torch.log(A))
339
+ self.D = nn.Parameter(torch.ones(self.d_inner))
340
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
341
+ self.d_state = d_state
342
+
343
+ def forward(self, x):
344
+ B, L, _ = x.shape
345
+ xin, z = self.in_proj(x).chunk(2, dim=-1)
346
+ xin = xin.transpose(1, 2)
347
+ xin = self.conv1d(xin)[..., :L].transpose(1, 2)
348
+ xin = F.silu(xin)
349
+ y = self._ssm(xin) * F.silu(z)
350
+ return self.out_proj(y)
351
+
352
+ def _ssm(self, x):
353
+ A = -torch.exp(self.A_log)
354
+ delta, Bm, Cm = torch.split(self.x_proj(x), [self.dt_rank, self.d_state, self.d_state], dim=-1)
355
+ delta = F.softplus(self.dt_proj(delta))
356
+ dA = torch.exp(delta.unsqueeze(-1) * A)
357
+ dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1)
358
+ h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
359
+ ys = []
360
+ for t in range(x.shape[1]):
361
+ h = dA[:, t] * h + dB_x[:, t]
362
+ ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1))
363
+ return torch.stack(ys, dim=1) + x * self.D
364
+
365
+ class MambaLayer(nn.Module):
366
+ def __init__(self, d_model, d_state):
367
+ super().__init__()
368
+ self.norm = nn.LayerNorm(d_model)
369
+ self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2) \
370
+ if _HAS_MAMBA_SSM else MambaBlockTorch(d_model, d_state=d_state)
371
+ def forward(self, x):
372
+ return x + self.mix(self.norm(x))
373
+
374
+ class MambaEncoder(nn.Module):
375
+ def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):
376
+ super().__init__()
377
+ self.bidir = bidir
378
+ self.proj = nn.Linear(d_in, d_model)
379
+ self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
380
+ if bidir:
381
+ self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
382
+ self.attn = nn.Linear(d_model, 1)
383
+ self.out = nn.Linear(d_model, z_dim)
384
+
385
+ @staticmethod
386
+ def _run(layers, h):
387
+ for L in layers:
388
+ h = L(h)
389
+ return h
390
+
391
+ def forward(self, x, mask):
392
+ with torch.cuda.amp.autocast(enabled=False):
393
+ x = x.float()
394
+ h = self.proj(x)
395
+ out = self._run(self.fwd, h)
396
+ if self.bidir:
397
+ out = out + torch.flip(self._run(self.bwd, torch.flip(h, dims=[1])), dims=[1])
398
+ a = self.attn(out).squeeze(-1).masked_fill(~mask, float("-inf"))
399
+ w = torch.softmax(a, dim=1).unsqueeze(-1)
400
+ return self.out((out * w).sum(1))
401
+
402
+ # %% [markdown]
403
+ # ## 6. Dựng enc + heads → nạp trọng số từ ckpt + lấy chuẩn hóa từ ckpt
404
+
405
+ # %%
406
+ N_EMO = len(EMOTIONS5)
407
+ WAVLM_BRANCH = Z_DIM if USE_MAMBA else WAVLM_DIM
408
+ TRUNK_IN = WAVLM_BRANCH + (AUD_DIM if USE_AUDEERING else 0)
409
+
410
+ enc = MambaEncoder(WAVLM_DIM, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL).to(device) \
411
+ if USE_MAMBA else None
412
+
413
+ class EmoHeads(nn.Module):
414
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
415
+ super().__init__()
416
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
417
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
418
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
419
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
420
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
421
+ def forward(self, feat, tgt):
422
+ h = self.trunk(feat)
423
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
424
+
425
+ heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
426
+ hm, hu = heads.load_state_dict(ckpt["heads"], strict=False)
427
+ print(f"🔁 load heads từ ckpt: thiếu {len(hm)} / dư {len(hu)} key (kỳ vọng 0)")
428
+ if USE_MAMBA:
429
+ assert ckpt.get("enc") is not None, "❌ ckpt USE_MAMBA=True nhưng KHÔNG có 'enc' → không inference đúng được."
430
+ em, eu = enc.load_state_dict(ckpt["enc"], strict=False)
431
+ print(f"🔁 load Mamba enc từ ckpt: thiếu {len(em)} / dư {len(eu)} key (kỳ vọng 0)")
432
+ heads.eval()
433
+ if USE_MAMBA:
434
+ enc.eval()
435
+
436
+ # Chuẩn hóa LẤY TỪ ckpt (head dự đoán ở thang z-score này → phải giải chuẩn đúng thang)
437
+ emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
438
+ vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
439
+ print(f"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
440
+
441
+ def wavlm_branch(input_values, attn_mask):
442
+ out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
443
+ if USE_MAMBA:
444
+ return enc(out, frame_mask(out.shape[1], attn_mask))
445
+ return masked_mean(out, attn_mask)
446
+
447
+ print(f"Trunk input = {TRUNK_IN} (wavlm-branch {WAVLM_BRANCH} [{'Mamba' if USE_MAMBA else 'mean-pool'}] + aud {AUD_DIM if USE_AUDEERING else 0})")
448
+
449
+ # %% [markdown]
450
+ # ## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc; QMOS mượn exp07/UTMOSv2)
451
+
452
+ # %%
453
+ def list_dev():
454
+ with open(DEV_SCP) as f:
455
+ return [ln.strip() for ln in f if ln.strip()]
456
+
457
+ dev_names = list_dev()
458
+ if LIMIT_DEV:
459
+ dev_names = dev_names[:LIMIT_DEV]
460
+ dev_stems = [stem(n) for n in dev_names]
461
+ print("DEV:", len(dev_names), "mẫu")
462
+ aud_dev = extract_audeering(dev_stems, "dev")
463
+
464
+ def load_exp07_qmos():
465
+ if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
466
+ import csv
467
+ d = {}
468
+ with open(EXP07_ANSWER) as f:
469
+ for row in csv.DictReader(f):
470
+ d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
471
+ print(f"✅ Mượn QMOS exp07 ({EXP07_ANSWER}): {len(d)//2} wav")
472
+ return d
473
+ return None
474
+
475
+ qmos_map = load_exp07_qmos()
476
+ if qmos_map is None:
477
+ print("ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).")
478
+ pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
479
+ import utmosv2
480
+ v2 = utmosv2.create_model(pretrained=True)
481
+ qmos_map = {}
482
+ for n in tqdm(dev_names, desc="UTMOSv2"):
483
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
484
+ if not os.path.exists(wav):
485
+ continue
486
+ out = v2.predict(input_path=wav)
487
+ qmos_map[n] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
488
+ del v2; torch.cuda.empty_cache() if device == "cuda" else None
489
+
490
+ @torch.no_grad()
491
+ def predict_emotion(sid):
492
+ wave = load_wav(sid)
493
+ if wave is None or (USE_AUDEERING and sid not in aud_dev):
494
+ return None
495
+ iv = torch.from_numpy(wave).unsqueeze(0).to(device)
496
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
497
+ tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
498
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
499
+ fw = wavlm_branch(iv, am)
500
+ feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw
501
+ emos_p, cat_l, vad_p = heads(feat, tgt)
502
+ emos = float(emos_p.item()) * emos_sd + emos_mu
503
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
504
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
505
+ return emos, cat5, vad3
506
+
507
+ def fmt_cat(p5):
508
+ return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
509
+
510
+ def build_answer(out_path):
511
+ n_real = n_def = 0
512
+ with open(out_path, "w") as f:
513
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
514
+ for name in tqdm(dev_names, desc="answer"):
515
+ sid = stem(name)
516
+ pr = predict_emotion(sid)
517
+ if pr is None:
518
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
519
+ else:
520
+ emos, cat5, vad3 = pr; n_real += 1
521
+ qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
522
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
523
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
524
+
525
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
526
+ build_answer(answer_path)
527
+
528
+ # %% [markdown]
529
+ # ## 8. Validate + đóng zip
530
+
531
+ # %%
532
+ def validate(path):
533
+ import csv
534
+ with open(path) as f:
535
+ rows = list(csv.reader(f))
536
+ assert rows[0][0] == "wav" and "QMOS" in rows[0] and "EMOS" in rows[0], "Header sai"
537
+ for i, r in enumerate(rows[1:], 2):
538
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
539
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
540
+
541
+ validate(answer_path)
542
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp15_predict.zip answer.txt "
543
+ f"&& unzip -l submission_track2_exp15_predict.zip")
544
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp15_predict.zip"))
545
+
546
+ # %% [markdown]
547
+ # ## Ghi chú
548
+ # - File này **chỉ inference** — không train, không cần train.csv. Dùng khi đã có `ft_mamba_emotion_full*.pt`.
549
+ # - ⚠️ **Siêu tham số Mamba/heads (MAMBA_DMODEL/LAYERS/DSTATE, TRUNK_HIDDEN, HEAD_HIDDEN) PHẢI khớp lúc train**
550
+ # (ckpt không lưu các số này) — nếu lúc train exp15 bạn đổi, hãy sửa cho khớp ở cell 0, nếu không load_state_dict
551
+ # sẽ lệch key / sai shape.
552
+ # - `USE_MAMBA`, `Z_DIM`, `AUD_DIM`, `UNFREEZE_TOP_LAYERS` thì **đọc tự động từ ckpt**.
553
+ # - QMOS: tốt nhất Add Input `answer.txt` exp07 (0.548); không có thì tự chấm UTMOSv2.
554
+ # - Smoke-test: đặt `LIMIT_DEV=20` chạy thử cho nhanh, OK rồi đặt lại `None` để chấm đủ 2730.
track2/exp15_wavlm_mamba_emotion.ipynb ADDED
@@ -0,0 +1,1081 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "5b4b651f",
6
+ "metadata": {},
7
+ "source": [
8
+ "# VMC2026 Track 2 — exp15 (WavLM FINE-TUNE + MAMBA head cho 5 cột cảm xúc) — Kaggle\n",
9
+ "\n",
10
+ "**Ý tưởng:** exp08 fine-tune WavLM nhưng vẫn **mean-pool** đặc trưng theo thời gian → 1 vector/wav\n",
11
+ "(vứt bỏ động lực thời gian: lên/xuống giọng, ngắt quãng, run giọng — rất quan trọng cho cảm xúc).\n",
12
+ "exp15 **thay mean-pool bằng MAMBA head** (bộ mã hóa chuỗi học được, độ phức tạp tuyến tính) → kỳ vọng\n",
13
+ "nắm temporal dynamics tốt hơn. Tham khảo: MambaRate (AudioMOS 2025, arXiv:2507.12090).\n",
14
+ "\n",
15
+ "## Kiến trúc (= exp08 đổi đúng 1 chỗ: pool → Mamba)\n",
16
+ "```\n",
17
+ " wav ─► WavLM-large (SAILER warm-start, mở băng N lớp, TRAINABLE) ─► hidden states (B, T, 1024)\n",
18
+ " │ (KHÔNG mean-pool)\n",
19
+ " MambaEncoder (proj 1024→d, Mamba×L 2 chiều,\n",
20
+ " attentive-pool có mask) ─► z (B, Z_DIM)\n",
21
+ " │\n",
22
+ " (tùy chọn) audeering MSP-dim FROZEN [emb|vad3] ──concat──► TRUNK ─┬─► EMOS (+ one-hot target)\n",
23
+ " ├─► CAT (5, soft-CE)\n",
24
+ " └─► VAD (3)\n",
25
+ " QMOS: KHÔNG train ở đây → mượn cột QMOS exp07 (0.548) hoặc UTMOSv2.\n",
26
+ "```\n",
27
+ "- **Cờ `USE_MAMBA`:** True = Mamba head; False = quay về `masked_mean` = **đúng exp08**\n",
28
+ " → đây là **ablation chính cho paper** (\"Mamba temporal head vs mean-pooling\", CÙNG backbone fine-tune).\n",
29
+ "\n",
30
+ "## ⚠️ Đánh đổi / gotcha (đã phòng trong code)\n",
31
+ "- Fine-tune = chạy lại WavLM mỗi epoch (không cache được) → **lần đầu BẮT BUỘC `LIMIT_TRAIN=300`, `LIMIT_DEV=20`**.\n",
32
+ "- `mamba-ssm` khó cài Kaggle → tự fallback **Mamba thuần PyTorch** (vòng-lặp-thời-gian). Bản này khi fine-tune\n",
33
+ " **chậm + nặng RAM hơn** → cap `MAX_SECONDS=6`, `BATCH=2`. OOM/quá chậm → hạ MAX_SECONDS→5, MAMBA_LAYERS→1,\n",
34
+ " hoặc thử cài `mamba-ssm causal-conv1d`.\n",
35
+ "- `layerdrop=0` (tránh CheckpointError khi grad-ckpt — bài học exp12). KHÔNG đụng numpy (lệch ABI).\n",
36
+ "- **Checkpoint lưu CẢ backbone + Mamba + heads mỗi best** (bài học exp08 mất backbone).\n",
37
+ "\n",
38
+ "## 🔁 RESUME (yêu cầu của user): \"nếu có checkpoint thì train TIẾP, không train lại từ đầu\"\n",
39
+ "- Notebook **tự dò** `ft_mamba_emotion_full.pt` trong `/kaggle/input` và `/kaggle/working` (hoặc trỏ tay `RESUME_CKPT`).\n",
40
+ "- Có ckpt đủ (backbone WavLM + Mamba enc + heads) → **nạp lại trạng thái + thống kê chuẩn hóa TỪ ckpt** rồi train tiếp;\n",
41
+ " `best` khởi tạo = điểm VAL của ckpt → chỉ ghi đè khi train tiếp **TỐT HƠN** (không sợ tụt). `RESUME_LR_SCALE<1` để hạ LR.\n",
42
+ "- KHÔNG có ckpt → train mới từ SAILER warm-start như cũ (hành vi exp15 gốc giữ nguyên).\n",
43
+ "\n",
44
+ "**Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input dataset Track 2 (+ Add Input checkpoint cũ nếu muốn resume)\n",
45
+ "→ sửa `DATA_ROOT` → Run All."
46
+ ]
47
+ },
48
+ {
49
+ "cell_type": "markdown",
50
+ "id": "194bcd01",
51
+ "metadata": {},
52
+ "source": [
53
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "id": "8ed47b3b",
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "import os, glob\n",
64
+ "\n",
65
+ "# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets/train.csv + wav/ + metadata.csv) ──\n",
66
+ "def find_data_root(search_root=\"/kaggle/input\"):\n",
67
+ " cands = []\n",
68
+ " for train_csv in glob.glob(os.path.join(search_root, \"**\", \"sets\", \"train.csv\"), recursive=True):\n",
69
+ " root = os.path.dirname(os.path.dirname(train_csv)) # .../<root>/sets/train.csv → <root>\n",
70
+ " score = os.path.isdir(os.path.join(root, \"wav\")) + os.path.exists(os.path.join(root, \"metadata.csv\"))\n",
71
+ " cands.append((score, root))\n",
72
+ " cands.sort(reverse=True) # ưu tiên thư mục đủ wav + metadata\n",
73
+ " return cands\n",
74
+ "\n",
75
+ "_cands = find_data_root(\"/kaggle/input\")\n",
76
+ "if _cands:\n",
77
+ " print(\"🔎 Ứng viên DATA_ROOT (điểm cao = đủ wav+metadata):\")\n",
78
+ " for sc, r in _cands:\n",
79
+ " print(f\" [{sc}/2] {r}\")\n",
80
+ " DATA_ROOT = _cands[0][1]\n",
81
+ " print(f\"👉 Tự chọn DATA_ROOT = {DATA_ROOT}\")\n",
82
+ "else:\n",
83
+ " DATA_ROOT = \"/kaggle/input/datasets/minhtoan2\" # dự phòng — sửa tay nếu auto-dò không thấy\n",
84
+ " print(f\"❌ Không thấy sets/train.csv trong /kaggle/input → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)\")\n",
85
+ "\n",
86
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
87
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
88
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
89
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
90
+ "\n",
91
+ "OUT_DIR = \"/kaggle/working\"\n",
92
+ "CACHE_DIR = \"/kaggle/working/ft_cache\" # cache audeering (.npz) — WavLM/Mamba KHÔNG cache (đang train)\n",
93
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
94
+ "\n",
95
+ "# (Tùy chọn) tái dùng cache audeering cũ (read-only /kaggle/input → copy sang working)\n",
96
+ "# Dataset cache_exp8: aud_*.npz nằm trong thư mục con archive/ → quét ĐỆ QUY để bắt mọi vị trí.\n",
97
+ "CACHE_INPUT = \"/kaggle/input/cache-exp8\" # << SỬA slug (dataset cache_exp8 → Kaggle đổi _→-); hoặc \"\"\n",
98
+ "if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
99
+ " import shutil\n",
100
+ " _n = 0\n",
101
+ " for _fp in glob.glob(os.path.join(CACHE_INPUT, \"**\", \"aud_*.npz\"), recursive=True):\n",
102
+ " shutil.copy(_fp, os.path.join(CACHE_DIR, os.path.basename(_fp))); _n += 1\n",
103
+ " print(f\"📦 Tái dùng cache: copy {_n} file aud_*.npz (quét đệ quy {CACHE_INPUT})\")\n",
104
+ "else:\n",
105
+ " print(f\"ℹ️ Không thấy CACHE_INPUT={CACHE_INPUT} → sẽ tự trích audeering.\")\n",
106
+ "\n",
107
+ "# Mượn cột QMOS exp07 (0.548). Trỏ answer.txt exp07 nếu có; không thì UTMOSv2.\n",
108
+ "EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn)\n",
109
+ "\n",
110
+ "# ── Cờ Mamba (ablation chính) ────────────────────────────────────────────────\n",
111
+ "USE_MAMBA = True # True = Mamba head; False = mean-pool = ĐÚNG exp08\n",
112
+ "\n",
113
+ "# ── Siêu tham số Mamba head ──────────────────────────────────────────────────\n",
114
+ "MAMBA_DMODEL = 256\n",
115
+ "MAMBA_LAYERS = 2\n",
116
+ "MAMBA_DSTATE = 16\n",
117
+ "BIDIRECTIONAL = True\n",
118
+ "Z_DIM = 256 # chiều vector ra sau attentive-pool, thay cho emb WavLM mean-pool\n",
119
+ "\n",
120
+ "# ── Fine-tune / siêu tham số (kế thừa exp08) ─────────────────────────────────\n",
121
+ "DEVICE = \"cuda\"\n",
122
+ "SR = 16000\n",
123
+ "MAX_SECONDS = 6 # giảm từ 8 (exp08) vì Mamba backprop-through-time nặng RAM hơn\n",
124
+ "UNFREEZE_TOP_LAYERS = 6 # số lớp Transformer trên cùng được train (0 = freeze hết)\n",
125
+ "TRUNK_HIDDEN = 512\n",
126
+ "HEAD_HIDDEN = 128\n",
127
+ "DROPOUT = 0.3\n",
128
+ "LR_BACKBONE = 1e-5\n",
129
+ "LR_HEAD = 1e-3 # cho Mamba + trunk + head (train từ đầu)\n",
130
+ "WEIGHT_DECAY = 1e-5\n",
131
+ "EPOCHS = 12\n",
132
+ "PATIENCE = 3\n",
133
+ "BATCH = 2 # nhỏ (backbone to + Mamba); bù bằng ACCUM\n",
134
+ "ACCUM = 16 # effective batch = 32\n",
135
+ "VAL_FRAC = 0.10\n",
136
+ "SEED = 42\n",
137
+ "USE_AMP = True\n",
138
+ "USE_GRAD_CKPT = True\n",
139
+ "USE_AUDEERING = True\n",
140
+ "USE_UNCERTAINTY = True\n",
141
+ "RANK_LAMBDA = 0.3 # 0 = chỉ MSE (cũ). >0 = thêm pairwise ranking loss (tối ưu thẳng SRCC) cho emos/val/aro/dom\n",
142
+ " # ⚠️ ranking cần NHIỀU cặp/batch mới mạnh → BATCH nhỏ (2) thì tác dụng yếu (xem Ghi chú)\n",
143
+ "\n",
144
+ "LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
145
+ "LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
146
+ "\n",
147
+ "# ── RESUME — train TIẾP từ checkpoint, KHÔNG train lại từ đầu ─────────────────\n",
148
+ "# Để \"\" + auto-dò: nếu thấy `ft_mamba_emotion_full.pt` (đủ backbone+Mamba+heads) trong /kaggle/input\n",
149
+ "# hoặc /kaggle/working → nạp lại rồi train tiếp. Trỏ tay RESUME_CKPT nếu muốn chỉ định file cụ thể.\n",
150
+ "RESUME_CKPT = \"\" # << \"\" = auto-dò; hoặc \"/kaggle/input/<slug>/ft_mamba_emotion_full.pt\"\n",
151
+ "RESUME_LR_SCALE = 1.0 # <1.0 hạ LR khi train tiếp (vd 0.5 nếu val đã chững)\n",
152
+ "\n",
153
+ "def find_resume_ckpt(explicit):\n",
154
+ " \"\"\"Tìm checkpoint exp15 để train tiếp. Ưu tiên đường dẫn user trỏ; không thì auto-dò.\n",
155
+ " Khớp cả tên bị Kaggle/Windows thêm hậu tố trùng, vd 'ft_mamba_emotion_full (2).pt'.\"\"\"\n",
156
+ " if explicit and os.path.exists(explicit):\n",
157
+ " return explicit\n",
158
+ " for base in [\"/kaggle/input\", \"/kaggle/working\"]:\n",
159
+ " hits = sorted(glob.glob(os.path.join(base, \"**\", \"ft_mamba_emotion_full*.pt\"), recursive=True))\n",
160
+ " if hits:\n",
161
+ " return hits[0]\n",
162
+ " return \"\"\n",
163
+ "\n",
164
+ "RESUME_CKPT = find_resume_ckpt(RESUME_CKPT)\n",
165
+ "RESUME = bool(RESUME_CKPT)\n",
166
+ "print(\"🔁 RESUME =\", RESUME, (\"→ train tiếp từ: \" + RESUME_CKPT) if RESUME else \"(không thấy ckpt → train MỚI từ đầu)\")\n",
167
+ "\n",
168
+ "# Mốc so (exp08 fine-tune + mean-pool — đối thủ trực tiếp của Mamba head)\n",
169
+ "EXP08 = {\"emos\": 0.811, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751}\n",
170
+ "\n",
171
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
172
+ "\n",
173
+ "_EMO_ALIAS = {\n",
174
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
175
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
176
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
177
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
178
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
179
+ "}\n",
180
+ "\n",
181
+ "def norm_emotion(label):\n",
182
+ " key = str(label).strip().lower()\n",
183
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
184
+ "\n",
185
+ "def stem(p):\n",
186
+ " return os.path.splitext(os.path.basename(str(p)))[0]\n",
187
+ "\n",
188
+ "print(\"USE_MAMBA =\", USE_MAMBA, \"(False → ra đúng exp08)\")\n",
189
+ "print(\"DATA_ROOT:\", DATA_ROOT)\n",
190
+ "for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
191
+ " print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
192
+ "print(f\"Fine-tune: mở băng {UNFREEZE_TOP_LAYERS} lớp · BATCH {BATCH}×ACCUM {ACCUM} · MAX {MAX_SECONDS}s\")"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "id": "c8010473",
198
+ "metadata": {},
199
+ "source": [
200
+ "## 1. Cài đặt + tải code SAILER (clone + sys.path)"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": null,
206
+ "id": "1b8d9fad",
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "import sys, subprocess\n",
211
+ "\n",
212
+ "def pip_install(*pkgs):\n",
213
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
214
+ "\n",
215
+ "pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\",\n",
216
+ " \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
217
+ "\n",
218
+ "# Cài kernel CUDA Mamba (nhanh + nhẹ RAM hơn bản thuần PyTorch nhiều). Build hay lỗi/chậm trên Kaggle\n",
219
+ "# → bọc try/except: lỗi thì BỎ QUA, mục 6a tự fallback Mamba thuần PyTorch. KHÔNG để chết notebook.\n",
220
+ "INSTALL_MAMBA_SSM = True # đặt False nếu muốn BỎ QUA, dùng thẳng Mamba thuần PyTorch\n",
221
+ "if INSTALL_MAMBA_SSM and USE_MAMBA:\n",
222
+ " try:\n",
223
+ " # --no-build-isolation cho CẢ HAI → dùng torch+CUDA sẵn có của Kaggle để biên dịch (đừng kéo torch khác).\n",
224
+ " # Cần ninja để build nhanh. -q ẩn log nên bước này có thể \"treo\" vài phút khi đang compile — bình thường.\n",
225
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"ninja\"], check=True)\n",
226
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n",
227
+ " \"--no-build-isolation\", \"causal-conv1d>=1.2.0\"], check=True)\n",
228
+ " subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n",
229
+ " \"--no-build-isolation\", \"mamba-ssm\"], check=True)\n",
230
+ " print(\"✅ Cài mamba-ssm + causal-conv1d xong (sẽ dùng kernel CUDA nếu import được).\")\n",
231
+ " except Exception as e:\n",
232
+ " print(\"⚠️ Cài mamba-ssm thất bại:\", repr(e), \"→ dùng Mamba thuần PyTorch (chậm hơn).\")\n",
233
+ " print(\" ℹ️ Vẫn chạy bình thường. Nếu chạy THẬT (LIMIT=None) quá chậm → xem Ghi chú cuối notebook.\")\n",
234
+ "\n",
235
+ "REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
236
+ "if not os.path.exists(REPO_DIR):\n",
237
+ " subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
238
+ " \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
239
+ "if REPO_DIR not in sys.path:\n",
240
+ " sys.path.insert(0, REPO_DIR)"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "id": "598d74d9",
246
+ "metadata": {},
247
+ "source": [
248
+ "## 2. Nạp SAILER → lấy backbone WavLM bên trong để FINE-TUNE (warm-start)"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "code",
253
+ "execution_count": null,
254
+ "id": "5346a63d",
255
+ "metadata": {
256
+ "lines_to_next_cell": 1
257
+ },
258
+ "outputs": [],
259
+ "source": [
260
+ "import torch\n",
261
+ "import torch.nn as nn\n",
262
+ "import torch.nn.functional as F\n",
263
+ "\n",
264
+ "device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
265
+ "print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
266
+ "\n",
267
+ "def find_hf_backbone(module):\n",
268
+ " \"\"\"Tìm submodule kiểu HF WavLM backbone: có .feature_extractor và .encoder.layers.\"\"\"\n",
269
+ " cands = []\n",
270
+ " for name, m in module.named_modules():\n",
271
+ " enc = getattr(m, \"encoder\", None)\n",
272
+ " if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
273
+ " and getattr(enc, \"layers\", None) is not None:\n",
274
+ " cands.append((name, m))\n",
275
+ " if not cands:\n",
276
+ " return None, None\n",
277
+ " cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
278
+ " return cands[0]\n",
279
+ "\n",
280
+ "wavlm = None\n",
281
+ "try:\n",
282
+ " from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
283
+ " _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
284
+ " name, wavlm = find_hf_backbone(_wrapper)\n",
285
+ " if wavlm is not None:\n",
286
+ " print(f\"✅ Warm-start SAILER: backbone WavLM tại '.{name}' \"\n",
287
+ " f\"({sum(p.numel() for p in wavlm.parameters())/1e6:.0f}M params)\")\n",
288
+ " else:\n",
289
+ " print(\"⚠️ Không tìm thấy backbone HF trong wrapper SAILER → fallback WavLM trắng.\")\n",
290
+ "except Exception as e:\n",
291
+ " print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
292
+ "\n",
293
+ "if wavlm is None:\n",
294
+ " from transformers import WavLMModel\n",
295
+ " wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
296
+ " print(\"ℹ️ Fallback: microsoft/wavlm-large (KHÔNG warm-start SAILER).\")\n",
297
+ "\n",
298
+ "wavlm = wavlm.to(device)\n",
299
+ "WAVLM_DIM = int(wavlm.config.hidden_size)\n",
300
+ "wavlm.config.layerdrop = 0.0 # ⚠️ tránh CheckpointError khi grad-ckpt (bài học exp12)\n",
301
+ "\n",
302
+ "# ── RESUME: nạp trọng số backbone đã fine-tune từ checkpoint (đè lên warm-start SAILER) ──\n",
303
+ "resume_ckpt = None\n",
304
+ "if RESUME:\n",
305
+ " resume_ckpt = torch.load(RESUME_CKPT, map_location=\"cpu\", weights_only=False) # ckpt có numpy → cần False\n",
306
+ " assert \"wavlm\" in resume_ckpt, (\"❌ Checkpoint KHÔNG có 'wavlm' (backbone) → không resume được. \"\n",
307
+ " \"Dùng file ft_mamba_emotion_full.pt do exp15 lưu.\")\n",
308
+ " if resume_ckpt.get(\"USE_MAMBA\", USE_MAMBA) != USE_MAMBA:\n",
309
+ " print(f\" ⚠️ ckpt USE_MAMBA={resume_ckpt.get('USE_MAMBA')} ≠ cấu hình hiện tại {USE_MAMBA} → kiến trúc LỆCH! \"\n",
310
+ " \"Đặt USE_MAMBA cho khớp ckpt.\")\n",
311
+ " miss, unexp = wavlm.load_state_dict(resume_ckpt[\"wavlm\"], strict=False)\n",
312
+ " print(f\"🔁 RESUME load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0). keys ckpt:\", list(resume_ckpt.keys()))\n",
313
+ " if len(miss) > 20 or len(unexp) > 20:\n",
314
+ " print(\" ⚠️ Lệch key nhiều → kiểm tra UNFREEZE_TOP_LAYERS / backbone có khớp ckpt không.\")\n",
315
+ "\n",
316
+ "# ── Đóng băng partial: feature-extractor + tất cả trừ UNFREEZE_TOP_LAYERS lớp trên ──\n",
317
+ "for p in wavlm.parameters():\n",
318
+ " p.requires_grad = False\n",
319
+ "enc_layers = wavlm.encoder.layers\n",
320
+ "n_layers = len(enc_layers)\n",
321
+ "for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:\n",
322
+ " for p in layer.parameters():\n",
323
+ " p.requires_grad = True\n",
324
+ "n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)\n",
325
+ "print(f\"WavLM: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {WAVLM_DIM})\")\n",
326
+ "\n",
327
+ "if USE_GRAD_CKPT:\n",
328
+ " wavlm.gradient_checkpointing_enable()\n",
329
+ " if hasattr(wavlm, \"enable_input_require_grads\"):\n",
330
+ " wavlm.enable_input_require_grads()\n",
331
+ "\n",
332
+ "def frame_mask(T, attn_mask):\n",
333
+ " \"\"\"attn_mask (B, Lwav) → frame-mask (B, T) bool (True=frame thật). Khớp downsample của WavLM.\"\"\"\n",
334
+ " if attn_mask is None:\n",
335
+ " return torch.ones((1, T), dtype=torch.bool, device=device)\n",
336
+ " try:\n",
337
+ " fm = wavlm._get_feature_vector_attention_mask(T, attn_mask)\n",
338
+ " return fm.bool()\n",
339
+ " except Exception:\n",
340
+ " return torch.ones((attn_mask.shape[0], T), dtype=torch.bool, device=attn_mask.device)\n",
341
+ "\n",
342
+ "def masked_mean(hidden, attn_mask):\n",
343
+ " \"\"\"Mean-pool theo thời gian bỏ pad (đường exp08 khi USE_MAMBA=False).\"\"\"\n",
344
+ " if attn_mask is None:\n",
345
+ " return hidden.mean(dim=1)\n",
346
+ " fm = frame_mask(hidden.shape[1], attn_mask).unsqueeze(-1).to(hidden.dtype)\n",
347
+ " return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)"
348
+ ]
349
+ },
350
+ {
351
+ "cell_type": "markdown",
352
+ "id": "c72d5983",
353
+ "metadata": {},
354
+ "source": [
355
+ "## 3. Nạp audeering MSP-dim (FROZEN) — đặc trưng phụ (như exp08)"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "code",
360
+ "execution_count": null,
361
+ "id": "d967397d",
362
+ "metadata": {},
363
+ "outputs": [],
364
+ "source": [
365
+ "AUD_DIM = 0\n",
366
+ "aud_backbone = aud_head = aud_proc = None\n",
367
+ "if USE_AUDEERING:\n",
368
+ " from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
369
+ " from huggingface_hub import hf_hub_download\n",
370
+ " AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
371
+ " aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
372
+ " aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
373
+ " aud_backbone = Wav2Vec2Model(aud_cfg)\n",
374
+ " try:\n",
375
+ " _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
376
+ " hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
377
+ " except Exception:\n",
378
+ " _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
379
+ " bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
380
+ " aud_backbone.load_state_dict(bb_sd, strict=False)\n",
381
+ " _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
382
+ " _out = _sd[\"classifier.out_proj.weight\"].shape[0]\n",
383
+ " aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
384
+ " aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
385
+ " aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
386
+ " aud_backbone = aud_backbone.to(device).eval()\n",
387
+ " aud_head = aud_head.to(device).eval()\n",
388
+ " AUD_DIM = _hid + 3\n",
389
+ " print(f\"✅ audeering frozen (đặc trưng phụ {AUD_DIM}-D = emb {_hid} + vad 3)\")"
390
+ ]
391
+ },
392
+ {
393
+ "cell_type": "code",
394
+ "execution_count": null,
395
+ "id": "1a5f1592",
396
+ "metadata": {
397
+ "lines_to_next_cell": 1
398
+ },
399
+ "outputs": [],
400
+ "source": [
401
+ "import numpy as np\n",
402
+ "import librosa\n",
403
+ "from tqdm.auto import tqdm\n",
404
+ "\n",
405
+ "def load_wav(name_or_stem):\n",
406
+ " p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(\n",
407
+ " WAV_DIR, name_or_stem if str(name_or_stem).endswith(\".wav\") else str(name_or_stem) + \".wav\")\n",
408
+ " if not os.path.exists(p):\n",
409
+ " return None\n",
410
+ " wave, _ = librosa.load(p, sr=SR, mono=True)\n",
411
+ " return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
412
+ "\n",
413
+ "@torch.no_grad()\n",
414
+ "def extract_audeering(stems, tag):\n",
415
+ " if not USE_AUDEERING:\n",
416
+ " return {}\n",
417
+ " cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
418
+ " store = {}\n",
419
+ " if os.path.exists(cache_path):\n",
420
+ " z = np.load(cache_path, allow_pickle=True)\n",
421
+ " store = {k: z[k] for k in z.files}\n",
422
+ " print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
423
+ " todo = [s for s in stems if s not in store]\n",
424
+ " for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
425
+ " wave = load_wav(s)\n",
426
+ " if wave is None:\n",
427
+ " continue\n",
428
+ " x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
429
+ " x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
430
+ " h = aud_backbone(x)[0].mean(dim=1)\n",
431
+ " out = aud_head(h)[0].cpu().numpy() # [arousal, dominance, valence] ∈[0,1]\n",
432
+ " vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
433
+ " store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
434
+ " if (i + 1) % 500 == 0:\n",
435
+ " np.savez(cache_path, **store)\n",
436
+ " if todo:\n",
437
+ " np.savez(cache_path, **store)\n",
438
+ " return store"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "markdown",
443
+ "id": "50717e09",
444
+ "metadata": {},
445
+ "source": [
446
+ "## 4. Đọc & gộp nhãn theo wavID (EMOS / VAD / CAT) — như exp08"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": null,
452
+ "id": "b5c3e935",
453
+ "metadata": {},
454
+ "outputs": [],
455
+ "source": [
456
+ "import pandas as pd\n",
457
+ "\n",
458
+ "def load_target_emotions():\n",
459
+ " tgt = {}\n",
460
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
461
+ " for ln in f:\n",
462
+ " parts = ln.strip().split(\"|\")\n",
463
+ " if len(parts) >= 2:\n",
464
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
465
+ " return tgt\n",
466
+ "\n",
467
+ "def _col(cols_map, *names, df=None, default_idx=None):\n",
468
+ " for n in names:\n",
469
+ " if n in cols_map:\n",
470
+ " return cols_map[n]\n",
471
+ " return list(df.columns)[default_idx] if default_idx is not None else None\n",
472
+ "\n",
473
+ "def parse_emocat_votes(cell):\n",
474
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
475
+ " for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
476
+ " e = norm_emotion(tok)\n",
477
+ " if e in EMOTIONS5:\n",
478
+ " v[EMOTIONS5.index(e)] += 1.0\n",
479
+ " return v\n",
480
+ "\n",
481
+ "def load_train_labels():\n",
482
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
483
+ " cols = {c.lower().strip(): c for c in df.columns}\n",
484
+ " wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
485
+ " emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
486
+ " val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
487
+ " cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
488
+ " assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
489
+ " df[\"_stem\"] = df[wav_col].map(stem)\n",
490
+ " rows = []\n",
491
+ " for sid, g in df.groupby(\"_stem\"):\n",
492
+ " rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
493
+ " rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
494
+ " rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
495
+ " rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
496
+ " votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
497
+ " if cat_col:\n",
498
+ " for cell in g[cat_col]:\n",
499
+ " votes += parse_emocat_votes(cell)\n",
500
+ " s = votes.sum()\n",
501
+ " cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
502
+ " for i in range(len(EMOTIONS5)):\n",
503
+ " rec[f\"cat{i}\"] = float(cat[i])\n",
504
+ " rows.append(rec)\n",
505
+ " return pd.DataFrame(rows)\n",
506
+ "\n",
507
+ "target_map = load_target_emotions()\n",
508
+ "train_df = load_train_labels()\n",
509
+ "HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
510
+ "print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
511
+ ]
512
+ },
513
+ {
514
+ "cell_type": "markdown",
515
+ "id": "b5ab79d3",
516
+ "metadata": {},
517
+ "source": [
518
+ "## 5. Dataset / DataLoader (load wav theo batch — KHÔNG cache WavLM vì đang train)"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "id": "9989f142",
525
+ "metadata": {},
526
+ "outputs": [],
527
+ "source": [
528
+ "from torch.utils.data import Dataset, DataLoader\n",
529
+ "\n",
530
+ "train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
531
+ "if LIMIT_TRAIN:\n",
532
+ " train_stems = train_stems[:LIMIT_TRAIN]\n",
533
+ "aud_tr = extract_audeering(train_stems, \"train\")\n",
534
+ "\n",
535
+ "lab = train_df.set_index(\"wavID\")\n",
536
+ "\n",
537
+ "def _zfit(arr):\n",
538
+ " a = np.asarray(arr, dtype=np.float32)\n",
539
+ " return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
540
+ "\n",
541
+ "if RESUME and resume_ckpt is not None:\n",
542
+ " # QUAN TRỌNG: lấy chuẩn hóa TỪ ckpt (head đã train theo thang này) — KHÔNG tính lại để khỏi lệch thang\n",
543
+ " emos_mu = float(resume_ckpt[\"emos_mu\"]); emos_sd = float(resume_ckpt[\"emos_sd\"])\n",
544
+ " vad_mu = np.asarray(resume_ckpt[\"vad_mu\"], dtype=np.float32)\n",
545
+ " vad_sd = np.asarray(resume_ckpt[\"vad_sd\"], dtype=np.float32)\n",
546
+ " print(f\"🔁 RESUME: dùng chuẩn hóa TỪ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
547
+ "else:\n",
548
+ " emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
549
+ " if HAS_VAD:\n",
550
+ " vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
551
+ " vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
552
+ " else:\n",
553
+ " vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
554
+ "\n",
555
+ "def onehot_target(tgt):\n",
556
+ " v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
557
+ " if tgt in EMOTIONS5:\n",
558
+ " v[EMOTIONS5.index(tgt)] = 1.0\n",
559
+ " return v\n",
560
+ "\n",
561
+ "class EmoDataset(Dataset):\n",
562
+ " def __init__(self, stems):\n",
563
+ " self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]\n",
564
+ " def __len__(self):\n",
565
+ " return len(self.stems)\n",
566
+ " def __getitem__(self, i):\n",
567
+ " s = self.stems[i]\n",
568
+ " wave = load_wav(s)\n",
569
+ " emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
570
+ " if HAS_VAD:\n",
571
+ " vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
572
+ " else:\n",
573
+ " vad = np.zeros(3, dtype=np.float32)\n",
574
+ " cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
575
+ " aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)\n",
576
+ " return {\"wave\": wave, \"tgt\": onehot_target(target_map.get(s)), \"aud\": aud,\n",
577
+ " \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
578
+ " \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
579
+ " \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
580
+ "\n",
581
+ "def collate(batch):\n",
582
+ " L = max(len(b[\"wave\"]) for b in batch)\n",
583
+ " waves = np.zeros((len(batch), L), dtype=np.float32)\n",
584
+ " mask = np.zeros((len(batch), L), dtype=np.float32)\n",
585
+ " for i, b in enumerate(batch):\n",
586
+ " waves[i, : len(b[\"wave\"])] = b[\"wave\"]; mask[i, : len(b[\"wave\"])] = 1.0\n",
587
+ " return {\n",
588
+ " \"input_values\": torch.from_numpy(waves), \"attn_mask\": torch.from_numpy(mask).long(),\n",
589
+ " \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
590
+ " \"aud\": torch.from_numpy(np.stack([b[\"aud\"] for b in batch])) if USE_AUDEERING else None,\n",
591
+ " \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
592
+ " \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
593
+ " \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
594
+ " \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
595
+ " \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
596
+ " }\n",
597
+ "\n",
598
+ "from sklearn.model_selection import train_test_split\n",
599
+ "ds = EmoDataset(train_stems)\n",
600
+ "print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
601
+ "tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
602
+ "tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
603
+ "va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "markdown",
608
+ "id": "6006ec6c",
609
+ "metadata": {},
610
+ "source": [
611
+ "## 6a. Khối MAMBA (thuần PyTorch, fallback nếu không có `mamba-ssm`)\n",
612
+ "Theo \"mamba-minimal\" — đúng công thức selective SSM, chỉ chậm hơn kernel CUDA. Chạy trong fp32 cho ổn định."
613
+ ]
614
+ },
615
+ {
616
+ "cell_type": "code",
617
+ "execution_count": null,
618
+ "id": "b9089952",
619
+ "metadata": {
620
+ "lines_to_next_cell": 1
621
+ },
622
+ "outputs": [],
623
+ "source": [
624
+ "import math\n",
625
+ "\n",
626
+ "try:\n",
627
+ " from mamba_ssm import Mamba as _OfficialMamba\n",
628
+ " _HAS_MAMBA_SSM = True\n",
629
+ " print(\"✅ Dùng mamba-ssm (CUDA kernel)\")\n",
630
+ "except Exception:\n",
631
+ " _HAS_MAMBA_SSM = False\n",
632
+ " print(\"ℹ️ Không có mamba-ssm → Mamba thuần PyTorch (chậm hơn khi fine-tune)\")\n",
633
+ "\n",
634
+ "class MambaBlockTorch(nn.Module):\n",
635
+ " def __init__(self, d_model, d_state=16, d_conv=4, expand=2):\n",
636
+ " super().__init__()\n",
637
+ " self.d_inner = expand * d_model\n",
638
+ " self.dt_rank = math.ceil(d_model / 16)\n",
639
+ " self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)\n",
640
+ " self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,\n",
641
+ " groups=self.d_inner, padding=d_conv - 1, bias=True)\n",
642
+ " self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)\n",
643
+ " self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)\n",
644
+ " A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)\n",
645
+ " self.A_log = nn.Parameter(torch.log(A))\n",
646
+ " self.D = nn.Parameter(torch.ones(self.d_inner))\n",
647
+ " self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)\n",
648
+ " self.d_state = d_state\n",
649
+ "\n",
650
+ " def forward(self, x): # x: (B, L, d_model)\n",
651
+ " B, L, _ = x.shape\n",
652
+ " xin, z = self.in_proj(x).chunk(2, dim=-1)\n",
653
+ " xin = xin.transpose(1, 2)\n",
654
+ " xin = self.conv1d(xin)[..., :L].transpose(1, 2)\n",
655
+ " xin = F.silu(xin)\n",
656
+ " y = self._ssm(xin) * F.silu(z)\n",
657
+ " return self.out_proj(y)\n",
658
+ "\n",
659
+ " def _ssm(self, x):\n",
660
+ " A = -torch.exp(self.A_log)\n",
661
+ " delta, Bm, Cm = torch.split(self.x_proj(x), [self.dt_rank, self.d_state, self.d_state], dim=-1)\n",
662
+ " delta = F.softplus(self.dt_proj(delta))\n",
663
+ " dA = torch.exp(delta.unsqueeze(-1) * A)\n",
664
+ " dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1)\n",
665
+ " h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)\n",
666
+ " ys = []\n",
667
+ " for t in range(x.shape[1]):\n",
668
+ " h = dA[:, t] * h + dB_x[:, t]\n",
669
+ " ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1))\n",
670
+ " return torch.stack(ys, dim=1) + x * self.D\n",
671
+ "\n",
672
+ "class MambaLayer(nn.Module):\n",
673
+ " def __init__(self, d_model, d_state):\n",
674
+ " super().__init__()\n",
675
+ " self.norm = nn.LayerNorm(d_model)\n",
676
+ " self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2) \\\n",
677
+ " if _HAS_MAMBA_SSM else MambaBlockTorch(d_model, d_state=d_state)\n",
678
+ " def forward(self, x):\n",
679
+ " return x + self.mix(self.norm(x))\n",
680
+ "\n",
681
+ "class MambaEncoder(nn.Module):\n",
682
+ " \"\"\"1024 → d_model → [Mamba ×L] (2 chiều) → attentive-pool (có mask) → Z_DIM.\"\"\"\n",
683
+ " def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):\n",
684
+ " super().__init__()\n",
685
+ " self.bidir = bidir\n",
686
+ " self.proj = nn.Linear(d_in, d_model)\n",
687
+ " self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
688
+ " if bidir:\n",
689
+ " self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
690
+ " self.attn = nn.Linear(d_model, 1)\n",
691
+ " self.out = nn.Linear(d_model, z_dim)\n",
692
+ "\n",
693
+ " @staticmethod\n",
694
+ " def _run(layers, h):\n",
695
+ " for L in layers:\n",
696
+ " h = L(h)\n",
697
+ " return h\n",
698
+ "\n",
699
+ " def forward(self, x, mask): # x:(B,L,1024) mask:(B,L) bool\n",
700
+ " with torch.cuda.amp.autocast(enabled=False): # SSM chạy fp32 cho ổn định\n",
701
+ " x = x.float()\n",
702
+ " h = self.proj(x)\n",
703
+ " out = self._run(self.fwd, h)\n",
704
+ " if self.bidir:\n",
705
+ " out = out + torch.flip(self._run(self.bwd, torch.flip(h, dims=[1])), dims=[1])\n",
706
+ " a = self.attn(out).squeeze(-1).masked_fill(~mask, float(\"-inf\"))\n",
707
+ " w = torch.softmax(a, dim=1).unsqueeze(-1)\n",
708
+ " return self.out((out * w).sum(1))"
709
+ ]
710
+ },
711
+ {
712
+ "cell_type": "markdown",
713
+ "id": "ff1cec20",
714
+ "metadata": {},
715
+ "source": [
716
+ "## 6b. Head cảm xúc + train loop (AMP + grad-accum + uncertainty weighting)"
717
+ ]
718
+ },
719
+ {
720
+ "cell_type": "code",
721
+ "execution_count": null,
722
+ "id": "c414e504",
723
+ "metadata": {
724
+ "lines_to_next_cell": 1
725
+ },
726
+ "outputs": [],
727
+ "source": [
728
+ "from scipy.stats import spearmanr\n",
729
+ "\n",
730
+ "torch.manual_seed(SEED); np.random.seed(SEED)\n",
731
+ "N_EMO = len(EMOTIONS5)\n",
732
+ "WAVLM_BRANCH = Z_DIM if USE_MAMBA else WAVLM_DIM\n",
733
+ "TRUNK_IN = WAVLM_BRANCH + (AUD_DIM if USE_AUDEERING else 0)\n",
734
+ "\n",
735
+ "enc = MambaEncoder(WAVLM_DIM, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL).to(device) \\\n",
736
+ " if USE_MAMBA else None\n",
737
+ "\n",
738
+ "class EmoHeads(nn.Module):\n",
739
+ " def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
740
+ " super().__init__()\n",
741
+ " self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
742
+ " nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
743
+ " self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
744
+ " self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
745
+ " self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
746
+ " def forward(self, feat, tgt):\n",
747
+ " h = self.trunk(feat)\n",
748
+ " return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
749
+ "\n",
750
+ "heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
751
+ "print(f\"Trunk input = {TRUNK_IN} (wavlm-branch {WAVLM_BRANCH} [{'Mamba' if USE_MAMBA else 'mean-pool'}] + aud {AUD_DIM if USE_AUDEERING else 0})\")\n",
752
+ "if USE_MAMBA:\n",
753
+ " print(f\"Mamba encoder: {sum(p.numel() for p in enc.parameters())/1e6:.2f}M param\")\n",
754
+ "\n",
755
+ "# ── RESUME: nạp heads (+ Mamba enc) từ checkpoint ──\n",
756
+ "if RESUME and resume_ckpt is not None:\n",
757
+ " hm, hu = heads.load_state_dict(resume_ckpt[\"heads\"], strict=False)\n",
758
+ " print(f\"🔁 RESUME load heads từ ckpt: thiếu {len(hm)} / dư {len(hu)} key (kỳ vọng 0)\")\n",
759
+ " if USE_MAMBA and resume_ckpt.get(\"enc\") is not None:\n",
760
+ " em, eu = enc.load_state_dict(resume_ckpt[\"enc\"], strict=False)\n",
761
+ " print(f\"🔁 RESUME load Mamba enc từ ckpt: thiếu {len(em)} / dư {len(eu)} key (kỳ vọng 0)\")\n",
762
+ " elif USE_MAMBA:\n",
763
+ " print(\" ⚠️ ckpt KHÔNG có 'enc' (Mamba) → Mamba head train lại từ đầu (chỉ resume backbone+heads).\")\n",
764
+ "\n",
765
+ "TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
766
+ "log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
767
+ "bb_params = [p for p in wavlm.parameters() if p.requires_grad]\n",
768
+ "head_params = list(heads.parameters()) + (list(enc.parameters()) if USE_MAMBA else []) \\\n",
769
+ " + ([log_var] if USE_UNCERTAINTY else [])\n",
770
+ "_lr_scale = RESUME_LR_SCALE if RESUME else 1.0\n",
771
+ "opt = torch.optim.AdamW([\n",
772
+ " {\"params\": bb_params, \"lr\": LR_BACKBONE * _lr_scale},\n",
773
+ " {\"params\": head_params, \"lr\": LR_HEAD * _lr_scale},\n",
774
+ "], weight_decay=WEIGHT_DECAY)\n",
775
+ "if RESUME and _lr_scale != 1.0:\n",
776
+ " print(f\"🔁 RESUME: LR ×{_lr_scale} → backbone {LR_BACKBONE*_lr_scale:.1e} · head {LR_HEAD*_lr_scale:.1e}\")\n",
777
+ "scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
778
+ "mse = nn.MSELoss()\n",
779
+ "\n",
780
+ "def soft_ce(logits, target_dist):\n",
781
+ " return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
782
+ "\n",
783
+ "def wavlm_branch(input_values, attn_mask):\n",
784
+ " out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state # (B,T,D)\n",
785
+ " if USE_MAMBA:\n",
786
+ " return enc(out, frame_mask(out.shape[1], attn_mask)) # (B, Z_DIM)\n",
787
+ " return masked_mean(out, attn_mask) # (B, D)\n",
788
+ "\n",
789
+ "def forward_batch(b):\n",
790
+ " fw = wavlm_branch(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
791
+ " feat = torch.cat([fw, b[\"aud\"].to(device)], dim=1) if USE_AUDEERING else fw\n",
792
+ " return heads(feat, b[\"tgt\"].to(device))\n",
793
+ "\n",
794
+ "def pairwise_rank_loss(pred, target):\n",
795
+ " \"\"\"Hinge ranking trên MỌI cặp trong batch → tối ưu thẳng thứ hạng (≈ SRCC). Khả vi (backprop được).\n",
796
+ " Cần ≥2 mẫu/batch mới có cặp; batch càng to càng nhiều cặp → tín hiệu càng mạnh.\"\"\"\n",
797
+ " p = pred.reshape(-1); t = target.reshape(-1)\n",
798
+ " if p.numel() < 2:\n",
799
+ " return torch.zeros((), device=p.device)\n",
800
+ " sign = torch.sign(t.unsqueeze(0) - t.unsqueeze(1)) # +1 nếu câu i ĐÁNG cao hơn câu j\n",
801
+ " diff = p.unsqueeze(0) - p.unsqueeze(1) # chênh lệch model dự đoán\n",
802
+ " return torch.relu(-sign * diff).mean() # phạt khi xếp sai thứ tự\n",
803
+ "\n",
804
+ "def compute_loss(emos_p, cat_l, vad_p, b):\n",
805
+ " L = {\"emos\": mse(emos_p, b[\"emos\"].to(device)), \"cat\": soft_ce(cat_l, b[\"cat\"].to(device))}\n",
806
+ " if HAS_VAD:\n",
807
+ " vt = b[\"vad\"].to(device)\n",
808
+ " L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
809
+ " else:\n",
810
+ " vt = None\n",
811
+ " z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
812
+ " # Ranking loss CHỈ cho các cột chấm SRCC (emos/val/aro/dom). CAT là ERR phân bố → giữ soft-CE.\n",
813
+ " if RANK_LAMBDA > 0:\n",
814
+ " L[\"emos\"] = L[\"emos\"] + RANK_LAMBDA * pairwise_rank_loss(emos_p, b[\"emos\"].to(device))\n",
815
+ " if HAS_VAD:\n",
816
+ " L[\"val\"] = L[\"val\"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 0:1], vt[:, 0:1])\n",
817
+ " L[\"aro\"] = L[\"aro\"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 1:2], vt[:, 1:2])\n",
818
+ " L[\"dom\"] = L[\"dom\"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 2:3], vt[:, 2:3])\n",
819
+ " if USE_UNCERTAINTY:\n",
820
+ " return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
821
+ " return sum(L.values())\n",
822
+ "\n",
823
+ "def set_mode(train):\n",
824
+ " wavlm.train(train); heads.train(train)\n",
825
+ " if USE_MAMBA:\n",
826
+ " enc.train(train)\n",
827
+ "\n",
828
+ "@torch.no_grad()\n",
829
+ "def evaluate():\n",
830
+ " set_mode(False)\n",
831
+ " P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
832
+ " catP, catY = [], []\n",
833
+ " for b in va_loader:\n",
834
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
835
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
836
+ " P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
837
+ " vad_p = vad_p.float().cpu().numpy()\n",
838
+ " for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
839
+ " P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
840
+ " catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
841
+ " out = {t: spearmanr(P[t], Y[t]).correlation for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])}\n",
842
+ " q = np.concatenate(catP); p = np.concatenate(catY)\n",
843
+ " out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
844
+ " return out\n",
845
+ "\n",
846
+ "def mean_srcc(m):\n",
847
+ " keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
848
+ " return float(np.mean([m[k] for k in keys]))\n",
849
+ "\n",
850
+ "CKPT_PATH = os.path.join(OUT_DIR, \"ft_mamba_emotion_full.pt\")\n",
851
+ "def save_full_ckpt(state, val_emos=float(\"nan\")):\n",
852
+ " torch.save({\"wavlm\": state[\"wavlm\"], \"heads\": state[\"heads\"], \"enc\": state.get(\"enc\"),\n",
853
+ " \"USE_MAMBA\": USE_MAMBA, \"emos_mu\": emos_mu, \"emos_sd\": emos_sd,\n",
854
+ " \"vad_mu\": vad_mu, \"vad_sd\": vad_sd, \"WAVLM_DIM\": WAVLM_DIM, \"AUD_DIM\": AUD_DIM,\n",
855
+ " \"Z_DIM\": Z_DIM, \"UNFREEZE_TOP_LAYERS\": UNFREEZE_TOP_LAYERS,\n",
856
+ " \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
857
+ "\n",
858
+ "def snapshot():\n",
859
+ " s = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
860
+ " \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
861
+ " if USE_MAMBA:\n",
862
+ " s[\"enc\"] = {k: v.cpu().clone() for k, v in enc.state_dict().items()}\n",
863
+ " return s\n",
864
+ "\n",
865
+ "# RESUME: init best = điểm VAL của ckpt hiện tại → chỉ ghi đè nếu train tiếp TỐT HƠN (không sợ tụt)\n",
866
+ "if RESUME and resume_ckpt is not None:\n",
867
+ " m0 = evaluate(); best = mean_srcc(m0); best_state = snapshot(); bad = 0\n",
868
+ " print(f\"📍 RESUME — checkpoint hiện tại: mean SRCC={best:.4f} | \"\n",
869
+ " + \" \".join(f\"{k}={m0[k]:.3f}\" for k in ['emos', 'val', 'aro', 'dom'] if k in m0))\n",
870
+ "else:\n",
871
+ " m0 = None\n",
872
+ " best, best_state, bad = -1e9, None, 0\n",
873
+ "for ep in range(1, EPOCHS + 1):\n",
874
+ " set_mode(True)\n",
875
+ " opt.zero_grad(); run = 0.0; nb = 0\n",
876
+ " for step, b in enumerate(tqdm(tr_loader, desc=f\"epoch {ep}\")):\n",
877
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
878
+ " emos_p, cat_l, vad_p = forward_batch(b)\n",
879
+ " loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
880
+ " scaler.scale(loss).backward()\n",
881
+ " if (step + 1) % ACCUM == 0:\n",
882
+ " scaler.step(opt); scaler.update(); opt.zero_grad()\n",
883
+ " run += loss.item() * ACCUM; nb += 1\n",
884
+ " m = evaluate(); sc = mean_srcc(m)\n",
885
+ " msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
886
+ " print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
887
+ " if sc > best:\n",
888
+ " best = sc; bad = 0\n",
889
+ " best_state = snapshot()\n",
890
+ " save_full_ckpt(best_state, m[\"emos\"])\n",
891
+ " print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\")\n",
892
+ " else:\n",
893
+ " bad += 1\n",
894
+ " if bad >= PATIENCE:\n",
895
+ " print(f\"Early stop ở epoch {ep}.\"); break\n",
896
+ "\n",
897
+ "if best_state:\n",
898
+ " wavlm.load_state_dict(best_state[\"wavlm\"]); heads.load_state_dict(best_state[\"heads\"])\n",
899
+ " if USE_MAMBA:\n",
900
+ " enc.load_state_dict(best_state[\"enc\"])\n",
901
+ "final = evaluate()\n",
902
+ "if RESUME and m0 is not None:\n",
903
+ " print(f\"\\n🔁 RESUME: mean SRCC ckpt {mean_srcc(m0):.4f} → sau train tiếp {mean_srcc(final):.4f} \"\n",
904
+ " + (\"🚀 cải thiện → đã ghi đè ckpt\" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else \"➖ không cải thiện (giữ best cũ)\"))\n",
905
+ "print(f\"\\n✅ VAL (nội bộ) — exp15 (Mamba={'ON' if USE_MAMBA else 'OFF'}):\")\n",
906
+ "print(f\" EMOS={final['emos']:.4f} (exp08 {EXP08['emos']})\")\n",
907
+ "if HAS_VAD:\n",
908
+ " print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} \"\n",
909
+ " f\"(exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
910
+ "warn = [f\"EMOS {final['emos']:.3f}<{EXP08['emos']}\"] if final[\"emos\"] < EXP08[\"emos\"] - 0.005 else []\n",
911
+ "if HAS_VAD:\n",
912
+ " warn += [f\"{t.upper()} {final[t]:.3f}<{EXP08[t]}\" for t in [\"val\", \"aro\", \"dom\"] if final[t] < EXP08[t] - 0.005]\n",
913
+ "print(\" ⚠️ Mamba head CHƯA thắng exp08 ở:\", \"; \".join(warn), \"(vẫn là kết quả cho paper)\" if warn else \"\")\n",
914
+ "if not warn:\n",
915
+ " print(\" ✅ Mamba head thắng/ngang exp08 ở mọi cột → temporal modeling có ích!\")\n",
916
+ "save_full_ckpt(best_state if best_state else\n",
917
+ " {\"wavlm\": wavlm.state_dict(), \"heads\": heads.state_dict(),\n",
918
+ " \"enc\": enc.state_dict() if USE_MAMBA else None}, final[\"emos\"])\n",
919
+ "print(f\"✅ Đã lưu {CKPT_PATH} (CÓ backbone + Mamba + heads). NHỚ Save Version!\")"
920
+ ]
921
+ },
922
+ {
923
+ "cell_type": "markdown",
924
+ "id": "9c748af2",
925
+ "metadata": {},
926
+ "source": [
927
+ "## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc exp15; QMOS mượn exp07/UTMOSv2)"
928
+ ]
929
+ },
930
+ {
931
+ "cell_type": "code",
932
+ "execution_count": null,
933
+ "id": "92d43e56",
934
+ "metadata": {
935
+ "lines_to_next_cell": 1
936
+ },
937
+ "outputs": [],
938
+ "source": [
939
+ "def list_dev():\n",
940
+ " with open(DEV_SCP) as f:\n",
941
+ " return [ln.strip() for ln in f if ln.strip()]\n",
942
+ "\n",
943
+ "dev_names = list_dev()\n",
944
+ "if LIMIT_DEV:\n",
945
+ " dev_names = dev_names[:LIMIT_DEV]\n",
946
+ "dev_stems = [stem(n) for n in dev_names]\n",
947
+ "print(\"DEV:\", len(dev_names), \"mẫu\")\n",
948
+ "aud_dev = extract_audeering(dev_stems, \"dev\")\n",
949
+ "\n",
950
+ "def load_exp07_qmos():\n",
951
+ " if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
952
+ " import csv\n",
953
+ " d = {}\n",
954
+ " with open(EXP07_ANSWER) as f:\n",
955
+ " for row in csv.DictReader(f):\n",
956
+ " d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
957
+ " print(f\"✅ Mượn QMOS exp07 ({EXP07_ANSWER}): {len(d)//2} wav\")\n",
958
+ " return d\n",
959
+ " return None\n",
960
+ "\n",
961
+ "qmos_map = load_exp07_qmos()\n",
962
+ "if qmos_map is None:\n",
963
+ " print(\"ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).\")\n",
964
+ " pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
965
+ " import utmosv2\n",
966
+ " v2 = utmosv2.create_model(pretrained=True)\n",
967
+ " qmos_map = {}\n",
968
+ " for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
969
+ " wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
970
+ " if not os.path.exists(wav):\n",
971
+ " continue\n",
972
+ " out = v2.predict(input_path=wav)\n",
973
+ " qmos_map[n] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
974
+ " del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
975
+ "\n",
976
+ "@torch.no_grad()\n",
977
+ "def predict_emotion(sid):\n",
978
+ " wave = load_wav(sid)\n",
979
+ " if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
980
+ " return None\n",
981
+ " set_mode(False)\n",
982
+ " iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
983
+ " am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
984
+ " tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
985
+ " with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
986
+ " fw = wavlm_branch(iv, am)\n",
987
+ " feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw\n",
988
+ " emos_p, cat_l, vad_p = heads(feat, tgt)\n",
989
+ " emos = float(emos_p.item()) * emos_sd + emos_mu\n",
990
+ " cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
991
+ " vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
992
+ " return emos, cat5, vad3\n",
993
+ "\n",
994
+ "def fmt_cat(p5):\n",
995
+ " return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
996
+ "\n",
997
+ "def build_answer(out_path):\n",
998
+ " n_real = n_def = 0\n",
999
+ " with open(out_path, \"w\") as f:\n",
1000
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
1001
+ " for name in tqdm(dev_names, desc=\"answer\"):\n",
1002
+ " sid = stem(name)\n",
1003
+ " pr = predict_emotion(sid)\n",
1004
+ " if pr is None:\n",
1005
+ " emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
1006
+ " else:\n",
1007
+ " emos, cat5, vad3 = pr; n_real += 1\n",
1008
+ " qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
1009
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
1010
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
1011
+ "\n",
1012
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
1013
+ "build_answer(answer_path)"
1014
+ ]
1015
+ },
1016
+ {
1017
+ "cell_type": "markdown",
1018
+ "id": "20ec4343",
1019
+ "metadata": {},
1020
+ "source": [
1021
+ "## 8. Validate + đóng zip"
1022
+ ]
1023
+ },
1024
+ {
1025
+ "cell_type": "code",
1026
+ "execution_count": null,
1027
+ "id": "e289ea27",
1028
+ "metadata": {},
1029
+ "outputs": [],
1030
+ "source": [
1031
+ "def validate(path):\n",
1032
+ " import csv\n",
1033
+ " with open(path) as f:\n",
1034
+ " rows = list(csv.reader(f))\n",
1035
+ " assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0] and \"EMOS\" in rows[0], \"Header sai\"\n",
1036
+ " for i, r in enumerate(rows[1:], 2):\n",
1037
+ " assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
1038
+ " print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
1039
+ "\n",
1040
+ "validate(answer_path)\n",
1041
+ "os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp15_mamba-emotion.zip answer.txt \"\n",
1042
+ " f\"&& unzip -l submission_track2_exp15_mamba-emotion.zip\")\n",
1043
+ "print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp15_mamba-emotion.zip\"))"
1044
+ ]
1045
+ },
1046
+ {
1047
+ "cell_type": "markdown",
1048
+ "id": "7aeeb9ea",
1049
+ "metadata": {},
1050
+ "source": [
1051
+ "## Ghi chú\n",
1052
+ "- **🔁 RESUME (train tiếp, không train lại từ đầu):** Add Input dataset chứa `ft_mamba_emotion_full.pt` của lần\n",
1053
+ " chạy trước (hoặc để nó nằm sẵn trong `/kaggle/working` khi chạy nối phiên) → notebook tự dò & train tiếp.\n",
1054
+ " `EPOCHS` lúc này là **số epoch train THÊM**. Val chững → đặt `RESUME_LR_SCALE=0.5`. Muốn ép train mới: `RESUME_CKPT=\"—\"`\n",
1055
+ " (đường dẫn không tồn tại) hoặc xóa ckpt khỏi input. ⚠️ `USE_MAMBA` phải KHỚP ckpt (code sẽ cảnh báo nếu lệch).\n",
1056
+ "- **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` → kiểm 1 epoch không OOM / không CheckpointError; rồi đặt `None`.\n",
1057
+ "- **Ablation chính cho paper:** chạy `USE_MAMBA=True` vs `USE_MAMBA=False` (=exp08) → so EMOS/VAL/ARO/DOM nội bộ\n",
1058
+ " → trả lời \"Mamba temporal head có hơn mean-pooling không?\".\n",
1059
+ "- **OOM / quá chậm trên T4 (nhất là khi dùng Mamba thuần PyTorch):** giảm theo thứ tự\n",
1060
+ " `MAX_SECONDS` (6→5) → `MAMBA_LAYERS` (2→1) → `UNFREEZE_TOP_LAYERS` (6→4) → `BATCH` (2→1, tăng `ACCUM`).\n",
1061
+ " Hoặc thử cài `mamba-ssm causal-conv1d` (nhanh + nhẹ RAM hơn nhiều) — code tự dùng nếu import được.\n",
1062
+ "- **Ranking loss (`RANK_LAMBDA`):** thêm pairwise ranking cho 4 cột SRCC (emos/val/aro/dom) → khớp metric\n",
1063
+ " UTT-SRCC hơn MSE. ⚠️ **Điểm yếu:** ranking tính trên các cặp TRONG 1 mini-batch; `BATCH=2` → mỗi forward\n",
1064
+ " chỉ có 1 cặp → tín hiệu YẾU. Muốn ranking mạnh: tăng `BATCH` (4→8 nếu VRAM chịu được). Ở các exp head\n",
1065
+ " ĐÓNG BĂNG (exp06/07, BATCH=64) ranking mạnh hơn nhiều. A/B `RANK_LAMBDA=0` vs `0.3` → bảng ablation cho paper.\n",
1066
+ "- **QMOS:** Add Input answer.txt exp07 vào `/kaggle/input/exp07-answer/answer.txt` để mượn QMOS 0.548;\n",
1067
+ " không có thì tự chấm UTMOSv2 (cần Internet On).\n",
1068
+ "- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp15)."
1069
+ ]
1070
+ }
1071
+ ],
1072
+ "metadata": {
1073
+ "jupytext": {
1074
+ "cell_metadata_filter": "-all",
1075
+ "main_language": "python",
1076
+ "notebook_metadata_filter": "-all"
1077
+ }
1078
+ },
1079
+ "nbformat": 4,
1080
+ "nbformat_minor": 5
1081
+ }
track2/exp15_wavlm_mamba_emotion_pipeline.py ADDED
@@ -0,0 +1,920 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — exp15 (WavLM FINE-TUNE + MAMBA head cho 5 cột cảm xúc) — Kaggle
3
+ #
4
+ # **Ý tưởng:** exp08 fine-tune WavLM nhưng vẫn **mean-pool** đặc trưng theo thời gian → 1 vector/wav
5
+ # (vứt bỏ động lực thời gian: lên/xuống giọng, ngắt quãng, run giọng — rất quan trọng cho cảm xúc).
6
+ # exp15 **thay mean-pool bằng MAMBA head** (bộ mã hóa chuỗi học được, độ phức tạp tuyến tính) → kỳ vọng
7
+ # nắm temporal dynamics tốt hơn. Tham khảo: MambaRate (AudioMOS 2025, arXiv:2507.12090).
8
+ #
9
+ # ## Kiến trúc (= exp08 đổi đúng 1 chỗ: pool → Mamba)
10
+ # ```
11
+ # wav ─► WavLM-large (SAILER warm-start, mở băng N lớp, TRAINABLE) ─► hidden states (B, T, 1024)
12
+ # │ (KHÔNG mean-pool)
13
+ # MambaEncoder (proj 1024→d, Mamba×L 2 chiều,
14
+ # attentive-pool có mask) ─► z (B, Z_DIM)
15
+ # │
16
+ # (tùy chọn) audeering MSP-dim FROZEN [emb|vad3] ──concat──► TRUNK ─┬─► EMOS (+ one-hot target)
17
+ # ├─► CAT (5, soft-CE)
18
+ # └─► VAD (3)
19
+ # QMOS: KHÔNG train ở đây → mượn cột QMOS exp07 (0.548) hoặc UTMOSv2.
20
+ # ```
21
+ # - **Cờ `USE_MAMBA`:** True = Mamba head; False = quay về `masked_mean` = **đúng exp08**
22
+ # → đây là **ablation chính cho paper** ("Mamba temporal head vs mean-pooling", CÙNG backbone fine-tune).
23
+ #
24
+ # ## ⚠️ Đánh đổi / gotcha (đã phòng trong code)
25
+ # - Fine-tune = chạy lại WavLM mỗi epoch (không cache được) → **lần đầu BẮT BUỘC `LIMIT_TRAIN=300`, `LIMIT_DEV=20`**.
26
+ # - `mamba-ssm` khó cài Kaggle → tự fallback **Mamba thuần PyTorch** (vòng-lặp-thời-gian). Bản này khi fine-tune
27
+ # **chậm + nặng RAM hơn** → cap `MAX_SECONDS=6`, `BATCH=2`. OOM/quá chậm → hạ MAX_SECONDS→5, MAMBA_LAYERS→1,
28
+ # hoặc thử cài `mamba-ssm causal-conv1d`.
29
+ # - `layerdrop=0` (tránh CheckpointError khi grad-ckpt — bài học exp12). KHÔNG đụng numpy (lệch ABI).
30
+ # - **Checkpoint lưu CẢ backbone + Mamba + heads mỗi best** (bài học exp08 mất backbone).
31
+ #
32
+ # ## 🔁 RESUME (yêu cầu của user): "nếu có checkpoint thì train TIẾP, không train lại từ đầu"
33
+ # - Notebook **tự dò** `ft_mamba_emotion_full.pt` trong `/kaggle/input` và `/kaggle/working` (hoặc trỏ tay `RESUME_CKPT`).
34
+ # - Có ckpt đủ (backbone WavLM + Mamba enc + heads) → **nạp lại trạng thái + thống kê chuẩn hóa TỪ ckpt** rồi train tiếp;
35
+ # `best` khởi tạo = điểm VAL của ckpt → chỉ ghi đè khi train tiếp **TỐT HƠN** (không sợ tụt). `RESUME_LR_SCALE<1` để hạ LR.
36
+ # - KHÔNG có ckpt → train mới từ SAILER warm-start như cũ (hành vi exp15 gốc giữ nguyên).
37
+ #
38
+ # **Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input dataset Track 2 (+ Add Input checkpoint cũ nếu muốn resume)
39
+ # → sửa `DATA_ROOT` → Run All.
40
+
41
+ # %% [markdown]
42
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
43
+
44
+ # %%
45
+ import os, glob
46
+
47
+ # ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets/train.csv + wav/ + metadata.csv) ──
48
+ def find_data_root(search_root="/kaggle/input"):
49
+ cands = []
50
+ for train_csv in glob.glob(os.path.join(search_root, "**", "sets", "train.csv"), recursive=True):
51
+ root = os.path.dirname(os.path.dirname(train_csv)) # .../<root>/sets/train.csv → <root>
52
+ score = os.path.isdir(os.path.join(root, "wav")) + os.path.exists(os.path.join(root, "metadata.csv"))
53
+ cands.append((score, root))
54
+ cands.sort(reverse=True) # ưu tiên thư mục đủ wav + metadata
55
+ return cands
56
+
57
+ _cands = find_data_root("/kaggle/input")
58
+ if _cands:
59
+ print("🔎 Ứng viên DATA_ROOT (điểm cao = đủ wav+metadata):")
60
+ for sc, r in _cands:
61
+ print(f" [{sc}/2] {r}")
62
+ DATA_ROOT = _cands[0][1]
63
+ print(f"👉 Tự chọn DATA_ROOT = {DATA_ROOT}")
64
+ else:
65
+ DATA_ROOT = "/kaggle/input/datasets/minhtoan2" # dự phòng — sửa tay nếu auto-dò không thấy
66
+ print(f"❌ Không thấy sets/train.csv trong /kaggle/input → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)")
67
+
68
+ WAV_DIR = f"{DATA_ROOT}/wav"
69
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
70
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
71
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
72
+
73
+ OUT_DIR = "/kaggle/working"
74
+ CACHE_DIR = "/kaggle/working/ft_cache" # cache audeering (.npz) — WavLM/Mamba KHÔNG cache (đang train)
75
+ os.makedirs(CACHE_DIR, exist_ok=True)
76
+
77
+ # (Tùy chọn) tái dùng cache audeering cũ (read-only /kaggle/input → copy sang working)
78
+ # Dataset cache_exp8: aud_*.npz nằm trong thư mục con archive/ → quét ĐỆ QUY để bắt mọi vị trí.
79
+ CACHE_INPUT = "/kaggle/input/cache-exp8" # << SỬA slug (dataset cache_exp8 → Kaggle đổi _→-); hoặc ""
80
+ if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
81
+ import shutil
82
+ _n = 0
83
+ for _fp in glob.glob(os.path.join(CACHE_INPUT, "**", "aud_*.npz"), recursive=True):
84
+ shutil.copy(_fp, os.path.join(CACHE_DIR, os.path.basename(_fp))); _n += 1
85
+ print(f"📦 Tái dùng cache: copy {_n} file aud_*.npz (quét đệ quy {CACHE_INPUT})")
86
+ else:
87
+ print(f"ℹ️ Không thấy CACHE_INPUT={CACHE_INPUT} → sẽ tự trích audeering.")
88
+
89
+ # Mượn cột QMOS exp07 (0.548). Trỏ answer.txt exp07 nếu có; không thì UTMOSv2.
90
+ EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn)
91
+
92
+ # ── Cờ Mamba (ablation chính) ────────────────────────────────────────────────
93
+ USE_MAMBA = True # True = Mamba head; False = mean-pool = ĐÚNG exp08
94
+
95
+ # ── Siêu tham số Mamba head ──────────────────────────────────────────────────
96
+ MAMBA_DMODEL = 256
97
+ MAMBA_LAYERS = 2
98
+ MAMBA_DSTATE = 16
99
+ BIDIRECTIONAL = True
100
+ Z_DIM = 256 # chiều vector ra sau attentive-pool, thay cho emb WavLM mean-pool
101
+
102
+ # ── Fine-tune / siêu tham số (kế thừa exp08) ─────────────────────────────────
103
+ DEVICE = "cuda"
104
+ SR = 16000
105
+ MAX_SECONDS = 6 # giảm từ 8 (exp08) vì Mamba backprop-through-time nặng RAM hơn
106
+ UNFREEZE_TOP_LAYERS = 6 # số lớp Transformer trên cùng được train (0 = freeze hết)
107
+ TRUNK_HIDDEN = 512
108
+ HEAD_HIDDEN = 128
109
+ DROPOUT = 0.3
110
+ LR_BACKBONE = 1e-5
111
+ LR_HEAD = 1e-3 # cho Mamba + trunk + head (train từ đầu)
112
+ WEIGHT_DECAY = 1e-5
113
+ EPOCHS = 12
114
+ PATIENCE = 3
115
+ BATCH = 2 # nhỏ (backbone to + Mamba); bù bằng ACCUM
116
+ ACCUM = 16 # effective batch = 32
117
+ VAL_FRAC = 0.10
118
+ SEED = 42
119
+ USE_AMP = True
120
+ USE_GRAD_CKPT = True
121
+ USE_AUDEERING = True
122
+ USE_UNCERTAINTY = True
123
+ RANK_LAMBDA = 0.3 # 0 = chỉ MSE (cũ). >0 = thêm pairwise ranking loss (tối ưu thẳng SRCC) cho emos/val/aro/dom
124
+ # ⚠️ ranking cần NHIỀU cặp/batch mới mạnh → BATCH nhỏ (2) thì tác dụng yếu (xem Ghi chú)
125
+
126
+ LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
127
+ LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
128
+
129
+ # ── RESUME — train TIẾP từ checkpoint, KHÔNG train lại từ đầu ─────────────────
130
+ # Để "" + auto-dò: nếu thấy `ft_mamba_emotion_full.pt` (đủ backbone+Mamba+heads) trong /kaggle/input
131
+ # hoặc /kaggle/working → nạp lại rồi train tiếp. Trỏ tay RESUME_CKPT nếu muốn chỉ định file cụ thể.
132
+ RESUME_CKPT = "" # << "" = auto-dò; hoặc "/kaggle/input/<slug>/ft_mamba_emotion_full.pt"
133
+ RESUME_LR_SCALE = 1.0 # <1.0 hạ LR khi train tiếp (vd 0.5 nếu val đã chững)
134
+
135
+ def find_resume_ckpt(explicit):
136
+ """Tìm checkpoint exp15 để train tiếp. Ưu tiên đường dẫn user trỏ; không thì auto-dò.
137
+ Khớp cả tên bị Kaggle/Windows thêm hậu tố trùng, vd 'ft_mamba_emotion_full (2).pt'."""
138
+ if explicit and os.path.exists(explicit):
139
+ return explicit
140
+ for base in ["/kaggle/input", "/kaggle/working"]:
141
+ hits = sorted(glob.glob(os.path.join(base, "**", "ft_mamba_emotion_full*.pt"), recursive=True))
142
+ if hits:
143
+ return hits[0]
144
+ return ""
145
+
146
+ RESUME_CKPT = find_resume_ckpt(RESUME_CKPT)
147
+ RESUME = bool(RESUME_CKPT)
148
+ print("🔁 RESUME =", RESUME, ("→ train tiếp từ: " + RESUME_CKPT) if RESUME else "(không thấy ckpt → train MỚI từ đầu)")
149
+
150
+ # Mốc so (exp08 fine-tune + mean-pool — đối thủ trực tiếp của Mamba head)
151
+ EXP08 = {"emos": 0.811, "val": 0.659, "aro": 0.793, "dom": 0.751}
152
+
153
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
154
+
155
+ _EMO_ALIAS = {
156
+ "angry": "angry", "anger": "angry",
157
+ "happy": "happy", "happiness": "happy", "joy": "happy",
158
+ "neutral": "neutral", "calm": "neutral",
159
+ "sad": "sad", "sadness": "sad",
160
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
161
+ }
162
+
163
+ def norm_emotion(label):
164
+ key = str(label).strip().lower()
165
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
166
+
167
+ def stem(p):
168
+ return os.path.splitext(os.path.basename(str(p)))[0]
169
+
170
+ print("USE_MAMBA =", USE_MAMBA, "(False → ra đúng exp08)")
171
+ print("DATA_ROOT:", DATA_ROOT)
172
+ for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
173
+ print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
174
+ print(f"Fine-tune: mở băng {UNFREEZE_TOP_LAYERS} lớp · BATCH {BATCH}×ACCUM {ACCUM} · MAX {MAX_SECONDS}s")
175
+
176
+ # %% [markdown]
177
+ # ## 1. Cài đặt + tải code SAILER (clone + sys.path)
178
+
179
+ # %%
180
+ import sys, subprocess
181
+
182
+ def pip_install(*pkgs):
183
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
184
+
185
+ pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile",
186
+ "scipy", "scikit-learn", "pandas", "tqdm")
187
+
188
+ # Cài kernel CUDA Mamba (nhanh + nhẹ RAM hơn bản thuần PyTorch nhiều). Build hay lỗi/chậm trên Kaggle
189
+ # → bọc try/except: lỗi thì BỎ QUA, mục 6a tự fallback Mamba thuần PyTorch. KHÔNG để chết notebook.
190
+ INSTALL_MAMBA_SSM = True # đặt False nếu muốn BỎ QUA, dùng thẳng Mamba thuần PyTorch
191
+ if INSTALL_MAMBA_SSM and USE_MAMBA:
192
+ try:
193
+ # --no-build-isolation cho CẢ HAI → dùng torch+CUDA sẵn có của Kaggle để biên dịch (đừng kéo torch khác).
194
+ # Cần ninja để build nhanh. -q ẩn log nên bước này có thể "treo" vài phút khi đang compile — bình thường.
195
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q", "ninja"], check=True)
196
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q",
197
+ "--no-build-isolation", "causal-conv1d>=1.2.0"], check=True)
198
+ subprocess.run([sys.executable, "-m", "pip", "install", "-q",
199
+ "--no-build-isolation", "mamba-ssm"], check=True)
200
+ print("✅ Cài mamba-ssm + causal-conv1d xong (sẽ dùng kernel CUDA nếu import được).")
201
+ except Exception as e:
202
+ print("⚠️ Cài mamba-ssm thất bại:", repr(e), "→ dùng Mamba thuần PyTorch (chậm hơn).")
203
+ print(" ℹ️ Vẫn chạy bình thường. Nếu chạy THẬT (LIMIT=None) quá chậm → xem Ghi chú cuối notebook.")
204
+
205
+ REPO_DIR = "/kaggle/working/vox-profile-release"
206
+ if not os.path.exists(REPO_DIR):
207
+ subprocess.run(["git", "clone", "--depth", "1",
208
+ "https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
209
+ if REPO_DIR not in sys.path:
210
+ sys.path.insert(0, REPO_DIR)
211
+
212
+ # %% [markdown]
213
+ # ## 2. Nạp SAILER → lấy backbone WavLM bên trong để FINE-TUNE (warm-start)
214
+
215
+ # %%
216
+ import torch
217
+ import torch.nn as nn
218
+ import torch.nn.functional as F
219
+
220
+ device = DEVICE if torch.cuda.is_available() else "cpu"
221
+ print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
222
+
223
+ def find_hf_backbone(module):
224
+ """Tìm submodule kiểu HF WavLM backbone: có .feature_extractor và .encoder.layers."""
225
+ cands = []
226
+ for name, m in module.named_modules():
227
+ enc = getattr(m, "encoder", None)
228
+ if getattr(m, "feature_extractor", None) is not None and enc is not None \
229
+ and getattr(enc, "layers", None) is not None:
230
+ cands.append((name, m))
231
+ if not cands:
232
+ return None, None
233
+ cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
234
+ return cands[0]
235
+
236
+ wavlm = None
237
+ try:
238
+ from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
239
+ _wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
240
+ name, wavlm = find_hf_backbone(_wrapper)
241
+ if wavlm is not None:
242
+ print(f"✅ Warm-start SAILER: backbone WavLM tại '.{name}' "
243
+ f"({sum(p.numel() for p in wavlm.parameters())/1e6:.0f}M params)")
244
+ else:
245
+ print("⚠️ Không tìm thấy backbone HF trong wrapper SAILER → fallback WavLM trắng.")
246
+ except Exception as e:
247
+ print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
248
+
249
+ if wavlm is None:
250
+ from transformers import WavLMModel
251
+ wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
252
+ print("ℹ️ Fallback: microsoft/wavlm-large (KHÔNG warm-start SAILER).")
253
+
254
+ wavlm = wavlm.to(device)
255
+ WAVLM_DIM = int(wavlm.config.hidden_size)
256
+ wavlm.config.layerdrop = 0.0 # ⚠️ tránh CheckpointError khi grad-ckpt (bài học exp12)
257
+
258
+ # ── RESUME: nạp trọng số backbone đã fine-tune từ checkpoint (đè lên warm-start SAILER) ──
259
+ resume_ckpt = None
260
+ if RESUME:
261
+ resume_ckpt = torch.load(RESUME_CKPT, map_location="cpu", weights_only=False) # ckpt có numpy → cần False
262
+ assert "wavlm" in resume_ckpt, ("❌ Checkpoint KHÔNG có 'wavlm' (backbone) → không resume được. "
263
+ "Dùng file ft_mamba_emotion_full.pt do exp15 lưu.")
264
+ if resume_ckpt.get("USE_MAMBA", USE_MAMBA) != USE_MAMBA:
265
+ print(f" ⚠️ ckpt USE_MAMBA={resume_ckpt.get('USE_MAMBA')} ≠ cấu hình hiện tại {USE_MAMBA} → kiến trúc LỆCH! "
266
+ "Đặt USE_MAMBA cho khớp ckpt.")
267
+ miss, unexp = wavlm.load_state_dict(resume_ckpt["wavlm"], strict=False)
268
+ print(f"🔁 RESUME load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0). keys ckpt:", list(resume_ckpt.keys()))
269
+ if len(miss) > 20 or len(unexp) > 20:
270
+ print(" ⚠️ Lệch key nhiều → kiểm tra UNFREEZE_TOP_LAYERS / backbone có khớp ckpt không.")
271
+
272
+ # ── Đóng băng partial: feature-extractor + tất cả trừ UNFREEZE_TOP_LAYERS lớp trên ──
273
+ for p in wavlm.parameters():
274
+ p.requires_grad = False
275
+ enc_layers = wavlm.encoder.layers
276
+ n_layers = len(enc_layers)
277
+ for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:
278
+ for p in layer.parameters():
279
+ p.requires_grad = True
280
+ n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)
281
+ print(f"WavLM: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {WAVLM_DIM})")
282
+
283
+ if USE_GRAD_CKPT:
284
+ wavlm.gradient_checkpointing_enable()
285
+ if hasattr(wavlm, "enable_input_require_grads"):
286
+ wavlm.enable_input_require_grads()
287
+
288
+ def frame_mask(T, attn_mask):
289
+ """attn_mask (B, Lwav) → frame-mask (B, T) bool (True=frame thật). Khớp downsample của WavLM."""
290
+ if attn_mask is None:
291
+ return torch.ones((1, T), dtype=torch.bool, device=device)
292
+ try:
293
+ fm = wavlm._get_feature_vector_attention_mask(T, attn_mask)
294
+ return fm.bool()
295
+ except Exception:
296
+ return torch.ones((attn_mask.shape[0], T), dtype=torch.bool, device=attn_mask.device)
297
+
298
+ def masked_mean(hidden, attn_mask):
299
+ """Mean-pool theo thời gian bỏ pad (đường exp08 khi USE_MAMBA=False)."""
300
+ if attn_mask is None:
301
+ return hidden.mean(dim=1)
302
+ fm = frame_mask(hidden.shape[1], attn_mask).unsqueeze(-1).to(hidden.dtype)
303
+ return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
304
+
305
+ # %% [markdown]
306
+ # ## 3. Nạp audeering MSP-dim (FROZEN) — đặc trưng phụ (như exp08)
307
+
308
+ # %%
309
+ AUD_DIM = 0
310
+ aud_backbone = aud_head = aud_proc = None
311
+ if USE_AUDEERING:
312
+ from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
313
+ from huggingface_hub import hf_hub_download
314
+ AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
315
+ aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
316
+ aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
317
+ aud_backbone = Wav2Vec2Model(aud_cfg)
318
+ try:
319
+ _sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
320
+ hf_hub_download(AUD_NAME, "model.safetensors"))
321
+ except Exception:
322
+ _sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
323
+ bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
324
+ aud_backbone.load_state_dict(bb_sd, strict=False)
325
+ _hid = _sd["classifier.dense.weight"].shape[0]
326
+ _out = _sd["classifier.out_proj.weight"].shape[0]
327
+ aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
328
+ aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
329
+ aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
330
+ aud_backbone = aud_backbone.to(device).eval()
331
+ aud_head = aud_head.to(device).eval()
332
+ AUD_DIM = _hid + 3
333
+ print(f"✅ audeering frozen (đặc trưng phụ {AUD_DIM}-D = emb {_hid} + vad 3)")
334
+
335
+ # %%
336
+ import numpy as np
337
+ import librosa
338
+ from tqdm.auto import tqdm
339
+
340
+ def load_wav(name_or_stem):
341
+ p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(
342
+ WAV_DIR, name_or_stem if str(name_or_stem).endswith(".wav") else str(name_or_stem) + ".wav")
343
+ if not os.path.exists(p):
344
+ return None
345
+ wave, _ = librosa.load(p, sr=SR, mono=True)
346
+ return wave[: MAX_SECONDS * SR].astype(np.float32)
347
+
348
+ @torch.no_grad()
349
+ def extract_audeering(stems, tag):
350
+ if not USE_AUDEERING:
351
+ return {}
352
+ cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
353
+ store = {}
354
+ if os.path.exists(cache_path):
355
+ z = np.load(cache_path, allow_pickle=True)
356
+ store = {k: z[k] for k in z.files}
357
+ print(f"[aud/{tag}] nạp cache: {len(store)}")
358
+ todo = [s for s in stems if s not in store]
359
+ for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
360
+ wave = load_wav(s)
361
+ if wave is None:
362
+ continue
363
+ x = aud_proc(wave, sampling_rate=SR).input_values[0]
364
+ x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
365
+ h = aud_backbone(x)[0].mean(dim=1)
366
+ out = aud_head(h)[0].cpu().numpy() # [arousal, dominance, valence] ∈[0,1]
367
+ vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
368
+ store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
369
+ if (i + 1) % 500 == 0:
370
+ np.savez(cache_path, **store)
371
+ if todo:
372
+ np.savez(cache_path, **store)
373
+ return store
374
+
375
+ # %% [markdown]
376
+ # ## 4. Đọc & gộp nhãn theo wavID (EMOS / VAD / CAT) — như exp08
377
+
378
+ # %%
379
+ import pandas as pd
380
+
381
+ def load_target_emotions():
382
+ tgt = {}
383
+ with open(METADATA_CSV, encoding="utf-8") as f:
384
+ for ln in f:
385
+ parts = ln.strip().split("|")
386
+ if len(parts) >= 2:
387
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
388
+ return tgt
389
+
390
+ def _col(cols_map, *names, df=None, default_idx=None):
391
+ for n in names:
392
+ if n in cols_map:
393
+ return cols_map[n]
394
+ return list(df.columns)[default_idx] if default_idx is not None else None
395
+
396
+ def parse_emocat_votes(cell):
397
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
398
+ for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
399
+ e = norm_emotion(tok)
400
+ if e in EMOTIONS5:
401
+ v[EMOTIONS5.index(e)] += 1.0
402
+ return v
403
+
404
+ def load_train_labels():
405
+ df = pd.read_csv(TRAIN_CSV, sep="|")
406
+ cols = {c.lower().strip(): c for c in df.columns}
407
+ wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
408
+ emos_col = _col(cols, "emos", "emo", "emomos")
409
+ val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
410
+ cat_col = _col(cols, "emocat", "cat", "emotion")
411
+ assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
412
+ df["_stem"] = df[wav_col].map(stem)
413
+ rows = []
414
+ for sid, g in df.groupby("_stem"):
415
+ rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
416
+ rec["val"] = float(g[val_col].mean()) if val_col else np.nan
417
+ rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
418
+ rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
419
+ votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
420
+ if cat_col:
421
+ for cell in g[cat_col]:
422
+ votes += parse_emocat_votes(cell)
423
+ s = votes.sum()
424
+ cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
425
+ for i in range(len(EMOTIONS5)):
426
+ rec[f"cat{i}"] = float(cat[i])
427
+ rows.append(rec)
428
+ return pd.DataFrame(rows)
429
+
430
+ target_map = load_target_emotions()
431
+ train_df = load_train_labels()
432
+ HAS_VAD = bool(train_df["val"].notna().any())
433
+ print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
434
+
435
+ # %% [markdown]
436
+ # ## 5. Dataset / DataLoader (load wav theo batch — KHÔNG cache WavLM vì đang train)
437
+
438
+ # %%
439
+ from torch.utils.data import Dataset, DataLoader
440
+
441
+ train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
442
+ if LIMIT_TRAIN:
443
+ train_stems = train_stems[:LIMIT_TRAIN]
444
+ aud_tr = extract_audeering(train_stems, "train")
445
+
446
+ lab = train_df.set_index("wavID")
447
+
448
+ def _zfit(arr):
449
+ a = np.asarray(arr, dtype=np.float32)
450
+ return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
451
+
452
+ if RESUME and resume_ckpt is not None:
453
+ # QUAN TRỌNG: lấy chuẩn hóa TỪ ckpt (head đã train theo thang này) — KHÔNG tính lại để khỏi lệch thang
454
+ emos_mu = float(resume_ckpt["emos_mu"]); emos_sd = float(resume_ckpt["emos_sd"])
455
+ vad_mu = np.asarray(resume_ckpt["vad_mu"], dtype=np.float32)
456
+ vad_sd = np.asarray(resume_ckpt["vad_sd"], dtype=np.float32)
457
+ print(f"🔁 RESUME: dùng chuẩn hóa TỪ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
458
+ else:
459
+ emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
460
+ if HAS_VAD:
461
+ vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], dtype=np.float32)
462
+ vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], dtype=np.float32)
463
+ else:
464
+ vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
465
+
466
+ def onehot_target(tgt):
467
+ v = np.zeros(len(EMOTIONS5), dtype=np.float32)
468
+ if tgt in EMOTIONS5:
469
+ v[EMOTIONS5.index(tgt)] = 1.0
470
+ return v
471
+
472
+ class EmoDataset(Dataset):
473
+ def __init__(self, stems):
474
+ self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]
475
+ def __len__(self):
476
+ return len(self.stems)
477
+ def __getitem__(self, i):
478
+ s = self.stems[i]
479
+ wave = load_wav(s)
480
+ emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
481
+ if HAS_VAD:
482
+ vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
483
+ else:
484
+ vad = np.zeros(3, dtype=np.float32)
485
+ cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
486
+ aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)
487
+ return {"wave": wave, "tgt": onehot_target(target_map.get(s)), "aud": aud,
488
+ "emos": np.float32(emos), "vad": vad, "cat": cat,
489
+ "emos_raw": np.float32(lab.loc[s, "emos"]),
490
+ "vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
491
+
492
+ def collate(batch):
493
+ L = max(len(b["wave"]) for b in batch)
494
+ waves = np.zeros((len(batch), L), dtype=np.float32)
495
+ mask = np.zeros((len(batch), L), dtype=np.float32)
496
+ for i, b in enumerate(batch):
497
+ waves[i, : len(b["wave"])] = b["wave"]; mask[i, : len(b["wave"])] = 1.0
498
+ return {
499
+ "input_values": torch.from_numpy(waves), "attn_mask": torch.from_numpy(mask).long(),
500
+ "tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
501
+ "aud": torch.from_numpy(np.stack([b["aud"] for b in batch])) if USE_AUDEERING else None,
502
+ "emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
503
+ "vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
504
+ "cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
505
+ "emos_raw": np.stack([b["emos_raw"] for b in batch]),
506
+ "vad_raw": np.stack([b["vad_raw"] for b in batch]),
507
+ }
508
+
509
+ from sklearn.model_selection import train_test_split
510
+ ds = EmoDataset(train_stems)
511
+ print("Dataset hợp lệ:", len(ds), "wav")
512
+ tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
513
+ tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
514
+ va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
515
+
516
+ # %% [markdown]
517
+ # ## 6a. Khối MAMBA (thuần PyTorch, fallback nếu không có `mamba-ssm`)
518
+ # Theo "mamba-minimal" — đúng công thức selective SSM, chỉ chậm hơn kernel CUDA. Chạy trong fp32 cho ổn định.
519
+
520
+ # %%
521
+ import math
522
+
523
+ try:
524
+ from mamba_ssm import Mamba as _OfficialMamba
525
+ _HAS_MAMBA_SSM = True
526
+ print("✅ Dùng mamba-ssm (CUDA kernel)")
527
+ except Exception:
528
+ _HAS_MAMBA_SSM = False
529
+ print("ℹ️ Không có mamba-ssm → Mamba thuần PyTorch (chậm hơn khi fine-tune)")
530
+
531
+ class MambaBlockTorch(nn.Module):
532
+ def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
533
+ super().__init__()
534
+ self.d_inner = expand * d_model
535
+ self.dt_rank = math.ceil(d_model / 16)
536
+ self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
537
+ self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,
538
+ groups=self.d_inner, padding=d_conv - 1, bias=True)
539
+ self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
540
+ self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
541
+ A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
542
+ self.A_log = nn.Parameter(torch.log(A))
543
+ self.D = nn.Parameter(torch.ones(self.d_inner))
544
+ self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
545
+ self.d_state = d_state
546
+
547
+ def forward(self, x): # x: (B, L, d_model)
548
+ B, L, _ = x.shape
549
+ xin, z = self.in_proj(x).chunk(2, dim=-1)
550
+ xin = xin.transpose(1, 2)
551
+ xin = self.conv1d(xin)[..., :L].transpose(1, 2)
552
+ xin = F.silu(xin)
553
+ y = self._ssm(xin) * F.silu(z)
554
+ return self.out_proj(y)
555
+
556
+ def _ssm(self, x):
557
+ A = -torch.exp(self.A_log)
558
+ delta, Bm, Cm = torch.split(self.x_proj(x), [self.dt_rank, self.d_state, self.d_state], dim=-1)
559
+ delta = F.softplus(self.dt_proj(delta))
560
+ dA = torch.exp(delta.unsqueeze(-1) * A)
561
+ dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1)
562
+ h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
563
+ ys = []
564
+ for t in range(x.shape[1]):
565
+ h = dA[:, t] * h + dB_x[:, t]
566
+ ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1))
567
+ return torch.stack(ys, dim=1) + x * self.D
568
+
569
+ class MambaLayer(nn.Module):
570
+ def __init__(self, d_model, d_state):
571
+ super().__init__()
572
+ self.norm = nn.LayerNorm(d_model)
573
+ self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2) \
574
+ if _HAS_MAMBA_SSM else MambaBlockTorch(d_model, d_state=d_state)
575
+ def forward(self, x):
576
+ return x + self.mix(self.norm(x))
577
+
578
+ class MambaEncoder(nn.Module):
579
+ """1024 → d_model → [Mamba ×L] (2 chiều) → attentive-pool (có mask) → Z_DIM."""
580
+ def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):
581
+ super().__init__()
582
+ self.bidir = bidir
583
+ self.proj = nn.Linear(d_in, d_model)
584
+ self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
585
+ if bidir:
586
+ self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
587
+ self.attn = nn.Linear(d_model, 1)
588
+ self.out = nn.Linear(d_model, z_dim)
589
+
590
+ @staticmethod
591
+ def _run(layers, h):
592
+ for L in layers:
593
+ h = L(h)
594
+ return h
595
+
596
+ def forward(self, x, mask): # x:(B,L,1024) mask:(B,L) bool
597
+ with torch.cuda.amp.autocast(enabled=False): # SSM chạy fp32 cho ổn định
598
+ x = x.float()
599
+ h = self.proj(x)
600
+ out = self._run(self.fwd, h)
601
+ if self.bidir:
602
+ out = out + torch.flip(self._run(self.bwd, torch.flip(h, dims=[1])), dims=[1])
603
+ a = self.attn(out).squeeze(-1).masked_fill(~mask, float("-inf"))
604
+ w = torch.softmax(a, dim=1).unsqueeze(-1)
605
+ return self.out((out * w).sum(1))
606
+
607
+ # %% [markdown]
608
+ # ## 6b. Head cảm xúc + train loop (AMP + grad-accum + uncertainty weighting)
609
+
610
+ # %%
611
+ from scipy.stats import spearmanr
612
+
613
+ torch.manual_seed(SEED); np.random.seed(SEED)
614
+ N_EMO = len(EMOTIONS5)
615
+ WAVLM_BRANCH = Z_DIM if USE_MAMBA else WAVLM_DIM
616
+ TRUNK_IN = WAVLM_BRANCH + (AUD_DIM if USE_AUDEERING else 0)
617
+
618
+ enc = MambaEncoder(WAVLM_DIM, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL).to(device) \
619
+ if USE_MAMBA else None
620
+
621
+ class EmoHeads(nn.Module):
622
+ def __init__(self, d_in, trunk_h, head_h, p, n_emo):
623
+ super().__init__()
624
+ self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
625
+ nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
626
+ self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
627
+ self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
628
+ self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
629
+ def forward(self, feat, tgt):
630
+ h = self.trunk(feat)
631
+ return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
632
+
633
+ heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
634
+ print(f"Trunk input = {TRUNK_IN} (wavlm-branch {WAVLM_BRANCH} [{'Mamba' if USE_MAMBA else 'mean-pool'}] + aud {AUD_DIM if USE_AUDEERING else 0})")
635
+ if USE_MAMBA:
636
+ print(f"Mamba encoder: {sum(p.numel() for p in enc.parameters())/1e6:.2f}M param")
637
+
638
+ # ── RESUME: nạp heads (+ Mamba enc) từ checkpoint ──
639
+ if RESUME and resume_ckpt is not None:
640
+ hm, hu = heads.load_state_dict(resume_ckpt["heads"], strict=False)
641
+ print(f"🔁 RESUME load heads từ ckpt: thiếu {len(hm)} / dư {len(hu)} key (kỳ vọng 0)")
642
+ if USE_MAMBA and resume_ckpt.get("enc") is not None:
643
+ em, eu = enc.load_state_dict(resume_ckpt["enc"], strict=False)
644
+ print(f"🔁 RESUME load Mamba enc từ ckpt: thiếu {len(em)} / dư {len(eu)} key (kỳ vọng 0)")
645
+ elif USE_MAMBA:
646
+ print(" ⚠️ ckpt KHÔNG có 'enc' (Mamba) → Mamba head train lại từ đầu (chỉ resume backbone+heads).")
647
+
648
+ TASKS = ["emos", "cat", "val", "aro", "dom"]
649
+ log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
650
+ bb_params = [p for p in wavlm.parameters() if p.requires_grad]
651
+ head_params = list(heads.parameters()) + (list(enc.parameters()) if USE_MAMBA else []) \
652
+ + ([log_var] if USE_UNCERTAINTY else [])
653
+ _lr_scale = RESUME_LR_SCALE if RESUME else 1.0
654
+ opt = torch.optim.AdamW([
655
+ {"params": bb_params, "lr": LR_BACKBONE * _lr_scale},
656
+ {"params": head_params, "lr": LR_HEAD * _lr_scale},
657
+ ], weight_decay=WEIGHT_DECAY)
658
+ if RESUME and _lr_scale != 1.0:
659
+ print(f"🔁 RESUME: LR ×{_lr_scale} → backbone {LR_BACKBONE*_lr_scale:.1e} · head {LR_HEAD*_lr_scale:.1e}")
660
+ scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
661
+ mse = nn.MSELoss()
662
+
663
+ def soft_ce(logits, target_dist):
664
+ return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
665
+
666
+ def wavlm_branch(input_values, attn_mask):
667
+ out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state # (B,T,D)
668
+ if USE_MAMBA:
669
+ return enc(out, frame_mask(out.shape[1], attn_mask)) # (B, Z_DIM)
670
+ return masked_mean(out, attn_mask) # (B, D)
671
+
672
+ def forward_batch(b):
673
+ fw = wavlm_branch(b["input_values"].to(device), b["attn_mask"].to(device))
674
+ feat = torch.cat([fw, b["aud"].to(device)], dim=1) if USE_AUDEERING else fw
675
+ return heads(feat, b["tgt"].to(device))
676
+
677
+ def pairwise_rank_loss(pred, target):
678
+ """Hinge ranking trên MỌI cặp trong batch → tối ưu thẳng thứ hạng (≈ SRCC). Khả vi (backprop được).
679
+ Cần ≥2 mẫu/batch mới có cặp; batch càng to càng nhiều cặp → tín hiệu càng mạnh."""
680
+ p = pred.reshape(-1); t = target.reshape(-1)
681
+ if p.numel() < 2:
682
+ return torch.zeros((), device=p.device)
683
+ sign = torch.sign(t.unsqueeze(0) - t.unsqueeze(1)) # +1 nếu câu i ĐÁNG cao hơn câu j
684
+ diff = p.unsqueeze(0) - p.unsqueeze(1) # chênh lệch model dự đoán
685
+ return torch.relu(-sign * diff).mean() # phạt khi xếp sai thứ tự
686
+
687
+ def compute_loss(emos_p, cat_l, vad_p, b):
688
+ L = {"emos": mse(emos_p, b["emos"].to(device)), "cat": soft_ce(cat_l, b["cat"].to(device))}
689
+ if HAS_VAD:
690
+ vt = b["vad"].to(device)
691
+ L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
692
+ else:
693
+ vt = None
694
+ z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
695
+ # Ranking loss CHỈ cho các cột chấm SRCC (emos/val/aro/dom). CAT là ERR phân bố → giữ soft-CE.
696
+ if RANK_LAMBDA > 0:
697
+ L["emos"] = L["emos"] + RANK_LAMBDA * pairwise_rank_loss(emos_p, b["emos"].to(device))
698
+ if HAS_VAD:
699
+ L["val"] = L["val"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 0:1], vt[:, 0:1])
700
+ L["aro"] = L["aro"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 1:2], vt[:, 1:2])
701
+ L["dom"] = L["dom"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 2:3], vt[:, 2:3])
702
+ if USE_UNCERTAINTY:
703
+ return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
704
+ return sum(L.values())
705
+
706
+ def set_mode(train):
707
+ wavlm.train(train); heads.train(train)
708
+ if USE_MAMBA:
709
+ enc.train(train)
710
+
711
+ @torch.no_grad()
712
+ def evaluate():
713
+ set_mode(False)
714
+ P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
715
+ catP, catY = [], []
716
+ for b in va_loader:
717
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
718
+ emos_p, cat_l, vad_p = forward_batch(b)
719
+ P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
720
+ vad_p = vad_p.float().cpu().numpy()
721
+ for j, t in enumerate(["val", "aro", "dom"]):
722
+ P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
723
+ catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
724
+ out = {t: spearmanr(P[t], Y[t]).correlation for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])}
725
+ q = np.concatenate(catP); p = np.concatenate(catY)
726
+ out["cat_err"] = float(np.abs(q - p).sum(1).mean())
727
+ return out
728
+
729
+ def mean_srcc(m):
730
+ keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
731
+ return float(np.mean([m[k] for k in keys]))
732
+
733
+ CKPT_PATH = os.path.join(OUT_DIR, "ft_mamba_emotion_full.pt")
734
+ def save_full_ckpt(state, val_emos=float("nan")):
735
+ torch.save({"wavlm": state["wavlm"], "heads": state["heads"], "enc": state.get("enc"),
736
+ "USE_MAMBA": USE_MAMBA, "emos_mu": emos_mu, "emos_sd": emos_sd,
737
+ "vad_mu": vad_mu, "vad_sd": vad_sd, "WAVLM_DIM": WAVLM_DIM, "AUD_DIM": AUD_DIM,
738
+ "Z_DIM": Z_DIM, "UNFREEZE_TOP_LAYERS": UNFREEZE_TOP_LAYERS,
739
+ "val_emos": float(val_emos)}, CKPT_PATH)
740
+
741
+ def snapshot():
742
+ s = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
743
+ "heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
744
+ if USE_MAMBA:
745
+ s["enc"] = {k: v.cpu().clone() for k, v in enc.state_dict().items()}
746
+ return s
747
+
748
+ # RESUME: init best = điểm VAL của ckpt hiện tại → chỉ ghi đè nếu train tiếp TỐT HƠN (không sợ tụt)
749
+ if RESUME and resume_ckpt is not None:
750
+ m0 = evaluate(); best = mean_srcc(m0); best_state = snapshot(); bad = 0
751
+ print(f"📍 RESUME — checkpoint hiện tại: mean SRCC={best:.4f} | "
752
+ + " ".join(f"{k}={m0[k]:.3f}" for k in ['emos', 'val', 'aro', 'dom'] if k in m0))
753
+ else:
754
+ m0 = None
755
+ best, best_state, bad = -1e9, None, 0
756
+ for ep in range(1, EPOCHS + 1):
757
+ set_mode(True)
758
+ opt.zero_grad(); run = 0.0; nb = 0
759
+ for step, b in enumerate(tqdm(tr_loader, desc=f"epoch {ep}")):
760
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
761
+ emos_p, cat_l, vad_p = forward_batch(b)
762
+ loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
763
+ scaler.scale(loss).backward()
764
+ if (step + 1) % ACCUM == 0:
765
+ scaler.step(opt); scaler.update(); opt.zero_grad()
766
+ run += loss.item() * ACCUM; nb += 1
767
+ m = evaluate(); sc = mean_srcc(m)
768
+ msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
769
+ print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
770
+ if sc > best:
771
+ best = sc; bad = 0
772
+ best_state = snapshot()
773
+ save_full_ckpt(best_state, m["emos"])
774
+ print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})")
775
+ else:
776
+ bad += 1
777
+ if bad >= PATIENCE:
778
+ print(f"Early stop ở epoch {ep}."); break
779
+
780
+ if best_state:
781
+ wavlm.load_state_dict(best_state["wavlm"]); heads.load_state_dict(best_state["heads"])
782
+ if USE_MAMBA:
783
+ enc.load_state_dict(best_state["enc"])
784
+ final = evaluate()
785
+ if RESUME and m0 is not None:
786
+ print(f"\n🔁 RESUME: mean SRCC ckpt {mean_srcc(m0):.4f} → sau train tiếp {mean_srcc(final):.4f} "
787
+ + ("🚀 cải thiện → đã ghi đè ckpt" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else "➖ không cải thiện (giữ best cũ)"))
788
+ print(f"\n✅ VAL (nội bộ) — exp15 (Mamba={'ON' if USE_MAMBA else 'OFF'}):")
789
+ print(f" EMOS={final['emos']:.4f} (exp08 {EXP08['emos']})")
790
+ if HAS_VAD:
791
+ print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} "
792
+ f"(exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
793
+ warn = [f"EMOS {final['emos']:.3f}<{EXP08['emos']}"] if final["emos"] < EXP08["emos"] - 0.005 else []
794
+ if HAS_VAD:
795
+ warn += [f"{t.upper()} {final[t]:.3f}<{EXP08[t]}" for t in ["val", "aro", "dom"] if final[t] < EXP08[t] - 0.005]
796
+ print(" ⚠️ Mamba head CHƯA thắng exp08 ở:", "; ".join(warn), "(vẫn là kết quả cho paper)" if warn else "")
797
+ if not warn:
798
+ print(" ✅ Mamba head thắng/ngang exp08 ở mọi cột → temporal modeling có ích!")
799
+ save_full_ckpt(best_state if best_state else
800
+ {"wavlm": wavlm.state_dict(), "heads": heads.state_dict(),
801
+ "enc": enc.state_dict() if USE_MAMBA else None}, final["emos"])
802
+ print(f"✅ Đã lưu {CKPT_PATH} (CÓ backbone + Mamba + heads). NHỚ Save Version!")
803
+
804
+ # %% [markdown]
805
+ # ## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc exp15; QMOS mượn exp07/UTMOSv2)
806
+
807
+ # %%
808
+ def list_dev():
809
+ with open(DEV_SCP) as f:
810
+ return [ln.strip() for ln in f if ln.strip()]
811
+
812
+ dev_names = list_dev()
813
+ if LIMIT_DEV:
814
+ dev_names = dev_names[:LIMIT_DEV]
815
+ dev_stems = [stem(n) for n in dev_names]
816
+ print("DEV:", len(dev_names), "mẫu")
817
+ aud_dev = extract_audeering(dev_stems, "dev")
818
+
819
+ def load_exp07_qmos():
820
+ if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
821
+ import csv
822
+ d = {}
823
+ with open(EXP07_ANSWER) as f:
824
+ for row in csv.DictReader(f):
825
+ d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
826
+ print(f"✅ Mượn QMOS exp07 ({EXP07_ANSWER}): {len(d)//2} wav")
827
+ return d
828
+ return None
829
+
830
+ qmos_map = load_exp07_qmos()
831
+ if qmos_map is None:
832
+ print("ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).")
833
+ pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
834
+ import utmosv2
835
+ v2 = utmosv2.create_model(pretrained=True)
836
+ qmos_map = {}
837
+ for n in tqdm(dev_names, desc="UTMOSv2"):
838
+ wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
839
+ if not os.path.exists(wav):
840
+ continue
841
+ out = v2.predict(input_path=wav)
842
+ qmos_map[n] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
843
+ del v2; torch.cuda.empty_cache() if device == "cuda" else None
844
+
845
+ @torch.no_grad()
846
+ def predict_emotion(sid):
847
+ wave = load_wav(sid)
848
+ if wave is None or (USE_AUDEERING and sid not in aud_dev):
849
+ return None
850
+ set_mode(False)
851
+ iv = torch.from_numpy(wave).unsqueeze(0).to(device)
852
+ am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
853
+ tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
854
+ with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
855
+ fw = wavlm_branch(iv, am)
856
+ feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw
857
+ emos_p, cat_l, vad_p = heads(feat, tgt)
858
+ emos = float(emos_p.item()) * emos_sd + emos_mu
859
+ cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
860
+ vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
861
+ return emos, cat5, vad3
862
+
863
+ def fmt_cat(p5):
864
+ return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
865
+
866
+ def build_answer(out_path):
867
+ n_real = n_def = 0
868
+ with open(out_path, "w") as f:
869
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
870
+ for name in tqdm(dev_names, desc="answer"):
871
+ sid = stem(name)
872
+ pr = predict_emotion(sid)
873
+ if pr is None:
874
+ emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
875
+ else:
876
+ emos, cat5, vad3 = pr; n_real += 1
877
+ qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
878
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
879
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
880
+
881
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
882
+ build_answer(answer_path)
883
+
884
+ # %% [markdown]
885
+ # ## 8. Validate + đóng zip
886
+
887
+ # %%
888
+ def validate(path):
889
+ import csv
890
+ with open(path) as f:
891
+ rows = list(csv.reader(f))
892
+ assert rows[0][0] == "wav" and "QMOS" in rows[0] and "EMOS" in rows[0], "Header sai"
893
+ for i, r in enumerate(rows[1:], 2):
894
+ assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
895
+ print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
896
+
897
+ validate(answer_path)
898
+ os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp15_mamba-emotion.zip answer.txt "
899
+ f"&& unzip -l submission_track2_exp15_mamba-emotion.zip")
900
+ print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp15_mamba-emotion.zip"))
901
+
902
+ # %% [markdown]
903
+ # ## Ghi chú
904
+ # - **🔁 RESUME (train tiếp, không train lại từ đầu):** Add Input dataset chứa `ft_mamba_emotion_full.pt` của lần
905
+ # chạy trước (hoặc để nó nằm sẵn trong `/kaggle/working` khi chạy nối phiên) → notebook tự dò & train tiếp.
906
+ # `EPOCHS` lúc này là **số epoch train THÊM**. Val chững → đặt `RESUME_LR_SCALE=0.5`. Muốn ép train mới: `RESUME_CKPT="—"`
907
+ # (đường dẫn không tồn tại) hoặc xóa ckpt khỏi input. ⚠️ `USE_MAMBA` phải KHỚP ckpt (code sẽ cảnh báo nếu lệch).
908
+ # - **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` → kiểm 1 epoch không OOM / không CheckpointError; rồi đặt `None`.
909
+ # - **Ablation chính cho paper:** chạy `USE_MAMBA=True` vs `USE_MAMBA=False` (=exp08) → so EMOS/VAL/ARO/DOM nội bộ
910
+ # → trả lời "Mamba temporal head có hơn mean-pooling không?".
911
+ # - **OOM / quá chậm trên T4 (nhất là khi dùng Mamba thuần PyTorch):** giảm theo thứ tự
912
+ # `MAX_SECONDS` (6→5) → `MAMBA_LAYERS` (2→1) → `UNFREEZE_TOP_LAYERS` (6→4) → `BATCH` (2→1, tăng `ACCUM`).
913
+ # Hoặc thử cài `mamba-ssm causal-conv1d` (nhanh + nhẹ RAM hơn nhiều) — code tự dùng nếu import được.
914
+ # - **Ranking loss (`RANK_LAMBDA`):** thêm pairwise ranking cho 4 cột SRCC (emos/val/aro/dom) → khớp metric
915
+ # UTT-SRCC hơn MSE. ⚠️ **Điểm yếu:** ranking tính trên các cặp TRONG 1 mini-batch; `BATCH=2` → mỗi forward
916
+ # chỉ có 1 cặp → tín hiệu YẾU. Muốn ranking mạnh: tăng `BATCH` (4→8 nếu VRAM chịu được). Ở các exp head
917
+ # ĐÓNG BĂNG (exp06/07, BATCH=64) ranking mạnh hơn nhiều. A/B `RANK_LAMBDA=0` vs `0.3` → bảng ablation cho paper.
918
+ # - **QMOS:** Add Input answer.txt exp07 vào `/kaggle/input/exp07-answer/answer.txt` để mượn QMOS 0.548;
919
+ # không có thì tự chấm UTMOSv2 (cần Internet On).
920
+ # - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp15).
track2/exp16_llm_judge.ipynb ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "7bae4e03",
6
+ "metadata": {},
7
+ "source": [
8
+ "# exp16 — Audio-LLM-as-Judge cho MOS cảm xúc (Track 2)\n",
9
+ "\n",
10
+ "**Ý tưởng:** đưa thẳng audio cho một **audio-LLM** (Gemini / GPT-4o-audio) qua **API** + prompt có\n",
11
+ "cấu trúc → bắt nó chấm cả 6 cột (`QMOS, EMOS, CAT, VAL, ARO, DOM`) → ráp `answer.txt` → nộp CodaBench.\n",
12
+ "\n",
13
+ "**Mục tiêu chính = NOVELTY cho paper** (khảo sát có hệ thống audio-LLM-as-judge cho MOS cảm xúc),\n",
14
+ "so với hệ SSL đã train (exp07 QMOS 0.548 · exp08 EMOS 0.811…). KHÔNG cần GPU — thuần gọi API.\n",
15
+ "\n",
16
+ "| Đặc điểm | Giá trị |\n",
17
+ "|---|---|\n",
18
+ "| GPU | ❌ không cần (chỉ network I/O) |\n",
19
+ "| Tốn phí | ✅ API trả tiền theo token/audio → **cache + resume bắt buộc** |\n",
20
+ "| Provider | `gemini` (mặc định, đã có billing) · `openai` (GPT-4o-audio, để so 2 LLM) |\n",
21
+ "| Output | `answer.txt` 6 cột giống exp07 |\n",
22
+ "\n",
23
+ "**Cách dùng Kaggle:** Internet = **On**; Add-ons → Secrets: `GEMINI_API_KEY` (và `OPENAI_API_KEY`\n",
24
+ "nếu chạy provider openai). Settings GPU **không cần**. Sửa `DATA_ROOT` cho khớp slug rồi Run All.\n",
25
+ "\n",
26
+ "⚠️ **Model ID có thể đã đổi** theo thời gian → kiểm tra `GEMINI_MODEL` / `OPENAI_MODEL` còn nhận\n",
27
+ "audio không trước khi chạy full (xem mục 1)."
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "id": "720a7dc2",
33
+ "metadata": {},
34
+ "source": [
35
+ "## 0. Cấu hình — SỬA Ở ĐÂY"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "id": "c583b4dc",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "import os, io, re, json, time, base64, glob\n",
46
+ "\n",
47
+ "# ── Data Track 2 trên Kaggle ────────────────────────────────────────────────\n",
48
+ "DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
49
+ "WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
50
+ "METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) — nhãn cảm xúc target\n",
51
+ "DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav DEV cần nộp (train phase)\n",
52
+ "TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # chỉ cần khi SHOT_MODE=\"few_shot\"\n",
53
+ "\n",
54
+ "OUT_DIR = \"/kaggle/working\"\n",
55
+ "CACHE_DIR = \"/kaggle/working/exp16_llm_cache\" # nên Save Version / lưu Dataset để KHÔNG gọi lại API\n",
56
+ "os.makedirs(CACHE_DIR, exist_ok=True)\n",
57
+ "\n",
58
+ "# ── Provider & model ────────────────────────────────────────────────────────\n",
59
+ "PROVIDER = \"gemini\" # \"gemini\" | \"openai\"\n",
60
+ "GEMINI_MODEL = \"gemini-2.5-flash\" # << xác nhận model audio hiện hành (baseline dùng họ gemini-*-flash)\n",
61
+ "OPENAI_MODEL = \"gpt-4o-audio-preview\" # << model audio của OpenAI; cần OPENAI_API_KEY\n",
62
+ "TEMPERATURE = 0.0 # cố định để TÁI LẬP (paper)\n",
63
+ "\n",
64
+ "# ── Chế độ chạy ─────────────────────────────────────────────────────────────\n",
65
+ "SHOT_MODE = \"zero_shot\" # \"zero_shot\" | \"few_shot\" (nhét K ví dụ audio có nhãn từ train.csv)\n",
66
+ "FEW_K = 2 # số ví dụ few-shot (mỗi ví dụ = 1 audio + nhãn vàng) — tốn thêm token!\n",
67
+ "LIMIT = 20 # << số nhỏ (20) để smoke test; None = full DEV (~2730) — CHẠY THỬ TRƯỚC\n",
68
+ "MAX_SECONDS = 12 # cắt audio cho rẻ + nhanh\n",
69
+ "WORKERS = 4 # luồng gọi song song (giảm nếu dính rate limit)\n",
70
+ "MAX_RETRY = 3 # số lần thử lại 1 wav khi lỗi mạng / JSON hỏng\n",
71
+ "RETRY_SLEEP = 2.0 # giây nghỉ giữa các lần thử\n",
72
+ "\n",
73
+ "TAG = f\"{PROVIDER}_{(GEMINI_MODEL if PROVIDER=='gemini' else OPENAI_MODEL)}_{SHOT_MODE}\".replace(\"/\", \"-\")\n",
74
+ "CACHE_PATH = os.path.join(CACHE_DIR, f\"{TAG}.jsonl\") # 1 dòng JSON / wav (raw + parsed) → resume\n",
75
+ "print(\"TAG:\", TAG, \"| cache:\", CACHE_PATH)"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "id": "f2cc876e",
81
+ "metadata": {},
82
+ "source": [
83
+ "## 0b. Nhãn cảm xúc target + chuẩn hóa lớp (tái dùng quy ước baseline)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "5b5c7f92",
90
+ "metadata": {
91
+ "lines_to_next_cell": 1
92
+ },
93
+ "outputs": [],
94
+ "source": [
95
+ "EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"] # THỨ TỰ chuẩn cho cột CAT\n",
96
+ "\n",
97
+ "_EMO_ALIAS = {\n",
98
+ " \"angry\": \"angry\", \"anger\": \"angry\",\n",
99
+ " \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
100
+ " \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
101
+ " \"sad\": \"sad\", \"sadness\": \"sad\",\n",
102
+ " \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
103
+ "}\n",
104
+ "\n",
105
+ "def norm_emotion(label):\n",
106
+ " key = str(label).strip().lower()\n",
107
+ " return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
108
+ "\n",
109
+ "def stem(name):\n",
110
+ " return os.path.splitext(os.path.basename(name))[0]\n",
111
+ "\n",
112
+ "def load_target_emotions():\n",
113
+ " \"\"\"metadata.csv (wavID|emotion|transcript, không header) → {stem: emotion_chuẩn}.\"\"\"\n",
114
+ " tgt = {}\n",
115
+ " if not (METADATA_CSV and os.path.exists(METADATA_CSV)):\n",
116
+ " print(\"⚠️ Không thấy metadata.csv → EMOS sẽ thiếu cảm xúc target.\")\n",
117
+ " return tgt\n",
118
+ " with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
119
+ " for ln in f:\n",
120
+ " parts = ln.strip().split(\"|\")\n",
121
+ " if len(parts) < 2:\n",
122
+ " continue\n",
123
+ " tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
124
+ " return tgt\n",
125
+ "\n",
126
+ "target_map = load_target_emotions()\n",
127
+ "print(\"Nhãn cảm xúc target:\", len(target_map))\n",
128
+ "\n",
129
+ "def list_dev():\n",
130
+ " with open(DEV_SCP) as f:\n",
131
+ " return [ln.strip() for ln in f if ln.strip()]\n",
132
+ "\n",
133
+ "dev_names = list_dev()\n",
134
+ "if LIMIT:\n",
135
+ " dev_names = dev_names[:LIMIT]\n",
136
+ "print(\"DEV cần chấm:\", len(dev_names), \"mẫu\", \"| LIMIT =\", LIMIT)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "markdown",
141
+ "id": "881a37a5",
142
+ "metadata": {},
143
+ "source": [
144
+ "## 1. Cài SDK + nạp key\n",
145
+ "\n",
146
+ "Gemini dùng SDK mới `google-genai`; OpenAI dùng `openai`. Trên Kaggle **Internet phải On**."
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "id": "a1d0c66b",
153
+ "metadata": {},
154
+ "outputs": [],
155
+ "source": [
156
+ "!pip -q install google-genai openai soundfile librosa\n",
157
+ "\n",
158
+ "def setup_keys():\n",
159
+ " \"\"\"Nạp API key từ Kaggle Secrets (fallback: biến môi trường đã set sẵn).\"\"\"\n",
160
+ " try:\n",
161
+ " from kaggle_secrets import UserSecretsClient\n",
162
+ " sec = UserSecretsClient()\n",
163
+ " for k in [\"GEMINI_API_KEY\", \"OPENAI_API_KEY\"]:\n",
164
+ " try:\n",
165
+ " os.environ[k] = sec.get_secret(k)\n",
166
+ " print(f\"Đã nạp {k} từ Secrets\")\n",
167
+ " except Exception:\n",
168
+ " pass\n",
169
+ " except Exception as e:\n",
170
+ " print(\"Không dùng được Kaggle Secrets:\", e, \"→ set tay os.environ[...] nếu cần\")\n",
171
+ "\n",
172
+ "setup_keys()"
173
+ ]
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "id": "a4ceeacf",
178
+ "metadata": {},
179
+ "source": [
180
+ "## 2. Đọc + chuẩn hóa audio (16kHz mono, cắt MAX_SECONDS) → bytes WAV trong RAM"
181
+ ]
182
+ },
183
+ {
184
+ "cell_type": "code",
185
+ "execution_count": null,
186
+ "id": "68d431ff",
187
+ "metadata": {
188
+ "lines_to_next_cell": 1
189
+ },
190
+ "outputs": [],
191
+ "source": [
192
+ "import numpy as np\n",
193
+ "\n",
194
+ "def load_wav_bytes(path, sr=16000, max_seconds=MAX_SECONDS):\n",
195
+ " \"\"\"Trả (wav_bytes, base64_str). Cắt ≤ max_seconds, resample 16k mono, encode WAV PCM16.\"\"\"\n",
196
+ " import soundfile as sf\n",
197
+ " try:\n",
198
+ " import librosa\n",
199
+ " y, _ = librosa.load(path, sr=sr, mono=True)\n",
200
+ " except Exception:\n",
201
+ " y, in_sr = sf.read(path)\n",
202
+ " if y.ndim > 1:\n",
203
+ " y = y.mean(axis=1)\n",
204
+ " if in_sr != sr: # fallback resample tuyến tính nếu không có librosa\n",
205
+ " idx = np.linspace(0, len(y) - 1, int(len(y) * sr / in_sr))\n",
206
+ " y = np.interp(idx, np.arange(len(y)), y)\n",
207
+ " if max_seconds:\n",
208
+ " y = y[: int(sr * max_seconds)]\n",
209
+ " buf = io.BytesIO()\n",
210
+ " sf.write(buf, y.astype(np.float32), sr, format=\"WAV\", subtype=\"PCM_16\")\n",
211
+ " raw = buf.getvalue()\n",
212
+ " return raw, base64.b64encode(raw).decode(\"ascii\")"
213
+ ]
214
+ },
215
+ {
216
+ "cell_type": "markdown",
217
+ "id": "9d846428",
218
+ "metadata": {},
219
+ "source": [
220
+ "## 3. Prompt — định nghĩa 6 metric + ép JSON nghiêm ngặt\n",
221
+ "\n",
222
+ "QMOS = chất lượng/độ tự nhiên (sạch, không méo/robot). EMOS = độ KHỚP với **cảm xúc target**.\n",
223
+ "CAT = phân phối vote 5 lớp. VAD = Valence/Arousal/Dominance. Tất cả thang **1–5** (CAT là tỉ lệ 0–1)."
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "id": "f7046919",
230
+ "metadata": {
231
+ "lines_to_next_cell": 1
232
+ },
233
+ "outputs": [],
234
+ "source": [
235
+ "SYSTEM_INSTRUCTION = (\n",
236
+ " \"You are an expert evaluator of emotional text-to-speech. \"\n",
237
+ " \"Listen to the audio and rate it. Respond with ONLY a compact JSON object, no prose.\"\n",
238
+ ")\n",
239
+ "\n",
240
+ "def build_prompt(target_emo):\n",
241
+ " tgt = target_emo if target_emo else \"unknown\"\n",
242
+ " return (\n",
243
+ " \"Rate this speech utterance. The INTENDED (target) emotion is: \"\n",
244
+ " f\"\\\"{tgt}\\\".\\n\\n\"\n",
245
+ " \"Return a JSON object with EXACTLY these keys (numbers on a 1-5 scale unless stated):\\n\"\n",
246
+ " \" \\\"qmos\\\": overall audio QUALITY / naturalness (1=very unnatural/robotic/distorted, 5=clean & human-like).\\n\"\n",
247
+ " \" \\\"emos\\\": how well the emotion expressed MATCHES the target emotion above \"\n",
248
+ " \"(1=not matching at all, 5=perfectly matching).\\n\"\n",
249
+ " \" \\\"cat\\\": an object with probabilities (summing to 1.0) over the 5 perceived emotions: \"\n",
250
+ " \"{\\\"neutral\\\":_, \\\"happy\\\":_, \\\"sad\\\":_, \\\"angry\\\":_, \\\"surprised\\\":_}.\\n\"\n",
251
+ " \" \\\"val\\\": valence (1=very negative, 5=very positive).\\n\"\n",
252
+ " \" \\\"aro\\\": arousal (1=very calm, 5=very excited).\\n\"\n",
253
+ " \" \\\"dom\\\": dominance (1=very submissive, 5=very dominant).\\n\\n\"\n",
254
+ " \"Example format: \"\n",
255
+ " \"{\\\"qmos\\\":3.5,\\\"emos\\\":4.0,\"\n",
256
+ " \"\\\"cat\\\":{\\\"neutral\\\":0.1,\\\"happy\\\":0.7,\\\"sad\\\":0.0,\\\"angry\\\":0.1,\\\"surprised\\\":0.1},\"\n",
257
+ " \"\\\"val\\\":4.0,\\\"aro\\\":3.5,\\\"dom\\\":3.0}\\n\"\n",
258
+ " \"Respond with ONLY the JSON.\"\n",
259
+ " )"
260
+ ]
261
+ },
262
+ {
263
+ "cell_type": "markdown",
264
+ "id": "fe8d1303",
265
+ "metadata": {},
266
+ "source": [
267
+ "## 3b. (tùy chọn) Few-shot — lấy K ví dụ audio có nhãn vàng từ train.csv\n",
268
+ "\n",
269
+ "Bật khi `SHOT_MODE=\"few_shot\"`. Mỗi ví dụ = 1 audio train + nhãn vàng (gộp TB theo wav). Tốn thêm token."
270
+ ]
271
+ },
272
+ {
273
+ "cell_type": "code",
274
+ "execution_count": null,
275
+ "id": "5fdf89c7",
276
+ "metadata": {},
277
+ "outputs": [],
278
+ "source": [
279
+ "few_shot_examples = [] # list[(audio_b64, audio_bytes, gold_json_str)]\n",
280
+ "\n",
281
+ "def _agg_train_labels():\n",
282
+ " \"\"\"Gộp train.csv (sep='|') theo wavID → nhãn vàng trung bình; CAT = tỉ lệ vote.\"\"\"\n",
283
+ " import pandas as pd\n",
284
+ " df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
285
+ " rows = {}\n",
286
+ " for wav, g in df.groupby(\"wavID\"):\n",
287
+ " votes = np.zeros(5, np.float32)\n",
288
+ " for cell in g[\"emoCat\"].astype(str):\n",
289
+ " for tok in cell.split(\",\"):\n",
290
+ " e = norm_emotion(tok)\n",
291
+ " if e in EMOTIONS5:\n",
292
+ " votes[EMOTIONS5.index(e)] += 1\n",
293
+ " s = votes.sum()\n",
294
+ " cat = (votes / s) if s > 0 else np.full(5, 0.2, np.float32)\n",
295
+ " rows[stem(wav)] = dict(\n",
296
+ " qmos=float(g[\"qMOS\"].mean()), emos=float(g[\"eMOS\"].mean()),\n",
297
+ " val=float(g[\"val\"].mean()), aro=float(g[\"aro\"].mean()), dom=float(g[\"dom\"].mean()),\n",
298
+ " cat={EMOTIONS5[i]: round(float(cat[i]), 4) for i in range(5)},\n",
299
+ " )\n",
300
+ " return rows\n",
301
+ "\n",
302
+ "def build_few_shot():\n",
303
+ " if SHOT_MODE != \"few_shot\":\n",
304
+ " return\n",
305
+ " labels = _agg_train_labels()\n",
306
+ " picked = list(labels.keys())[:FEW_K]\n",
307
+ " for sid in picked:\n",
308
+ " wavp = os.path.join(WAV_DIR, sid + \".wav\")\n",
309
+ " if not os.path.exists(wavp):\n",
310
+ " continue\n",
311
+ " raw, b64 = load_wav_bytes(wavp)\n",
312
+ " gold = labels[sid]\n",
313
+ " gold_json = json.dumps({\n",
314
+ " \"qmos\": round(gold[\"qmos\"], 2), \"emos\": round(gold[\"emos\"], 2),\n",
315
+ " \"cat\": gold[\"cat\"], \"val\": round(gold[\"val\"], 2),\n",
316
+ " \"aro\": round(gold[\"aro\"], 2), \"dom\": round(gold[\"dom\"], 2),\n",
317
+ " })\n",
318
+ " few_shot_examples.append((b64, raw, gold_json))\n",
319
+ " print(f\"Few-shot: {len(few_shot_examples)} ví dụ\")\n",
320
+ "\n",
321
+ "build_few_shot()"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "markdown",
326
+ "id": "4d3c1fef",
327
+ "metadata": {},
328
+ "source": [
329
+ "## 4. Gọi API — trừu tượng hóa provider (gemini / openai)\n",
330
+ "\n",
331
+ "Mỗi provider tự dựng message của nó (kèm few-shot nếu có). Trả về **text thô** để parse ở mục 5."
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "id": "ae85c4bf",
338
+ "metadata": {
339
+ "lines_to_next_cell": 1
340
+ },
341
+ "outputs": [],
342
+ "source": [
343
+ "_client = {\"gemini\": None, \"openai\": None}\n",
344
+ "\n",
345
+ "def _gemini_client():\n",
346
+ " if _client[\"gemini\"] is None:\n",
347
+ " from google import genai\n",
348
+ " _client[\"gemini\"] = genai.Client(api_key=os.environ[\"GEMINI_API_KEY\"])\n",
349
+ " return _client[\"gemini\"]\n",
350
+ "\n",
351
+ "def _openai_client():\n",
352
+ " if _client[\"openai\"] is None:\n",
353
+ " from openai import OpenAI\n",
354
+ " _client[\"openai\"] = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
355
+ " return _client[\"openai\"]\n",
356
+ "\n",
357
+ "def call_gemini(audio_b64, audio_bytes, prompt):\n",
358
+ " from google.genai import types\n",
359
+ " client = _gemini_client()\n",
360
+ " contents = []\n",
361
+ " for ex_b64, ex_bytes, ex_gold in few_shot_examples: # few-shot: audio ví dụ + nhãn vàng\n",
362
+ " contents.append(types.Content(role=\"user\", parts=[\n",
363
+ " types.Part.from_bytes(data=ex_bytes, mime_type=\"audio/wav\"),\n",
364
+ " types.Part.from_text(text=build_prompt(None)),\n",
365
+ " ]))\n",
366
+ " contents.append(types.Content(role=\"model\", parts=[types.Part.from_text(text=ex_gold)]))\n",
367
+ " contents.append(types.Content(role=\"user\", parts=[\n",
368
+ " types.Part.from_bytes(data=audio_bytes, mime_type=\"audio/wav\"),\n",
369
+ " types.Part.from_text(text=prompt),\n",
370
+ " ]))\n",
371
+ " resp = client.models.generate_content(\n",
372
+ " model=GEMINI_MODEL, contents=contents,\n",
373
+ " config=types.GenerateContentConfig(\n",
374
+ " system_instruction=SYSTEM_INSTRUCTION, temperature=TEMPERATURE),\n",
375
+ " )\n",
376
+ " return resp.text\n",
377
+ "\n",
378
+ "def call_openai(audio_b64, audio_bytes, prompt):\n",
379
+ " client = _openai_client()\n",
380
+ " messages = [{\"role\": \"system\", \"content\": SYSTEM_INSTRUCTION}]\n",
381
+ " for ex_b64, ex_bytes, ex_gold in few_shot_examples:\n",
382
+ " messages.append({\"role\": \"user\", \"content\": [\n",
383
+ " {\"type\": \"text\", \"text\": build_prompt(None)},\n",
384
+ " {\"type\": \"input_audio\", \"input_audio\": {\"data\": ex_b64, \"format\": \"wav\"}},\n",
385
+ " ]})\n",
386
+ " messages.append({\"role\": \"assistant\", \"content\": ex_gold})\n",
387
+ " messages.append({\"role\": \"user\", \"content\": [\n",
388
+ " {\"type\": \"text\", \"text\": prompt},\n",
389
+ " {\"type\": \"input_audio\", \"input_audio\": {\"data\": audio_b64, \"format\": \"wav\"}},\n",
390
+ " ]})\n",
391
+ " resp = client.chat.completions.create(\n",
392
+ " model=OPENAI_MODEL, messages=messages, temperature=TEMPERATURE,\n",
393
+ " modalities=[\"text\"],\n",
394
+ " )\n",
395
+ " return resp.choices[0].message.content\n",
396
+ "\n",
397
+ "def call_llm(audio_b64, audio_bytes, prompt):\n",
398
+ " return call_gemini(audio_b64, audio_bytes, prompt) if PROVIDER == \"gemini\" \\\n",
399
+ " else call_openai(audio_b64, audio_bytes, prompt)"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "markdown",
404
+ "id": "f6ef0abc",
405
+ "metadata": {},
406
+ "source": [
407
+ "## 5. Parse JSON chịu lỗi → 6 cột; clamp [1,5]; chuẩn hóa CAT"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": null,
413
+ "id": "74507a4a",
414
+ "metadata": {
415
+ "lines_to_next_cell": 1
416
+ },
417
+ "outputs": [],
418
+ "source": [
419
+ "def _clamp(x, lo=1.0, hi=5.0, default=3.0):\n",
420
+ " try:\n",
421
+ " v = float(x)\n",
422
+ " except Exception:\n",
423
+ " return default\n",
424
+ " return max(lo, min(hi, v))\n",
425
+ "\n",
426
+ "def parse_response(text):\n",
427
+ " \"\"\"text thô LLM → dict {qmos,emos,cat5(list theo EMOTIONS5),val,aro,dom} hoặc None nếu hỏng.\"\"\"\n",
428
+ " if not text:\n",
429
+ " return None\n",
430
+ " m = re.search(r\"\\{.*\\}\", text, re.DOTALL) # trích khối JSON đầu tiên\n",
431
+ " if not m:\n",
432
+ " return None\n",
433
+ " try:\n",
434
+ " d = json.loads(m.group(0))\n",
435
+ " except Exception:\n",
436
+ " return None\n",
437
+ " cat_in = d.get(\"cat\", {}) or {}\n",
438
+ " cat = np.zeros(5, np.float32)\n",
439
+ " for k, v in cat_in.items():\n",
440
+ " e = norm_emotion(k)\n",
441
+ " if e in EMOTIONS5:\n",
442
+ " try:\n",
443
+ " cat[EMOTIONS5.index(e)] = max(0.0, float(v))\n",
444
+ " except Exception:\n",
445
+ " pass\n",
446
+ " cat = cat / cat.sum() if cat.sum() > 0 else np.full(5, 0.2, np.float32)\n",
447
+ " return dict(\n",
448
+ " qmos=_clamp(d.get(\"qmos\")), emos=_clamp(d.get(\"emos\")),\n",
449
+ " cat5=cat.tolist(),\n",
450
+ " val=_clamp(d.get(\"val\")), aro=_clamp(d.get(\"aro\")), dom=_clamp(d.get(\"dom\")),\n",
451
+ " )"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "markdown",
456
+ "id": "462449d1",
457
+ "metadata": {},
458
+ "source": [
459
+ "## 6. Vòng chấm có CACHE + RESUME (KHÔNG gọi lại wav đã có trong cache)"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": null,
465
+ "id": "ee30edbc",
466
+ "metadata": {
467
+ "lines_to_next_cell": 1
468
+ },
469
+ "outputs": [],
470
+ "source": [
471
+ "def load_cache():\n",
472
+ " done = {}\n",
473
+ " if os.path.exists(CACHE_PATH):\n",
474
+ " with open(CACHE_PATH, encoding=\"utf-8\") as f:\n",
475
+ " for ln in f:\n",
476
+ " try:\n",
477
+ " r = json.loads(ln)\n",
478
+ " done[r[\"stem\"]] = r\n",
479
+ " except Exception:\n",
480
+ " continue\n",
481
+ " return done\n",
482
+ "\n",
483
+ "def score_one(name):\n",
484
+ " \"\"\"Gọi LLM cho 1 wav, retry; trả record dict {stem,name,raw,parsed}.\"\"\"\n",
485
+ " sid = stem(name)\n",
486
+ " wavp = os.path.join(WAV_DIR, name if name.endswith(\".wav\") else name + \".wav\")\n",
487
+ " tgt = target_map.get(sid)\n",
488
+ " prompt = build_prompt(tgt)\n",
489
+ " last_err = None\n",
490
+ " for attempt in range(MAX_RETRY):\n",
491
+ " try:\n",
492
+ " _, b64 = (None, None)\n",
493
+ " raw_bytes, b64 = load_wav_bytes(wavp)\n",
494
+ " text = call_llm(b64, raw_bytes, prompt)\n",
495
+ " parsed = parse_response(text)\n",
496
+ " if parsed is not None:\n",
497
+ " return dict(stem=sid, name=name, raw=text, parsed=parsed, ok=True)\n",
498
+ " last_err = \"parse_fail\"\n",
499
+ " except Exception as e:\n",
500
+ " last_err = str(e)\n",
501
+ " time.sleep(RETRY_SLEEP * (attempt + 1))\n",
502
+ " return dict(stem=sid, name=name, raw=None, parsed=None, ok=False, err=last_err)\n",
503
+ "\n",
504
+ "def run_scoring():\n",
505
+ " from concurrent.futures import ThreadPoolExecutor, as_completed\n",
506
+ " done = load_cache()\n",
507
+ " todo = [n for n in dev_names if stem(n) not in done]\n",
508
+ " print(f\"Cache có {len(done)} | cần chấm thêm {len(todo)} | ước lượng {len(todo)} call API\")\n",
509
+ " if not todo:\n",
510
+ " return done\n",
511
+ " n_ok = n_bad = 0\n",
512
+ " with open(CACHE_PATH, \"a\", encoding=\"utf-8\") as fout, \\\n",
513
+ " ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
514
+ " futs = {ex.submit(score_one, n): n for n in todo}\n",
515
+ " for i, fut in enumerate(as_completed(futs), 1):\n",
516
+ " rec = fut.result()\n",
517
+ " fout.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
518
+ " fout.flush()\n",
519
+ " done[rec[\"stem\"]] = rec\n",
520
+ " n_ok += int(rec[\"ok\"]); n_bad += int(not rec[\"ok\"])\n",
521
+ " if i % 50 == 0 or i == len(todo):\n",
522
+ " print(f\" {i}/{len(todo)} | ok={n_ok} bad={n_bad}\")\n",
523
+ " if n_bad:\n",
524
+ " print(f\"⚠️ {n_bad} wav hỏng (parse/API) → sẽ điền mặc định ở build_answer.\")\n",
525
+ " return done\n",
526
+ "\n",
527
+ "records = run_scoring()"
528
+ ]
529
+ },
530
+ {
531
+ "cell_type": "markdown",
532
+ "id": "eda27285",
533
+ "metadata": {},
534
+ "source": [
535
+ "## 7. Ráp `answer.txt` 6 cột (giống exp07) + validate + zip"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": null,
541
+ "id": "b9a4bd65",
542
+ "metadata": {
543
+ "lines_to_next_cell": 1
544
+ },
545
+ "outputs": [],
546
+ "source": [
547
+ "def fmt_cat(probs5):\n",
548
+ " return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
549
+ "\n",
550
+ "def build_answer(out_path):\n",
551
+ " n_real = n_default = 0\n",
552
+ " with open(out_path, \"w\") as f:\n",
553
+ " f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
554
+ " for name in dev_names:\n",
555
+ " sid = stem(name)\n",
556
+ " rec = records.get(sid)\n",
557
+ " p = rec[\"parsed\"] if (rec and rec.get(\"parsed\")) else None\n",
558
+ " if p is None:\n",
559
+ " qmos = emos = val = aro = dom = 3.0\n",
560
+ " cat5 = [0.2] * 5\n",
561
+ " n_default += 1\n",
562
+ " else:\n",
563
+ " qmos, emos = p[\"qmos\"], p[\"emos\"]\n",
564
+ " val, aro, dom = p[\"val\"], p[\"aro\"], p[\"dom\"]\n",
565
+ " cat5 = p[\"cat5\"]; n_real += 1\n",
566
+ " f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
567
+ " f\"{val:.6g},{aro:.6g},{dom:.6g}\\n\")\n",
568
+ " print(f\"Ghi {len(dev_names)} dòng → {out_path} | LLM thật {n_real}, mặc định {n_default}\")\n",
569
+ "\n",
570
+ "answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
571
+ "build_answer(answer_path)\n",
572
+ "\n",
573
+ "def validate(path):\n",
574
+ " import csv\n",
575
+ " with open(path) as f:\n",
576
+ " rows = list(csv.reader(f))\n",
577
+ " header = rows[0]\n",
578
+ " assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
579
+ " for i, r in enumerate(rows[1:], 2):\n",
580
+ " assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
581
+ " print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
582
+ "\n",
583
+ "validate(answer_path)\n",
584
+ "!cd /kaggle/working && zip -j submission_track2_exp16.zip answer.txt && unzip -l submission_track2_exp16.zip\n",
585
+ "print(\"Sẵn sàng nộp: /kaggle/working/submission_track2_exp16.zip\")"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "markdown",
590
+ "id": "f8eeafb8",
591
+ "metadata": {},
592
+ "source": [
593
+ "## 8. (tùy chọn) Ensemble muộn: trộn THỨ HẠNG điểm LLM + hệ trained\n",
594
+ "\n",
595
+ "Trung bình rank của exp16 với một `answer.txt` đã có (vd bản trộn cột exp07+exp08) cho từng cột số.\n",
596
+ "Đa dạng nguồn → có thể giảm nhiễu. CHỈ chạy khi có sẵn file kia (đặt đường dẫn rồi bỏ comment)."
597
+ ]
598
+ },
599
+ {
600
+ "cell_type": "code",
601
+ "execution_count": null,
602
+ "id": "d1a15c1f",
603
+ "metadata": {},
604
+ "outputs": [],
605
+ "source": [
606
+ "def ensemble_rank_average(answer_a, answer_b, out_path):\n",
607
+ " \"\"\"Trộn 2 answer.txt theo TRUNG BÌNH THỨ HẠNG cho 5 cột số (QMOS/EMOS/VAL/ARO/DOM); CAT lấy theo A.\"\"\"\n",
608
+ " import pandas as pd\n",
609
+ " num_cols = [\"QMOS\", \"EMOS\", \"VAL\", \"ARO\", \"DOM\"]\n",
610
+ " A = pd.read_csv(answer_a); B = pd.read_csv(answer_b)\n",
611
+ " A = A.set_index(\"wav\"); B = B.set_index(\"wav\").reindex(A.index)\n",
612
+ " out = A.copy()\n",
613
+ " for c in num_cols:\n",
614
+ " if c in A.columns and c in B.columns:\n",
615
+ " ra = A[c].rank(); rb = B[c].rank()\n",
616
+ " out[c] = ((ra + rb) / 2.0) # SRCC bất biến với scale → để nguyên rank trung bình\n",
617
+ " out.reset_index().to_csv(out_path, index=False)\n",
618
+ " print(\"Ensemble →\", out_path)\n",
619
+ "\n",
620
+ "# ensemble_rank_average(answer_path,\n",
621
+ "# \"/kaggle/input/.../exp_mix_q07_emo08/answer.txt\",\n",
622
+ "# os.path.join(OUT_DIR, \"answer_ens.txt\"))"
623
+ ]
624
+ },
625
+ {
626
+ "cell_type": "markdown",
627
+ "id": "14172d70",
628
+ "metadata": {},
629
+ "source": [
630
+ "## Ghi chú nộp & paper\n",
631
+ "- Nộp: My Submissions → **Track 2** (bỏ chọn track khác) → `submission_track2_exp16.zip` → đọc SRCC 6 cột.\n",
632
+ "- **Bảng A (paper):** đặt SRCC exp16 (gemini/openai, zero-shot) cạnh exp07 (QMOS 0.548) + exp08\n",
633
+ " (EMOS 0.811 · CAT 0.133 · VAD 0.659/0.793/0.751). Kỳ vọng: LLM khá ở EMOS/CAT, yếu ở QMOS.\n",
634
+ "- **Bảng B:** chạy lại `SHOT_MODE=\"few_shot\"` (1 provider) → so zero vs few-shot.\n",
635
+ "- **Cache:** Save Version để giữ `exp16_llm_cache/*.jsonl` (không trả tiền lại). Lưu thành Kaggle\n",
636
+ " Dataset nếu muốn dùng cho eval phase.\n",
637
+ "- **Khai báo external resource** (API thương mại Gemini/OpenAI) trong `12_system_description.md`."
638
+ ]
639
+ }
640
+ ],
641
+ "metadata": {
642
+ "jupytext": {
643
+ "cell_metadata_filter": "-all",
644
+ "main_language": "python",
645
+ "notebook_metadata_filter": "-all"
646
+ }
647
+ },
648
+ "nbformat": 4,
649
+ "nbformat_minor": 5
650
+ }
track2/exp16_llm_judge_pipeline.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # exp16 — Audio-LLM-as-Judge cho MOS cảm xúc (Track 2)
3
+ #
4
+ # **Ý tưởng:** đưa thẳng audio cho một **audio-LLM** (Gemini / GPT-4o-audio) qua **API** + prompt có
5
+ # cấu trúc → bắt nó chấm cả 6 cột (`QMOS, EMOS, CAT, VAL, ARO, DOM`) → ráp `answer.txt` → nộp CodaBench.
6
+ #
7
+ # **Mục tiêu chính = NOVELTY cho paper** (khảo sát có hệ thống audio-LLM-as-judge cho MOS cảm xúc),
8
+ # so với hệ SSL đã train (exp07 QMOS 0.548 · exp08 EMOS 0.811…). KHÔNG cần GPU — thuần gọi API.
9
+ #
10
+ # | Đặc điểm | Giá trị |
11
+ # |---|---|
12
+ # | GPU | ❌ không cần (chỉ network I/O) |
13
+ # | Tốn phí | ✅ API trả tiền theo token/audio → **cache + resume bắt buộc** |
14
+ # | Provider | `gemini` (mặc định, đã có billing) · `openai` (GPT-4o-audio, để so 2 LLM) |
15
+ # | Output | `answer.txt` 6 cột giống exp07 |
16
+ #
17
+ # **Cách dùng Kaggle:** Internet = **On**; Add-ons → Secrets: `GEMINI_API_KEY` (và `OPENAI_API_KEY`
18
+ # nếu chạy provider openai). Settings GPU **không cần**. Sửa `DATA_ROOT` cho khớp slug rồi Run All.
19
+ #
20
+ # ⚠️ **Model ID có thể đã đổi** theo thời gian → kiểm tra `GEMINI_MODEL` / `OPENAI_MODEL` còn nhận
21
+ # audio không trước khi chạy full (xem mục 1).
22
+
23
+ # %% [markdown]
24
+ # ## 0. Cấu hình — SỬA Ở ĐÂY
25
+
26
+ # %%
27
+ import os, io, re, json, time, base64, glob
28
+
29
+ # ── Data Track 2 trên Kaggle ────────────────────────────────────────────────
30
+ DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
31
+ WAV_DIR = f"{DATA_ROOT}/wav"
32
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) — nhãn cảm xúc target
33
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav DEV cần nộp (train phase)
34
+ TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # chỉ cần khi SHOT_MODE="few_shot"
35
+
36
+ OUT_DIR = "/kaggle/working"
37
+ CACHE_DIR = "/kaggle/working/exp16_llm_cache" # nên Save Version / lưu Dataset để KHÔNG gọi lại API
38
+ os.makedirs(CACHE_DIR, exist_ok=True)
39
+
40
+ # ── Provider & model ────────────────────────────────────────────────────────
41
+ PROVIDER = "gemini" # "gemini" | "openai"
42
+ GEMINI_MODEL = "gemini-2.5-flash" # << xác nhận model audio hiện hành (baseline dùng họ gemini-*-flash)
43
+ OPENAI_MODEL = "gpt-4o-audio-preview" # << model audio của OpenAI; cần OPENAI_API_KEY
44
+ TEMPERATURE = 0.0 # cố định để TÁI LẬP (paper)
45
+
46
+ # ── Chế độ chạy ─────────────────────────────────────────────────────────────
47
+ SHOT_MODE = "zero_shot" # "zero_shot" | "few_shot" (nhét K ví dụ audio có nhãn từ train.csv)
48
+ FEW_K = 2 # số ví dụ few-shot (mỗi ví dụ = 1 audio + nhãn vàng) — tốn thêm token!
49
+ LIMIT = 20 # << số nhỏ (20) để smoke test; None = full DEV (~2730) — CHẠY THỬ TRƯỚC
50
+ MAX_SECONDS = 12 # cắt audio cho rẻ + nhanh
51
+ WORKERS = 4 # luồng gọi song song (giảm nếu dính rate limit)
52
+ MAX_RETRY = 3 # số lần thử lại 1 wav khi lỗi mạng / JSON hỏng
53
+ RETRY_SLEEP = 2.0 # giây nghỉ giữa các lần thử
54
+
55
+ TAG = f"{PROVIDER}_{(GEMINI_MODEL if PROVIDER=='gemini' else OPENAI_MODEL)}_{SHOT_MODE}".replace("/", "-")
56
+ CACHE_PATH = os.path.join(CACHE_DIR, f"{TAG}.jsonl") # 1 dòng JSON / wav (raw + parsed) → resume
57
+ print("TAG:", TAG, "| cache:", CACHE_PATH)
58
+
59
+ # %% [markdown]
60
+ # ## 0b. Nhãn cảm xúc target + chuẩn hóa lớp (tái dùng quy ước baseline)
61
+
62
+ # %%
63
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"] # THỨ TỰ chuẩn cho cột CAT
64
+
65
+ _EMO_ALIAS = {
66
+ "angry": "angry", "anger": "angry",
67
+ "happy": "happy", "happiness": "happy", "joy": "happy",
68
+ "neutral": "neutral", "calm": "neutral",
69
+ "sad": "sad", "sadness": "sad",
70
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
71
+ }
72
+
73
+ def norm_emotion(label):
74
+ key = str(label).strip().lower()
75
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
76
+
77
+ def stem(name):
78
+ return os.path.splitext(os.path.basename(name))[0]
79
+
80
+ def load_target_emotions():
81
+ """metadata.csv (wavID|emotion|transcript, không header) → {stem: emotion_chuẩn}."""
82
+ tgt = {}
83
+ if not (METADATA_CSV and os.path.exists(METADATA_CSV)):
84
+ print("⚠️ Không thấy metadata.csv → EMOS sẽ thiếu cảm xúc target.")
85
+ return tgt
86
+ with open(METADATA_CSV, encoding="utf-8") as f:
87
+ for ln in f:
88
+ parts = ln.strip().split("|")
89
+ if len(parts) < 2:
90
+ continue
91
+ tgt[stem(parts[0])] = norm_emotion(parts[1])
92
+ return tgt
93
+
94
+ target_map = load_target_emotions()
95
+ print("Nhãn cảm xúc target:", len(target_map))
96
+
97
+ def list_dev():
98
+ with open(DEV_SCP) as f:
99
+ return [ln.strip() for ln in f if ln.strip()]
100
+
101
+ dev_names = list_dev()
102
+ if LIMIT:
103
+ dev_names = dev_names[:LIMIT]
104
+ print("DEV cần chấm:", len(dev_names), "mẫu", "| LIMIT =", LIMIT)
105
+
106
+ # %% [markdown]
107
+ # ## 1. Cài SDK + nạp key
108
+ #
109
+ # Gemini dùng SDK mới `google-genai`; OpenAI dùng `openai`. Trên Kaggle **Internet phải On**.
110
+
111
+ # %%
112
+ # !pip -q install google-genai openai soundfile librosa
113
+
114
+ def setup_keys():
115
+ """Nạp API key từ Kaggle Secrets (fallback: biến môi trường đã set sẵn)."""
116
+ try:
117
+ from kaggle_secrets import UserSecretsClient
118
+ sec = UserSecretsClient()
119
+ for k in ["GEMINI_API_KEY", "OPENAI_API_KEY"]:
120
+ try:
121
+ os.environ[k] = sec.get_secret(k)
122
+ print(f"Đã nạp {k} từ Secrets")
123
+ except Exception:
124
+ pass
125
+ except Exception as e:
126
+ print("Không dùng được Kaggle Secrets:", e, "→ set tay os.environ[...] nếu cần")
127
+
128
+ setup_keys()
129
+
130
+ # %% [markdown]
131
+ # ## 2. Đọc + chuẩn hóa audio (16kHz mono, cắt MAX_SECONDS) → bytes WAV trong RAM
132
+
133
+ # %%
134
+ import numpy as np
135
+
136
+ def load_wav_bytes(path, sr=16000, max_seconds=MAX_SECONDS):
137
+ """Trả (wav_bytes, base64_str). Cắt ≤ max_seconds, resample 16k mono, encode WAV PCM16."""
138
+ import soundfile as sf
139
+ try:
140
+ import librosa
141
+ y, _ = librosa.load(path, sr=sr, mono=True)
142
+ except Exception:
143
+ y, in_sr = sf.read(path)
144
+ if y.ndim > 1:
145
+ y = y.mean(axis=1)
146
+ if in_sr != sr: # fallback resample tuyến tính nếu không có librosa
147
+ idx = np.linspace(0, len(y) - 1, int(len(y) * sr / in_sr))
148
+ y = np.interp(idx, np.arange(len(y)), y)
149
+ if max_seconds:
150
+ y = y[: int(sr * max_seconds)]
151
+ buf = io.BytesIO()
152
+ sf.write(buf, y.astype(np.float32), sr, format="WAV", subtype="PCM_16")
153
+ raw = buf.getvalue()
154
+ return raw, base64.b64encode(raw).decode("ascii")
155
+
156
+ # %% [markdown]
157
+ # ## 3. Prompt — định nghĩa 6 metric + ép JSON nghiêm ngặt
158
+ #
159
+ # QMOS = chất lượng/độ tự nhiên (sạch, không méo/robot). EMOS = độ KHỚP với **cảm xúc target**.
160
+ # CAT = phân phối vote 5 lớp. VAD = Valence/Arousal/Dominance. Tất cả thang **1–5** (CAT là tỉ lệ 0–1).
161
+
162
+ # %%
163
+ SYSTEM_INSTRUCTION = (
164
+ "You are an expert evaluator of emotional text-to-speech. "
165
+ "Listen to the audio and rate it. Respond with ONLY a compact JSON object, no prose."
166
+ )
167
+
168
+ def build_prompt(target_emo):
169
+ tgt = target_emo if target_emo else "unknown"
170
+ return (
171
+ "Rate this speech utterance. The INTENDED (target) emotion is: "
172
+ f"\"{tgt}\".\n\n"
173
+ "Return a JSON object with EXACTLY these keys (numbers on a 1-5 scale unless stated):\n"
174
+ " \"qmos\": overall audio QUALITY / naturalness (1=very unnatural/robotic/distorted, 5=clean & human-like).\n"
175
+ " \"emos\": how well the emotion expressed MATCHES the target emotion above "
176
+ "(1=not matching at all, 5=perfectly matching).\n"
177
+ " \"cat\": an object with probabilities (summing to 1.0) over the 5 perceived emotions: "
178
+ "{\"neutral\":_, \"happy\":_, \"sad\":_, \"angry\":_, \"surprised\":_}.\n"
179
+ " \"val\": valence (1=very negative, 5=very positive).\n"
180
+ " \"aro\": arousal (1=very calm, 5=very excited).\n"
181
+ " \"dom\": dominance (1=very submissive, 5=very dominant).\n\n"
182
+ "Example format: "
183
+ "{\"qmos\":3.5,\"emos\":4.0,"
184
+ "\"cat\":{\"neutral\":0.1,\"happy\":0.7,\"sad\":0.0,\"angry\":0.1,\"surprised\":0.1},"
185
+ "\"val\":4.0,\"aro\":3.5,\"dom\":3.0}\n"
186
+ "Respond with ONLY the JSON."
187
+ )
188
+
189
+ # %% [markdown]
190
+ # ## 3b. (tùy chọn) Few-shot — lấy K ví dụ audio có nhãn vàng từ train.csv
191
+ #
192
+ # Bật khi `SHOT_MODE="few_shot"`. Mỗi ví dụ = 1 audio train + nhãn vàng (gộp TB theo wav). Tốn thêm token.
193
+
194
+ # %%
195
+ few_shot_examples = [] # list[(audio_b64, audio_bytes, gold_json_str)]
196
+
197
+ def _agg_train_labels():
198
+ """Gộp train.csv (sep='|') theo wavID → nhãn vàng trung bình; CAT = tỉ lệ vote."""
199
+ import pandas as pd
200
+ df = pd.read_csv(TRAIN_CSV, sep="|")
201
+ rows = {}
202
+ for wav, g in df.groupby("wavID"):
203
+ votes = np.zeros(5, np.float32)
204
+ for cell in g["emoCat"].astype(str):
205
+ for tok in cell.split(","):
206
+ e = norm_emotion(tok)
207
+ if e in EMOTIONS5:
208
+ votes[EMOTIONS5.index(e)] += 1
209
+ s = votes.sum()
210
+ cat = (votes / s) if s > 0 else np.full(5, 0.2, np.float32)
211
+ rows[stem(wav)] = dict(
212
+ qmos=float(g["qMOS"].mean()), emos=float(g["eMOS"].mean()),
213
+ val=float(g["val"].mean()), aro=float(g["aro"].mean()), dom=float(g["dom"].mean()),
214
+ cat={EMOTIONS5[i]: round(float(cat[i]), 4) for i in range(5)},
215
+ )
216
+ return rows
217
+
218
+ def build_few_shot():
219
+ if SHOT_MODE != "few_shot":
220
+ return
221
+ labels = _agg_train_labels()
222
+ picked = list(labels.keys())[:FEW_K]
223
+ for sid in picked:
224
+ wavp = os.path.join(WAV_DIR, sid + ".wav")
225
+ if not os.path.exists(wavp):
226
+ continue
227
+ raw, b64 = load_wav_bytes(wavp)
228
+ gold = labels[sid]
229
+ gold_json = json.dumps({
230
+ "qmos": round(gold["qmos"], 2), "emos": round(gold["emos"], 2),
231
+ "cat": gold["cat"], "val": round(gold["val"], 2),
232
+ "aro": round(gold["aro"], 2), "dom": round(gold["dom"], 2),
233
+ })
234
+ few_shot_examples.append((b64, raw, gold_json))
235
+ print(f"Few-shot: {len(few_shot_examples)} ví dụ")
236
+
237
+ build_few_shot()
238
+
239
+ # %% [markdown]
240
+ # ## 4. Gọi API — trừu tượng hóa provider (gemini / openai)
241
+ #
242
+ # Mỗi provider tự dựng message của nó (kèm few-shot nếu có). Trả về **text thô** để parse ở mục 5.
243
+
244
+ # %%
245
+ _client = {"gemini": None, "openai": None}
246
+
247
+ def _gemini_client():
248
+ if _client["gemini"] is None:
249
+ from google import genai
250
+ _client["gemini"] = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
251
+ return _client["gemini"]
252
+
253
+ def _openai_client():
254
+ if _client["openai"] is None:
255
+ from openai import OpenAI
256
+ _client["openai"] = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
257
+ return _client["openai"]
258
+
259
+ def call_gemini(audio_b64, audio_bytes, prompt):
260
+ from google.genai import types
261
+ client = _gemini_client()
262
+ contents = []
263
+ for ex_b64, ex_bytes, ex_gold in few_shot_examples: # few-shot: audio ví dụ + nhãn vàng
264
+ contents.append(types.Content(role="user", parts=[
265
+ types.Part.from_bytes(data=ex_bytes, mime_type="audio/wav"),
266
+ types.Part.from_text(text=build_prompt(None)),
267
+ ]))
268
+ contents.append(types.Content(role="model", parts=[types.Part.from_text(text=ex_gold)]))
269
+ contents.append(types.Content(role="user", parts=[
270
+ types.Part.from_bytes(data=audio_bytes, mime_type="audio/wav"),
271
+ types.Part.from_text(text=prompt),
272
+ ]))
273
+ resp = client.models.generate_content(
274
+ model=GEMINI_MODEL, contents=contents,
275
+ config=types.GenerateContentConfig(
276
+ system_instruction=SYSTEM_INSTRUCTION, temperature=TEMPERATURE),
277
+ )
278
+ return resp.text
279
+
280
+ def call_openai(audio_b64, audio_bytes, prompt):
281
+ client = _openai_client()
282
+ messages = [{"role": "system", "content": SYSTEM_INSTRUCTION}]
283
+ for ex_b64, ex_bytes, ex_gold in few_shot_examples:
284
+ messages.append({"role": "user", "content": [
285
+ {"type": "text", "text": build_prompt(None)},
286
+ {"type": "input_audio", "input_audio": {"data": ex_b64, "format": "wav"}},
287
+ ]})
288
+ messages.append({"role": "assistant", "content": ex_gold})
289
+ messages.append({"role": "user", "content": [
290
+ {"type": "text", "text": prompt},
291
+ {"type": "input_audio", "input_audio": {"data": audio_b64, "format": "wav"}},
292
+ ]})
293
+ resp = client.chat.completions.create(
294
+ model=OPENAI_MODEL, messages=messages, temperature=TEMPERATURE,
295
+ modalities=["text"],
296
+ )
297
+ return resp.choices[0].message.content
298
+
299
+ def call_llm(audio_b64, audio_bytes, prompt):
300
+ return call_gemini(audio_b64, audio_bytes, prompt) if PROVIDER == "gemini" \
301
+ else call_openai(audio_b64, audio_bytes, prompt)
302
+
303
+ # %% [markdown]
304
+ # ## 5. Parse JSON chịu lỗi → 6 cột; clamp [1,5]; chuẩn hóa CAT
305
+
306
+ # %%
307
+ def _clamp(x, lo=1.0, hi=5.0, default=3.0):
308
+ try:
309
+ v = float(x)
310
+ except Exception:
311
+ return default
312
+ return max(lo, min(hi, v))
313
+
314
+ def parse_response(text):
315
+ """text thô LLM → dict {qmos,emos,cat5(list theo EMOTIONS5),val,aro,dom} hoặc None nếu hỏng."""
316
+ if not text:
317
+ return None
318
+ m = re.search(r"\{.*\}", text, re.DOTALL) # trích khối JSON đầu tiên
319
+ if not m:
320
+ return None
321
+ try:
322
+ d = json.loads(m.group(0))
323
+ except Exception:
324
+ return None
325
+ cat_in = d.get("cat", {}) or {}
326
+ cat = np.zeros(5, np.float32)
327
+ for k, v in cat_in.items():
328
+ e = norm_emotion(k)
329
+ if e in EMOTIONS5:
330
+ try:
331
+ cat[EMOTIONS5.index(e)] = max(0.0, float(v))
332
+ except Exception:
333
+ pass
334
+ cat = cat / cat.sum() if cat.sum() > 0 else np.full(5, 0.2, np.float32)
335
+ return dict(
336
+ qmos=_clamp(d.get("qmos")), emos=_clamp(d.get("emos")),
337
+ cat5=cat.tolist(),
338
+ val=_clamp(d.get("val")), aro=_clamp(d.get("aro")), dom=_clamp(d.get("dom")),
339
+ )
340
+
341
+ # %% [markdown]
342
+ # ## 6. Vòng chấm có CACHE + RESUME (KHÔNG gọi lại wav đã có trong cache)
343
+
344
+ # %%
345
+ def load_cache():
346
+ done = {}
347
+ if os.path.exists(CACHE_PATH):
348
+ with open(CACHE_PATH, encoding="utf-8") as f:
349
+ for ln in f:
350
+ try:
351
+ r = json.loads(ln)
352
+ done[r["stem"]] = r
353
+ except Exception:
354
+ continue
355
+ return done
356
+
357
+ def score_one(name):
358
+ """Gọi LLM cho 1 wav, retry; trả record dict {stem,name,raw,parsed}."""
359
+ sid = stem(name)
360
+ wavp = os.path.join(WAV_DIR, name if name.endswith(".wav") else name + ".wav")
361
+ tgt = target_map.get(sid)
362
+ prompt = build_prompt(tgt)
363
+ last_err = None
364
+ for attempt in range(MAX_RETRY):
365
+ try:
366
+ _, b64 = (None, None)
367
+ raw_bytes, b64 = load_wav_bytes(wavp)
368
+ text = call_llm(b64, raw_bytes, prompt)
369
+ parsed = parse_response(text)
370
+ if parsed is not None:
371
+ return dict(stem=sid, name=name, raw=text, parsed=parsed, ok=True)
372
+ last_err = "parse_fail"
373
+ except Exception as e:
374
+ last_err = str(e)
375
+ time.sleep(RETRY_SLEEP * (attempt + 1))
376
+ return dict(stem=sid, name=name, raw=None, parsed=None, ok=False, err=last_err)
377
+
378
+ def run_scoring():
379
+ from concurrent.futures import ThreadPoolExecutor, as_completed
380
+ done = load_cache()
381
+ todo = [n for n in dev_names if stem(n) not in done]
382
+ print(f"Cache có {len(done)} | cần chấm thêm {len(todo)} | ước lượng {len(todo)} call API")
383
+ if not todo:
384
+ return done
385
+ n_ok = n_bad = 0
386
+ with open(CACHE_PATH, "a", encoding="utf-8") as fout, \
387
+ ThreadPoolExecutor(max_workers=WORKERS) as ex:
388
+ futs = {ex.submit(score_one, n): n for n in todo}
389
+ for i, fut in enumerate(as_completed(futs), 1):
390
+ rec = fut.result()
391
+ fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
392
+ fout.flush()
393
+ done[rec["stem"]] = rec
394
+ n_ok += int(rec["ok"]); n_bad += int(not rec["ok"])
395
+ if i % 50 == 0 or i == len(todo):
396
+ print(f" {i}/{len(todo)} | ok={n_ok} bad={n_bad}")
397
+ if n_bad:
398
+ print(f"⚠️ {n_bad} wav hỏng (parse/API) → sẽ điền mặc định ở build_answer.")
399
+ return done
400
+
401
+ records = run_scoring()
402
+
403
+ # %% [markdown]
404
+ # ## 7. Ráp `answer.txt` 6 cột (giống exp07) + validate + zip
405
+
406
+ # %%
407
+ def fmt_cat(probs5):
408
+ return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
409
+
410
+ def build_answer(out_path):
411
+ n_real = n_default = 0
412
+ with open(out_path, "w") as f:
413
+ f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
414
+ for name in dev_names:
415
+ sid = stem(name)
416
+ rec = records.get(sid)
417
+ p = rec["parsed"] if (rec and rec.get("parsed")) else None
418
+ if p is None:
419
+ qmos = emos = val = aro = dom = 3.0
420
+ cat5 = [0.2] * 5
421
+ n_default += 1
422
+ else:
423
+ qmos, emos = p["qmos"], p["emos"]
424
+ val, aro, dom = p["val"], p["aro"], p["dom"]
425
+ cat5 = p["cat5"]; n_real += 1
426
+ f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
427
+ f"{val:.6g},{aro:.6g},{dom:.6g}\n")
428
+ print(f"Ghi {len(dev_names)} dòng → {out_path} | LLM thật {n_real}, mặc định {n_default}")
429
+
430
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
431
+ build_answer(answer_path)
432
+
433
+ def validate(path):
434
+ import csv
435
+ with open(path) as f:
436
+ rows = list(csv.reader(f))
437
+ header = rows[0]
438
+ assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
439
+ for i, r in enumerate(rows[1:], 2):
440
+ assert len(r) == len(header), f"Dòng {i} sai số cột"
441
+ print(f"OK: {len(rows)-1} dòng, header = {header}")
442
+
443
+ validate(answer_path)
444
+ # !cd /kaggle/working && zip -j submission_track2_exp16.zip answer.txt && unzip -l submission_track2_exp16.zip
445
+ print("Sẵn sàng nộp: /kaggle/working/submission_track2_exp16.zip")
446
+
447
+ # %% [markdown]
448
+ # ## 8. (tùy chọn) Ensemble muộn: trộn THỨ HẠNG điểm LLM + hệ trained
449
+ #
450
+ # Trung bình rank của exp16 với một `answer.txt` đã có (vd bản trộn cột exp07+exp08) cho từng cột số.
451
+ # Đa dạng nguồn → có thể giảm nhiễu. CHỈ chạy khi có sẵn file kia (đặt đường dẫn rồi bỏ comment).
452
+
453
+ # %%
454
+ def ensemble_rank_average(answer_a, answer_b, out_path):
455
+ """Trộn 2 answer.txt theo TRUNG BÌNH THỨ HẠNG cho 5 cột số (QMOS/EMOS/VAL/ARO/DOM); CAT lấy theo A."""
456
+ import pandas as pd
457
+ num_cols = ["QMOS", "EMOS", "VAL", "ARO", "DOM"]
458
+ A = pd.read_csv(answer_a); B = pd.read_csv(answer_b)
459
+ A = A.set_index("wav"); B = B.set_index("wav").reindex(A.index)
460
+ out = A.copy()
461
+ for c in num_cols:
462
+ if c in A.columns and c in B.columns:
463
+ ra = A[c].rank(); rb = B[c].rank()
464
+ out[c] = ((ra + rb) / 2.0) # SRCC bất biến với scale → để nguyên rank trung bình
465
+ out.reset_index().to_csv(out_path, index=False)
466
+ print("Ensemble →", out_path)
467
+
468
+ # ensemble_rank_average(answer_path,
469
+ # "/kaggle/input/.../exp_mix_q07_emo08/answer.txt",
470
+ # os.path.join(OUT_DIR, "answer_ens.txt"))
471
+
472
+ # %% [markdown]
473
+ # ## Ghi chú nộp & paper
474
+ # - Nộp: My Submissions → **Track 2** (bỏ chọn track khác) → `submission_track2_exp16.zip` → đọc SRCC 6 cột.
475
+ # - **Bảng A (paper):** đặt SRCC exp16 (gemini/openai, zero-shot) cạnh exp07 (QMOS 0.548) + exp08
476
+ # (EMOS 0.811 · CAT 0.133 · VAD 0.659/0.793/0.751). Kỳ vọng: LLM khá ở EMOS/CAT, yếu ở QMOS.
477
+ # - **Bảng B:** chạy lại `SHOT_MODE="few_shot"` (1 provider) → so zero vs few-shot.
478
+ # - **Cache:** Save Version để giữ `exp16_llm_cache/*.jsonl` (không trả tiền lại). Lưu thành Kaggle
479
+ # Dataset nếu muốn dùng cho eval phase.
480
+ # - **Khai báo external resource** (API thương mại Gemini/OpenAI) trong `12_system_description.md`.
track2/track2_baseline.ipynb ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": "# VMC2026 Track 2 — Baseline Pipeline (Kaggle)\n\nQMOS (SpeechMOS) + EmoCat (emotion2vec) + **EMOS (emotion2vec target-prob, mặc định offline)** → gộp `answer.txt`.\n\n**Trước khi chạy:** Accelerator = **GPU T4**, Internet = **On**.\n- **+ Add Input** → tab Datasets → dataset Track 2 đã upload (Kaggle tự giải nén → có thư mục `vmc2026-track2/`).\n- Với mặc định `EMOS_METHOD='emotion2vec'`: **KHÔNG cần** `GEMINI_API_KEY`. Chỉ cần Secrets khi đổi sang `'gemini'` (để có thêm VAD).\n\nChạy được ngay: **QMOS + EmoCat + EMOS** (chỉ cần wav + `metadata.csv` chứa cảm xúc target).\n\n> ⚠️ Train phase: dự đoán tập **DEV** (`sets/dev.scp`, ~2730 mẫu). Thư mục `wav/` có cả train+dev nên KHÔNG glob hết — chỉ lấy đúng dev.scp."
7
+ },
8
+ {
9
+ "cell_type": "markdown",
10
+ "metadata": {},
11
+ "source": [
12
+ "## 0. Config — SỬA Ở ĐÂY"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": "import os, glob\n\n# ── Data Track 2 trên Kaggle (dataset đã upload, KHÔNG có thư mục con lồng) ──\nDATA_ROOT = '/kaggle/input/vmc2026-track2-full' # << slug dataset bạn upload\nWAV_DIR = f'{DATA_ROOT}/wav'\nMETADATA_CSV = f'{DATA_ROOT}/metadata.csv' # wavID|emotion|transcript (KHÔNG header)\nDEV_SCP = f'{DATA_ROOT}/sets/dev.scp' # danh sách wav tập DEV (tập cần nộp ở train phase)\n\n# Test nhanh trên ESD: trỏ WAV_DIR vào ESD, đặt DEV_SCP=None và METADATA_CSV=None.\n# WAV_DIR = '/kaggle/input/datasets/nguyenthanhlim/emotional-speech-dataset-esd/Emotion Speech Dataset'\n# DEV_SCP = None; METADATA_CSV = None\n\nLIMIT = 20 # << 20 = chạy THỬ nhanh. Đổi None để chạy TOÀN BỘ DEV rồi nộp.\n\n# ── Cách tính EMOS ──────────────────────────────────────────────────────────\n# 'emotion2vec': OFFLINE, MIỄN PHÍ (exp01, khuyến nghị) — P(cảm xúc target) từ emotion2vec → scale 1–5.\n# 'gemini' : LLM-as-judge qua Gemini API (cần GEMINI_API_KEY, tốn phí). Chỉ cách này có VAD.\nEMOS_METHOD = 'emotion2vec'\n\nOUT_DIR = '/kaggle/working'\nRUN_QMOS, RUN_EMOCAT = True, True\n_have_meta = bool(METADATA_CSV) and os.path.exists(METADATA_CSV)\nRUN_EMOS = _have_meta # cả 2 cách đều cần target từ metadata\nRUN_VAD = _have_meta and EMOS_METHOD == 'gemini' # VAD chỉ có ở Gemini\nEMOTIONS5 = ['angry', 'happy', 'neutral', 'sad', 'surprised']\n\n# Chuẩn hóa nhãn cảm xúc target (metadata) → đúng 1 trong 5 lớp của emotion2vec.\n_EMO_ALIAS = {'angry':'angry','anger':'angry','happy':'happy','happiness':'happy','joy':'happy',\n 'neutral':'neutral','calm':'neutral','sad':'sad','sadness':'sad',\n 'surprise':'surprised','surprised':'surprised','surprising':'surprised'}\ndef norm_emotion(label):\n key = str(label).strip().lower()\n return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n\ndef list_wavs(d):\n # Có DEV_SCP → đọc danh sách tên file tập DEV (wav nằm phẳng trong wav/).\n # Không có → quét đệ quy mọi .wav (chế độ test ESD, lồng speaker/emotion).\n if DEV_SCP and os.path.exists(DEV_SCP):\n with open(DEV_SCP) as f:\n names = [ln.strip() for ln in f if ln.strip()]\n wavs = [os.path.join(d, n) for n in names]\n else:\n wavs = sorted(glob.glob(os.path.join(d, '**', '*.wav'), recursive=True))\n return wavs[:LIMIT] if LIMIT else wavs\n\nprint('WAV_DIR:', WAV_DIR, '| EMOS_METHOD:', EMOS_METHOD)\nprint('Số wav:', len(list_wavs(WAV_DIR)) if os.path.isdir(WAV_DIR) else '(chưa thấy thư mục)')\nprint('Chế độ DEV (dev.scp):', bool(DEV_SCP and os.path.exists(DEV_SCP)))\nif METADATA_CSV and os.path.exists(METADATA_CSV) and DEV_SCP and os.path.exists(DEV_SCP):\n n_meta = sum(1 for _ in open(METADATA_CSV))\n n_dev = sum(1 for _ in open(DEV_SCP))\n print(f'metadata.csv: {n_meta} dòng | dev.scp: {n_dev} dòng')"
21
+ },
22
+ {
23
+ "cell_type": "markdown",
24
+ "metadata": {},
25
+ "source": [
26
+ "## 1. Cài đặt"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "!pip install -q speechmos funasr librosa soundfile pandas google-genai loguru tqdm"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {},
41
+ "source": [
42
+ "## 2. QMOS — SpeechMOS (UTMOS, không cần fairseq)"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": "def run_qmos(wav_dir):\n import torch, librosa\n dev = 'cuda' if torch.cuda.is_available() else 'cpu'\n predictor = torch.hub.load('tarepan/SpeechMOS:v1.2.0', 'utmos22_strong', trust_repo=True).to(dev) # << GPU\n print('QMOS device:', dev)\n scores, missing = {}, 0\n for w in list_wavs(wav_dir): # w là đường dẫn đầy đủ\n if not os.path.exists(w): # mẫu ESD/DailyTalk chưa lấy ngoài → bỏ qua, không crash\n missing += 1; continue\n wave, _ = librosa.load(w, sr=16000, mono=True)\n wave_t = torch.from_numpy(wave).unsqueeze(0).to(dev) # << đưa input lên GPU\n scores[w] = float(predictor(wave_t, sr=16000).mean().item())\n if missing: print(f'[QMOS] Bỏ qua {missing} file thiếu (chưa có ESD/DailyTalk) → điểm mặc định.')\n return scores\n\nqmos_scores = run_qmos(WAV_DIR) if RUN_QMOS else {}\nprint('QMOS xong:', len(qmos_scores))\nlist(qmos_scores.items())[:3]"
51
+ },
52
+ {
53
+ "cell_type": "markdown",
54
+ "metadata": {},
55
+ "source": [
56
+ "## 3. EmoCat — emotion2vec+ large\n",
57
+ "Đã sửa bug bản gốc + lọc 5 lớp + chuẩn hóa tổng = 1."
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": "def run_emocat(wav_dir):\n import torch\n from funasr import AutoModel\n dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n model = AutoModel(model='iic/emotion2vec_plus_large', hub='hf', device=dev) # << chạy GPU\n print('EmoCat device:', dev)\n results, missing = {}, 0\n for w in list_wavs(wav_dir): # w là đường dẫn đầy đủ\n if not os.path.exists(w): # mẫu ESD/DailyTalk chưa lấy ngoài → bỏ qua\n missing += 1; continue\n rec = model.generate(w, granularity='utterance', extract_embedding=False)\n probs = {e: 0.0 for e in EMOTIONS5}\n for lab, sc in zip(rec[0]['labels'], rec[0]['scores']):\n name = lab.split('/')[-1]\n if name in probs:\n probs[name] = float(sc)\n total = sum(probs.values())\n if total > 0:\n probs = {k: v / total for k, v in probs.items()}\n results[w] = probs\n if missing: print(f'[EmoCat] Bỏ qua {missing} file thiếu (chưa có ESD/DailyTalk) → phân bố mặc định.')\n return results\n\nemocat_probs = run_emocat(WAV_DIR) if RUN_EMOCAT else {}\nprint('EmoCat xong:', len(emocat_probs))\nlist(emocat_probs.items())[:2]"
66
+ },
67
+ {
68
+ "cell_type": "markdown",
69
+ "metadata": {},
70
+ "source": "## 4. EMOS — emotion2vec target-prob (mặc định) hoặc Gemini\n**emotion2vec (exp01, offline):** lấy P(cảm xúc target) từ emotion2vec (đã tính ở cell EmoCat), scale [0,1]→[1,5]. Chấm đủ 2.730 mẫu, KHÔNG cần API. SRCC chỉ quan tâm thứ hạng nên scale tuyến tính không đổi tương quan.\n\n**Gemini (`EMOS_METHOD='gemini'`):** LLM-as-judge, cần `GEMINI_API_KEY` + credit; tự lọc metadata về DEV để đỡ tốn. Chỉ cách này có VAD."
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": "emos_scores, vad_scores = {}, {} # key = TÊN FILE wav (uttID, có .wav)\n\n# Đọc cảm xúc target từ metadata.csv → {stem: emotion_chuẩn}\ndef load_target_emotions():\n tgt = {}\n if not (METADATA_CSV and os.path.exists(METADATA_CSV)):\n return tgt\n with open(METADATA_CSV, encoding='utf-8') as f:\n for ln in f:\n parts = ln.strip().split('|')\n if len(parts) >= 2:\n stem = os.path.splitext(os.path.basename(parts[0]))[0]\n tgt[stem] = norm_emotion(parts[1])\n return tgt\n\ntarget_map = load_target_emotions()\nprint('Nhãn cảm xúc target đọc được:', len(target_map))\n\nif RUN_EMOS and EMOS_METHOD == 'emotion2vec':\n # ── EMOS OFFLINE: P(cảm xúc target) từ emotion2vec (cell EmoCat), scale [0,1]→[1,5] ──\n assert RUN_EMOCAT and emocat_probs, 'EMOS theo emotion2vec cần chạy cell EmoCat (mục 3) trước.'\n miss_t = miss_p = 0\n for w in list_wavs(WAV_DIR):\n name = os.path.basename(w)\n tgt = target_map.get(os.path.splitext(name)[0])\n probs = emocat_probs.get(w)\n if tgt is None:\n miss_t += 1; continue\n if not probs:\n miss_p += 1; continue\n emos_scores[name] = 1.0 + 4.0 * probs.get(tgt, 0.0) # p=0→1 điểm, p=1→5 điểm\n if miss_t: print(f'[EMOS-e2v] {miss_t} mẫu thiếu nhãn target → mặc định 3.')\n if miss_p: print(f'[EMOS-e2v] {miss_p} mẫu thiếu prob emotion2vec → mặc định 3.')\n print(f'✅ EMOS (emotion2vec) cho {len(emos_scores)} mẫu — không cần API.')\n\nelif RUN_EMOS or RUN_VAD: # EMOS_METHOD == 'gemini'\n try:\n from kaggle_secrets import UserSecretsClient\n os.environ['GEMINI_API_KEY'] = UserSecretsClient().get_secret('GEMINI_API_KEY')\n print('Đã nạp GEMINI_API_KEY từ Secrets')\n except Exception as e:\n print('Chưa nạp được key:', e)\n\n # ── Lọc metadata.csv → CHỈ giữ mẫu thuộc DEV (tránh trả tiền Gemini cho mẫu train) ──\n dev_stems = {os.path.splitext(n.strip())[0] for n in open(DEV_SCP) if n.strip()}\n META_DEV = '/kaggle/working/metadata_dev.csv'\n kept = 0\n with open(METADATA_CSV) as fin, open(META_DEV, 'w') as fout:\n for line in fin:\n if not line.strip():\n continue\n stem = os.path.splitext(os.path.basename(line.split('|')[0].strip()))[0]\n if stem in dev_stems:\n fout.write(line); kept += 1\n print(f'metadata_dev.csv: {kept} dòng (kỳ vọng ~{len(dev_stems)})')\n\n GEMINI_ROWS = f'--end-row {LIMIT}' if LIMIT else ''\n !git clone -q https://github.com/voicemos-challenge/vmc2026-baselines.git /kaggle/working/vmc2026-baselines\n !cd /kaggle/working/vmc2026-baselines/track2/EMOS && python Gemini_EMOS.py --metadata-path $META_DEV --base-path $WAV_DIR --output-file /kaggle/working/emos.csv --workers 4 --resume $GEMINI_ROWS\n !cd /kaggle/working/vmc2026-baselines/track2/VAD && python Gemini_VAD.py --metadata-path $META_DEV --base-path $WAV_DIR --output-file /kaggle/working/vad.csv --workers 4 --resume $GEMINI_ROWS\n\n import pandas as pd\n if os.path.exists('/kaggle/working/emos.csv'):\n d = pd.read_csv('/kaggle/working/emos.csv'); emos_scores = dict(zip(d['uttID'], d['emos']))\n if os.path.exists('/kaggle/working/vad.csv'):\n d = pd.read_csv('/kaggle/working/vad.csv') # cột chuẩn: uttID, val, aro, dom\n for _, r in d.iterrows():\n vad_scores[r['uttID']] = (r['val'], r['aro'], r['dom'])\n\n if emos_scores:\n dev_bases = {os.path.basename(w) for w in list_wavs(WAV_DIR)}\n if not (set(emos_scores) & dev_bases):\n print('⚠️ KEY LỆCH: uttID không khớp tên file dev → EMOS/VAD sẽ về mặc định!')\n else:\n print('✅ Key khớp — EMOS/VAD sẽ gộp đúng.')\n\nprint('EMOS:', len(emos_scores), '| VAD:', len(vad_scores))"
78
+ },
79
+ {
80
+ "cell_type": "markdown",
81
+ "metadata": {},
82
+ "source": [
83
+ "## 5. Gộp answer.txt (tự bỏ cột thiếu)"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": "def fmt_cat(p):\n return '|'.join(f'{e}:{p[e]:.6g}' for e in EMOTIONS5)\n\ndef build_answer(out_path):\n wavs = list_wavs(WAV_DIR)\n have_cat = RUN_EMOCAT and len(emocat_probs) > 0\n have_vad = RUN_VAD and len(vad_scores) > 0\n cols = ['wav', 'QMOS', 'EMOS']\n if have_cat: cols.append('CAT')\n if have_vad: cols += ['VAL', 'ARO', 'DOM']\n with open(out_path, 'w') as f:\n f.write(','.join(cols) + '\\n')\n for w in wavs:\n name = os.path.basename(w) # tên file = cột wav & key của emos/vad\n row = [name, f\"{qmos_scores.get(w, 3.0):.6g}\", str(emos_scores.get(name, 3))]\n if have_cat: row.append(fmt_cat(emocat_probs.get(w, {e: 0.2 for e in EMOTIONS5})))\n if have_vad:\n v = vad_scores.get(name, (3, 3, 3)); row += [str(v[0]), str(v[1]), str(v[2])]\n f.write(','.join(row) + '\\n')\n print(f'Ghi {len(wavs)} dòng → {out_path} | cột: {cols}')\n\nanswer_path = os.path.join(OUT_DIR, 'answer.txt')\nbuild_answer(answer_path)\n!head -3 {answer_path}"
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {},
96
+ "source": [
97
+ "## 6. Validate + zip"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "import csv\n",
107
+ "with open(answer_path) as f:\n",
108
+ " rows = list(csv.reader(f))\n",
109
+ "header = rows[0]\n",
110
+ "assert header[0] == 'wav' and 'QMOS' in header and 'EMOS' in header, 'Header sai'\n",
111
+ "for i, r in enumerate(rows[1:], 2):\n",
112
+ " assert len(r) == len(header), f'Dòng {i} sai số cột'\n",
113
+ "print(f'OK: {len(rows)-1} dòng, header = {header}')\n",
114
+ "!cd /kaggle/working && zip -j submission_track2.zip answer.txt && unzip -l submission_track2.zip"
115
+ ]
116
+ }
117
+ ],
118
+ "metadata": {
119
+ "kernelspec": {
120
+ "display_name": "Python 3",
121
+ "language": "python",
122
+ "name": "python3"
123
+ },
124
+ "language_info": {
125
+ "name": "python"
126
+ }
127
+ },
128
+ "nbformat": 4,
129
+ "nbformat_minor": 5
130
+ }
track2/track2_baseline_pipeline.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — Baseline Pipeline (Kaggle)
3
+ #
4
+ # Chạy 4 baseline → gộp thành `answer.txt` đúng chuẩn nộp CodaBench.
5
+ #
6
+ # | Sub-task | Baseline | GPU | Cần label? |
7
+ # |---|---|---|---|
8
+ # | QMOS | SpeechMOS (UTMOS bản pip) | có (nhẹ) | không (chỉ cần wav) |
9
+ # | EmoCat (CAT) | emotion2vec+ large (funasr) | có (nhẹ) | không (chỉ cần wav) |
10
+ # | EMOS | **emotion2vec target-prob** (mặc định, offline) HOẶC Gemini | có (nhẹ) | cần `metadata.csv` (nhãn target) |
11
+ # | VAD | Gemini LLM-as-judge (chỉ khi EMOS_METHOD="gemini") | không | cần `metadata.csv` + API key |
12
+ #
13
+ # **Cách dùng trên Kaggle:**
14
+ # 1. Tạo Notebook, Settings → Accelerator = **GPU T4**, Internet = **On** (cần verify phone).
15
+ # 2. **+ Add Input** → chọn dataset Track 2 đã upload (Kaggle tự giải nén → có thư mục `vmc2026-track2/`).
16
+ # 3. Add-ons → Secrets: thêm `GEMINI_API_KEY` (cho EMOS/VAD).
17
+ # 4. Sửa `DATA_ROOT` ở cell 0 cho khớp slug dataset, rồi chạy lần lượt từng cell.
18
+ #
19
+ # Format đích `answer.txt`: `wav,QMOS,EMOS,CAT,VAL,ARO,DOM` — xem `08_track2_spec.md`.
20
+ # QMOS & EMOS bắt buộc; CAT/VAD tùy chọn. Có thể nộp tập con cột.
21
+ #
22
+ # > ⚠️ Ở **training phase** ta dự đoán cho tập **DEV** (`sets/dev.scp`, ~2730 mẫu) rồi nộp.
23
+ # > Thư mục `wav/` chứa cả train + dev nên KHÔNG glob hết — chỉ lấy đúng danh sách dev.scp.
24
+
25
+ # %% [markdown]
26
+ # ## 0. Cấu hình đường dẫn — SỬA Ở ĐÂY
27
+
28
+ # %%
29
+ import os, glob
30
+
31
+ # ── Data Track 2 trên Kaggle ────────────────────────────────────────────────
32
+ # Kaggle TỰ giải nén .tar.gz khi tạo Dataset → có sẵn thư mục `vmc2026-track2/`.
33
+ # Đổi <track2-data> thành slug dataset của bạn (xem thanh path bên phải khi Add Input).
34
+ DATA_ROOT = "/kaggle/input/<track2-data>/vmc2026-track2" # << SỬA slug
35
+ WAV_DIR = f"{DATA_ROOT}/wav"
36
+ METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # định dạng: wavID|emotion|transcript (KHÔNG header)
37
+ DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav của tập DEV (tập cần nộp ở train phase)
38
+
39
+ # Muốn test nhanh trên ESD (chưa có data thật): trỏ WAV_DIR vào ESD, đặt DEV_SCP=None, METADATA_CSV=None.
40
+ # WAV_DIR = "/kaggle/input/datasets/nguyenthanhlim/emotional-speech-dataset-esd/Emotion Speech Dataset"
41
+ # DEV_SCP = None; METADATA_CSV = None
42
+
43
+ LIMIT = None # << số nhỏ (vd 20) để chạy thử cho nhanh; None = chạy toàn bộ tập DEV
44
+
45
+ OUT_DIR = "/kaggle/working"
46
+
47
+ # ── Cách tính EMOS ──────────────────────────────────────────────────────────
48
+ # "emotion2vec": OFFLINE, MIỄN PHÍ (exp01, khuyến nghị) — lấy P(cảm xúc target) từ
49
+ # emotion2vec (model đã chạy cho CAT) rồi scale về thang 1–5. Không cần API.
50
+ # "gemini" : LLM-as-judge qua Gemini API (cần GEMINI_API_KEY, tốn phí). Chỉ cách này có VAD.
51
+ EMOS_METHOD = "emotion2vec"
52
+
53
+ _have_meta = bool(METADATA_CSV) and os.path.exists(METADATA_CSV) # cần nhãn cảm xúc target
54
+ RUN_QMOS = True
55
+ RUN_EMOCAT = True
56
+ RUN_EMOS = _have_meta # cả 2 cách đều cần target từ metadata
57
+ RUN_VAD = _have_meta and EMOS_METHOD == "gemini" # VAD chỉ có ở Gemini
58
+
59
+ EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
60
+
61
+ # Chuẩn hóa nhãn cảm xúc target (metadata) → đúng 1 trong 5 lớp của emotion2vec.
62
+ _EMO_ALIAS = {
63
+ "angry": "angry", "anger": "angry",
64
+ "happy": "happy", "happiness": "happy", "joy": "happy",
65
+ "neutral": "neutral", "calm": "neutral",
66
+ "sad": "sad", "sadness": "sad",
67
+ "surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
68
+ }
69
+
70
+ def norm_emotion(label):
71
+ """Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp."""
72
+ key = str(label).strip().lower()
73
+ return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
74
+
75
+
76
+ def list_wavs(d):
77
+ """Trả về list đường dẫn .wav đầy đủ cần dự đoán.
78
+ - Có DEV_SCP → đọc danh sách tên file tập DEV (đúng tập cần nộp, wav nằm phẳng trong wav/).
79
+ - Không có → quét đệ quy mọi .wav trong thư mục (chế độ test ESD, lồng speaker/emotion)."""
80
+ if DEV_SCP and os.path.exists(DEV_SCP):
81
+ with open(DEV_SCP) as f:
82
+ names = [ln.strip() for ln in f if ln.strip()]
83
+ wavs = [os.path.join(d, n) for n in names]
84
+ else:
85
+ wavs = sorted(glob.glob(os.path.join(d, "**", "*.wav"), recursive=True))
86
+ return wavs[:LIMIT] if LIMIT else wavs
87
+
88
+
89
+ print("WAV_DIR:", WAV_DIR)
90
+ print("Số wav:", len(list_wavs(WAV_DIR)) if os.path.isdir(WAV_DIR) else "(chưa thấy thư mục)")
91
+ print("Chế độ DEV (dev.scp):", bool(DEV_SCP and os.path.exists(DEV_SCP)))
92
+
93
+ # %% [markdown]
94
+ # ## 1. Cài đặt
95
+
96
+ # %%
97
+ # !pip install -q speechmos funasr librosa soundfile pandas google-genai loguru tqdm
98
+
99
+ # %% [markdown]
100
+ # ## 2. QMOS — SpeechMOS (UTMOS)
101
+ # Dùng SpeechMOS qua torch.hub (không cần fairseq). Output: dict {wav: score 1-5}.
102
+
103
+ # %%
104
+ def run_qmos(wav_dir):
105
+ import torch, librosa
106
+ # SpeechMOS yêu cầu 16kHz; input shape (Batch, Time); sr truyền dạng keyword.
107
+ predictor = torch.hub.load(
108
+ "tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True
109
+ )
110
+ scores = {}
111
+ missing = 0
112
+ for w in list_wavs(wav_dir):
113
+ if not os.path.exists(w): # mẫu ESD/DailyTalk chưa lấy ngoài → bỏ qua, không crash
114
+ missing += 1
115
+ continue
116
+ wave, _ = librosa.load(w, sr=16000, mono=True)
117
+ wave_t = torch.from_numpy(wave).unsqueeze(0) # (1, Time)
118
+ score = predictor(wave_t, sr=16000) # → tensor shape (1,)
119
+ scores[w] = float(score.mean().item())
120
+ if missing:
121
+ print(f"[QMOS] Bỏ qua {missing} file thiếu (chưa có ESD/DailyTalk) → sẽ nhận điểm mặc định.")
122
+ return scores
123
+
124
+
125
+ qmos_scores = run_qmos(WAV_DIR) if RUN_QMOS else {}
126
+ print("QMOS xong:", len(qmos_scores), "mẫu")
127
+ list(qmos_scores.items())[:3]
128
+
129
+ # %% [markdown]
130
+ # ## 3. EmoCat — emotion2vec+ large (funasr)
131
+ # Sửa bug bản gốc + lọc 5 lớp + **chuẩn hóa tổng = 1** (đúng format CAT).
132
+ # Output: dict {wav: {angry:p, happy:p, neutral:p, sad:p, surprised:p}}.
133
+
134
+ # %%
135
+ def run_emocat(wav_dir):
136
+ from funasr import AutoModel
137
+ model = AutoModel(model="iic/emotion2vec_plus_large", hub="hf")
138
+ results = {}
139
+ missing = 0
140
+ for w in list_wavs(wav_dir):
141
+ if not os.path.exists(w): # mẫu ESD/DailyTalk chưa lấy ngoài → bỏ qua
142
+ missing += 1
143
+ continue
144
+ rec = model.generate(
145
+ w,
146
+ granularity="utterance",
147
+ extract_embedding=False,
148
+ )
149
+ labels = rec[0]["labels"]
150
+ scores = rec[0]["scores"]
151
+ # gom điểm 5 lớp quan tâm (label có thể dạng "xx/angry")
152
+ probs = {e: 0.0 for e in EMOTIONS5}
153
+ for lab, sc in zip(labels, scores):
154
+ name = lab.split("/")[-1]
155
+ if name in probs:
156
+ probs[name] = float(sc)
157
+ total = sum(probs.values())
158
+ if total > 0: # chuẩn hóa lại trên 5 lớp
159
+ probs = {k: v / total for k, v in probs.items()}
160
+ results[w] = probs
161
+ if missing:
162
+ print(f"[EmoCat] Bỏ qua {missing} file thiếu (chưa có ESD/DailyTalk) → sẽ nhận phân bố mặc định.")
163
+ return results
164
+
165
+
166
+ emocat_probs = run_emocat(WAV_DIR) if RUN_EMOCAT else {}
167
+ print("EmoCat xong:", len(emocat_probs), "mẫu")
168
+ list(emocat_probs.items())[:2]
169
+
170
+ # %% [markdown]
171
+ # ## 4. EMOS — emotion2vec target-prob (mặc định) hoặc Gemini
172
+ # **Cách emotion2vec (exp01, offline):** với mỗi wav, lấy xác suất emotion2vec gán cho ĐÚNG cảm xúc
173
+ # target (đọc từ `metadata.csv`), scale [0,1] → [1,5]. Vì EMOS chấm bằng SRCC (thứ hạng) nên scale
174
+ # tuyến tính không đổi tương quan — chỉ cần thứ tự đúng. KHÔNG cần train, KHÔNG cần API.
175
+ # **Cách Gemini:** gọi script baseline gốc (cần `GEMINI_API_KEY`); chỉ cách này mới có VAD.
176
+ # `metadata.csv` dạng `wavID|emotion|transcript` (không header); `emotion` = cảm xúc target.
177
+
178
+ # %%
179
+ def load_target_emotions():
180
+ """Đọc metadata.csv (wavID|emotion|transcript, không header) → {stem: emotion_chuẩn}."""
181
+ tgt = {}
182
+ if not (METADATA_CSV and os.path.exists(METADATA_CSV)):
183
+ return tgt
184
+ with open(METADATA_CSV, encoding="utf-8") as f:
185
+ for ln in f:
186
+ parts = ln.strip().split("|")
187
+ if len(parts) < 2:
188
+ continue
189
+ stem = os.path.splitext(os.path.basename(parts[0]))[0]
190
+ tgt[stem] = norm_emotion(parts[1])
191
+ return tgt
192
+
193
+
194
+ def run_emos_emotion2vec(wav_dir, target_map):
195
+ """EMOS offline = P(cảm xúc target) từ emotion2vec, scale [0,1] → [1,5].
196
+ Dùng lại emocat_probs (đã tính ở mục 3) nên KHÔNG tốn thêm tính toán."""
197
+ out, miss_tgt, miss_prob = {}, 0, 0
198
+ for w in list_wavs(wav_dir):
199
+ name = os.path.basename(w)
200
+ tgt = target_map.get(os.path.splitext(name)[0])
201
+ probs = emocat_probs.get(w)
202
+ if tgt is None:
203
+ miss_tgt += 1; continue
204
+ if not probs:
205
+ miss_prob += 1; continue
206
+ out[name] = 1.0 + 4.0 * probs.get(tgt, 0.0) # p=0→1 điểm, p=1→5 điểm
207
+ if miss_tgt: print(f"[EMOS-e2v] {miss_tgt} mẫu thiếu nhãn target → mặc định 3.")
208
+ if miss_prob: print(f"[EMOS-e2v] {miss_prob} mẫu thiếu prob emotion2vec → mặc định 3.")
209
+ return out
210
+
211
+
212
+ def setup_gemini_key():
213
+ try:
214
+ from kaggle_secrets import UserSecretsClient
215
+ os.environ["GEMINI_API_KEY"] = UserSecretsClient().get_secret("GEMINI_API_KEY")
216
+ print("Đã nạp GEMINI_API_KEY từ Kaggle Secrets")
217
+ except Exception as e:
218
+ print("Chưa nạp được key từ Secrets:", e, "→ set thủ công os.environ['GEMINI_API_KEY']")
219
+
220
+
221
+ emos_scores = {} # {uttID(tên file .wav): điểm EMOS}
222
+ vad_scores = {} # {uttID: (val, aro, dom)} — chỉ có khi dùng Gemini
223
+
224
+ target_map = load_target_emotions()
225
+ print("Nhãn cảm xúc target đọc được:", len(target_map))
226
+
227
+ if RUN_EMOS and EMOS_METHOD == "emotion2vec":
228
+ assert RUN_EMOCAT and emocat_probs, "EMOS theo emotion2vec cần EmoCat chạy trước (RUN_EMOCAT=True)."
229
+ emos_scores = run_emos_emotion2vec(WAV_DIR, target_map)
230
+
231
+ elif RUN_EMOS and EMOS_METHOD == "gemini":
232
+ setup_gemini_key()
233
+ # !git clone -q https://github.com/voicemos-challenge/vmc2026-baselines.git /kaggle/working/vmc2026-baselines
234
+ # Chạy (1-based, inclusive). Với eval lớn nên chia batch + giảm --workers do quota free tier.
235
+ # Gemini chỉ chấm các dòng metadata.csv trùng với DEV; cách đơn giản: chạy hết rồi lọc lại ở build_answer.
236
+ # !cd /kaggle/working/vmc2026-baselines/track2/EMOS && python Gemini_EMOS.py \
237
+ # --metadata-path {METADATA_CSV} --base-path {WAV_DIR} \
238
+ # --output-file /kaggle/working/emos.csv --workers 4 --resume
239
+ # !cd /kaggle/working/vmc2026-baselines/track2/VAD && python Gemini_VAD.py \
240
+ # --metadata-path {METADATA_CSV} --base-path {WAV_DIR} \
241
+ # --output-file /kaggle/working/vad.csv --workers 4 --resume
242
+ import pandas as pd
243
+ if os.path.exists("/kaggle/working/emos.csv"):
244
+ df = pd.read_csv("/kaggle/working/emos.csv")
245
+ emos_scores = dict(zip(df["uttID"], df["emos"]))
246
+ if os.path.exists("/kaggle/working/vad.csv"):
247
+ df = pd.read_csv("/kaggle/working/vad.csv")
248
+ # cột output VAD chuẩn của Gemini_VAD.py: uttID, val, aro, dom
249
+ for _, r in df.iterrows():
250
+ vad_scores[r["uttID"]] = (r["val"], r["aro"], r["dom"])
251
+
252
+ print("EMOS:", len(emos_scores), "| VAD:", len(vad_scores))
253
+
254
+ # %% [markdown]
255
+ # ## 5. Gộp thành `answer.txt`
256
+ # QMOS & EMOS bắt buộc. Tự bỏ cột nếu thiếu dữ liệu (nộp tập con hợp lệ).
257
+ # Lưu ý key: qmos/emocat theo path đầy đủ; emos/vad theo TÊN FILE → tra cứu bằng basename.
258
+
259
+ # %%
260
+ def fmt_cat(p):
261
+ return "|".join(f"{e}:{p[e]:.6g}" for e in EMOTIONS5)
262
+
263
+
264
+ def build_answer(out_path):
265
+ wavs = list_wavs(WAV_DIR)
266
+ have_emos = RUN_EMOS and len(emos_scores) > 0
267
+ have_cat = RUN_EMOCAT and len(emocat_probs) > 0
268
+ have_vad = RUN_VAD and len(vad_scores) > 0
269
+
270
+ cols = ["wav", "QMOS", "EMOS"] # QMOS+EMOS bắt buộc
271
+ if have_cat: cols.append("CAT")
272
+ if have_vad: cols += ["VAL", "ARO", "DOM"]
273
+
274
+ n = 0
275
+ with open(out_path, "w") as f:
276
+ f.write(",".join(cols) + "\n")
277
+ for w in wavs:
278
+ name = os.path.basename(w) # tên file = key của emos/vad, và là giá trị cột wav
279
+ row = [name,
280
+ f"{qmos_scores.get(w, 3.0):.6g}",
281
+ str(emos_scores.get(name, 3))]
282
+ if have_cat:
283
+ row.append(fmt_cat(emocat_probs.get(w, {e: 0.2 for e in EMOTIONS5})))
284
+ if have_vad:
285
+ v = vad_scores.get(name, (3, 3, 3))
286
+ row += [str(v[0]), str(v[1]), str(v[2])]
287
+ f.write(",".join(row) + "\n")
288
+ n += 1
289
+ print(f"Ghi {n} dòng → {out_path} | cột: {cols}")
290
+ return cols
291
+
292
+
293
+ answer_path = os.path.join(OUT_DIR, "answer.txt")
294
+ cols = build_answer(answer_path)
295
+
296
+ # %% [markdown]
297
+ # ## 6. Validate + đóng zip
298
+
299
+ # %%
300
+ def validate(path):
301
+ import csv
302
+ with open(path) as f:
303
+ rows = list(csv.reader(f))
304
+ header = rows[0]
305
+ assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
306
+ for i, r in enumerate(rows[1:], 2):
307
+ assert len(r) == len(header), f"Dòng {i} sai số cột"
308
+ print(f"OK: {len(rows)-1} dòng, header = {header}")
309
+
310
+
311
+ validate(answer_path)
312
+ # !cd /kaggle/working && zip -j submission_track2.zip answer.txt && unzip -l submission_track2.zip
313
+ print("Sẵn sàng nộp: /kaggle/working/submission_track2.zip (chứa answer.txt)")
314
+
315
+ # %% [markdown]
316
+ # ## Ghi chú
317
+ # - Nộp: My Submissions → chọn **Track 2**, **bỏ chọn** track khác → upload `submission_track2.zip`.
318
+ # - `metadata.csv` (wavID|emotion|transcript, không header) chứa nhãn cảm xúc target cho Gemini EMOS/VAD.
319
+ # - Train phase: dự đoán tập DEV (`sets/dev.scp`). `sets/train.csv` có nhãn người nghe để train mô hình riêng.
320
+ # - Quota Gemini free tier dễ hết với eval lớn → chia batch `--start-row/--end-row`, giảm `--workers`, dùng `--resume`.
321
+ # - Khi có data thật: sửa `DATA_ROOT` ở cell 0 rồi chạy lại từ đầu.
track2/track2_prepare_data.ipynb ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# VMC2026 Track 2 — Chuẩn bị data (gộp ESD + DailyTalk) trên Kaggle\n",
8
+ "\n",
9
+ "Gói Track 2 thiếu **1.417 mẫu giọng thật** (license tách ra):\n",
10
+ "- **sys006** = ESD (1.379 file) · **sys001** = DailyTalk (38 file)\n",
11
+ "\n",
12
+ "Notebook này: cài **SoX** + build **sv56** → gom đúng utterance từ ESD/DailyTalk →\n",
13
+ "**chuẩn hóa âm lượng** (giống mẫu TTS) → ráp vào `wav/` đủ **15.477 file**.\n",
14
+ "\n",
15
+ "### Cách dùng\n",
16
+ "1. Settings → **Internet = On** (cần tải/biên dịch sv56). GPU không bắt buộc.\n",
17
+ "2. **+ Add Input** 3 dataset:\n",
18
+ " - Gói Track 2 (`vmc2026_track2_..._v3.tar.gz` — Kaggle tự giải nén ra `vmc2026-track2/`).\n",
19
+ " - ESD: `Emotional Speech Dataset (ESD).zip` (Kaggle tự giải nén ra `Emotion Speech Dataset/`).\n",
20
+ " - DailyTalk: `dailytalk.zip` (giải nén ra `dailytalk/data/...`).\n",
21
+ "3. **Run All**. Xong → **Save Version** (Commit) để lưu `wav/` ra output.\n",
22
+ "4. Từ output đó → **Create Dataset** → dùng làm input cho notebook train/baseline.\n",
23
+ "\n",
24
+ "> Notebook tự dò vị trí ESD/DailyTalk dù Kaggle giải nén ra thư mục tên gì."
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "metadata": {},
30
+ "source": [
31
+ "## 0. Tìm gói Track 2 + copy ra thư mục ghi được"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "import os, glob, shutil, subprocess\n",
41
+ "\n",
42
+ "# Tự dò thư mục vmc2026-track2 trong mọi input đã add.\n",
43
+ "_cands = glob.glob(\"/kaggle/input/*/vmc2026-track2\") + glob.glob(\"/kaggle/input/**/vmc2026-track2\", recursive=True)\n",
44
+ "TRACK2_SRC = _cands[0] if _cands else None\n",
45
+ "assert TRACK2_SRC, \"Không thấy thư mục vmc2026-track2 — đã Add Input gói Track 2 chưa?\"\n",
46
+ "\n",
47
+ "WORK = \"/kaggle/working/vmc2026-track2\" # bản ghi được (input là read-only)\n",
48
+ "print(\"Track2 source :\", TRACK2_SRC)\n",
49
+ "\n",
50
+ "# Copy toàn bộ gói ra working (gồm wav/ 14060 file + scripts + csv). Mất vài phút.\n",
51
+ "if not os.path.exists(WORK):\n",
52
+ " print(\"Đang copy gói Track 2 ra working (vài phút)...\")\n",
53
+ " shutil.copytree(TRACK2_SRC, WORK)\n",
54
+ "print(\"Số wav hiện có:\", len(os.listdir(f\"{WORK}/wav\")))"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "markdown",
59
+ "metadata": {},
60
+ "source": [
61
+ "## 1. Cài SoX + build sv56 (cần Internet = On)\n",
62
+ "sv56 = công cụ chuẩn hóa âm lượng của ITU-T, build từ source openitu/STL."
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "code",
67
+ "execution_count": null,
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "def sh(cmd):\n",
72
+ " print(\"$\", cmd)\n",
73
+ " print(subprocess.run(cmd, shell=True, capture_output=True, text=True).stdout[-2000:])\n",
74
+ "\n",
75
+ "# SoX + trình biên dịch (để build sv56)\n",
76
+ "sh(\"apt-get -qq update && apt-get -qq install -y sox make gcc\")\n",
77
+ "sh(\"which sox && sox --version\")\n",
78
+ "\n",
79
+ "# sv56demo\n",
80
+ "SV56_DIR = \"/kaggle/working/STL-2009\"\n",
81
+ "SV56_BIN_DIR = f\"{SV56_DIR}/src/sv56\"\n",
82
+ "if not os.path.exists(f\"{SV56_BIN_DIR}/sv56demo\"):\n",
83
+ " sh(\"cd /kaggle/working && wget -q https://github.com/openitu/STL/archive/refs/tags/v2009.tar.gz\")\n",
84
+ " sh(\"cd /kaggle/working && tar -xf v2009.tar.gz\")\n",
85
+ " sh(f\"cd {SV56_BIN_DIR} && make -f makefile.unx\")\n",
86
+ "assert os.path.exists(f\"{SV56_BIN_DIR}/sv56demo\"), \"Build sv56 thất bại — kiểm tra Internet=On + log make.\"\n",
87
+ "\n",
88
+ "# Đưa cả sox và sv56demo vào PATH cho các script .sh dùng được\n",
89
+ "os.environ[\"PATH\"] = SV56_BIN_DIR + \":\" + os.environ[\"PATH\"]\n",
90
+ "sh(\"which sv56demo\")"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "markdown",
95
+ "metadata": {},
96
+ "source": [
97
+ "## 2. Dò vị trí ESD + DailyTalk (tự tìm dù tên thư mục khác nhau)"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "execution_count": null,
103
+ "metadata": {},
104
+ "outputs": [],
105
+ "source": [
106
+ "def find_root(rel_path):\n",
107
+ " \"\"\"Tìm thư mục ROOT trong /kaggle/input sao cho ROOT/rel_path tồn tại.\"\"\"\n",
108
+ " base = os.path.basename(rel_path)\n",
109
+ " for hit in glob.glob(f\"/kaggle/input/**/{base}\", recursive=True):\n",
110
+ " if hit.endswith(rel_path.replace(\"/\", os.sep)) or hit.endswith(rel_path):\n",
111
+ " return hit[: -len(rel_path)].rstrip(\"/\")\n",
112
+ " return None\n",
113
+ "\n",
114
+ "# ESD: dòng CSV \"0014/Angry/000381.wav\" → file thật là \"0014/Angry/0014_000381.wav\"\n",
115
+ "_esd_first = open(f\"{WORK}/ESD_utts_train_dev.csv\").readline().strip().split(\",\")[0]\n",
116
+ "_p = _esd_first.split(\"/\") # [spk, emo, uttID.wav]\n",
117
+ "ESD_REL = f\"{_p[0]}/{_p[1]}/{_p[0]}_{_p[2]}\"\n",
118
+ "ESD_ROOT = find_root(ESD_REL)\n",
119
+ "\n",
120
+ "# DailyTalk: dòng CSV \"1020/0_1_d1020.wav\" → file thật \".../data/1020/0_1_d1020.wav\"\n",
121
+ "_dt_first = open(f\"{WORK}/DT_utts_train_dev.csv\").readline().strip().split(\",\")[0]\n",
122
+ "DT_ROOT = find_root(_dt_first) # ROOT sao cho ROOT/1020/0_1_d1020.wav tồn tại\n",
123
+ "\n",
124
+ "print(\"ESD_REL :\", ESD_REL)\n",
125
+ "print(\"ESD_ROOT :\", ESD_ROOT)\n",
126
+ "print(\"DT_ROOT :\", DT_ROOT)\n",
127
+ "assert ESD_ROOT, \"Không thấy ESD — đã Add Input 'Emotional Speech Dataset (ESD).zip' chưa?\"\n",
128
+ "assert DT_ROOT, \"Không thấy DailyTalk — đã Add Input 'dailytalk.zip' chưa?\""
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "markdown",
133
+ "metadata": {},
134
+ "source": [
135
+ "## 3. Gom các utterance cần dùng → thư mục gathered/"
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "GATHERED = f\"{WORK}/gathered\"\n",
145
+ "os.makedirs(GATHERED, exist_ok=True)\n",
146
+ "\n",
147
+ "# ESD: copy ROOT/spk/emo/spk_uttID → gathered/<tên vmc2026>\n",
148
+ "n_esd = 0\n",
149
+ "for line in open(f\"{WORK}/ESD_utts_train_dev.csv\"):\n",
150
+ " src_rel, dst = line.strip().split(\",\")\n",
151
+ " p = src_rel.split(\"/\")\n",
152
+ " src = f\"{ESD_ROOT}/{p[0]}/{p[1]}/{p[0]}_{p[2]}\"\n",
153
+ " if os.path.exists(src):\n",
154
+ " shutil.copy(src, f\"{GATHERED}/{dst}\")\n",
155
+ " n_esd += 1\n",
156
+ " else:\n",
157
+ " print(\"ESD thiếu:\", src)\n",
158
+ "\n",
159
+ "# DailyTalk: copy ROOT/parts[0] → gathered/<tên vmc2026>\n",
160
+ "n_dt = 0\n",
161
+ "for line in open(f\"{WORK}/DT_utts_train_dev.csv\"):\n",
162
+ " src_rel, dst = line.strip().split(\",\")\n",
163
+ " src = f\"{DT_ROOT}/{src_rel}\"\n",
164
+ " if os.path.exists(src):\n",
165
+ " shutil.copy(src, f\"{GATHERED}/{dst}\")\n",
166
+ " n_dt += 1\n",
167
+ " else:\n",
168
+ " print(\"DailyTalk thiếu:\", src)\n",
169
+ "\n",
170
+ "print(f\"Đã gom: ESD {n_esd}/1379 · DailyTalk {n_dt}/38 · tổng {len(os.listdir(GATHERED))}\")"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {},
176
+ "source": [
177
+ "## 4. Chuẩn hóa âm lượng bằng sv56 (mức -26 dB, giữ nguyên sample rate)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": null,
183
+ "metadata": {},
184
+ "outputs": [],
185
+ "source": [
186
+ "# Dùng chính script gốc trong gói: batch_normRMSE.sh → tạo *_norm.wav trong gathered/\n",
187
+ "sh(f\"bash {WORK}/sv56scripts/batch_normRMSE.sh {GATHERED}\")\n",
188
+ "\n",
189
+ "# Move các file đã chuẩn hóa vào wav/ với đúng tên (bỏ hậu tố _norm)\n",
190
+ "moved = 0\n",
191
+ "for n in os.listdir(GATHERED):\n",
192
+ " if n.endswith(\"_norm.wav\"):\n",
193
+ " final = \"_\".join(n.split(\"_\")[:-1]) + \".wav\" # bỏ \"_norm\"\n",
194
+ " shutil.move(f\"{GATHERED}/{n}\", f\"{WORK}/wav/{final}\")\n",
195
+ " moved += 1\n",
196
+ "print(\"Đã chuẩn hóa & move vào wav/:\", moved, \"file\")"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "markdown",
201
+ "metadata": {},
202
+ "source": [
203
+ "## 5. Kiểm tra đủ 15.477 + dọn rác"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "total = len(glob.glob(f\"{WORK}/wav/*.wav\"))\n",
213
+ "print(\"Tổng wav trong wav/:\", total)\n",
214
+ "if total == 15477:\n",
215
+ " shutil.rmtree(GATHERED, ignore_errors=True)\n",
216
+ " # Dọn artifact build để dataset lưu ra GỌN (chỉ giữ vmc2026-track2/)\n",
217
+ " shutil.rmtree(SV56_DIR, ignore_errors=True)\n",
218
+ " for f in glob.glob(\"/kaggle/working/v2009.tar.gz\"):\n",
219
+ " os.remove(f)\n",
220
+ " print(\"✅ ĐỦ 15.477 file. Sẵn sàng Save Version → tạo Dataset.\")\n",
221
+ " print(\" Dùng dataset này cho notebook baseline: DATA_ROOT = '/kaggle/input/<dataset-mới>/vmc2026-track2'\")\n",
222
+ "else:\n",
223
+ " print(f\"⚠️ Chưa đủ (đang {total}). Kiểm tra log bước 3-4 xem ESD/DailyTalk có thiếu file nào.\")"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "markdown",
228
+ "metadata": {},
229
+ "source": [
230
+ "## Ghi chú\n",
231
+ "- Output nặng (~2-3GB do 15.477 wav). `wav/` đã gồm cả train+dev nên dùng được cho cả fine-tune lẫn inference.\n",
232
+ "- sv56 chuẩn hóa để mẫu ESD/DailyTalk cùng mức âm lượng với mẫu TTS → tránh model bị nhiễu bởi độ to.\n",
233
+ "- Nếu Internet Off: SoX có thể có sẵn nhưng KHÔNG build được sv56 → bắt buộc bật Internet."
234
+ ]
235
+ }
236
+ ],
237
+ "metadata": {
238
+ "kernelspec": {
239
+ "display_name": "Python 3",
240
+ "language": "python",
241
+ "name": "python3"
242
+ },
243
+ "language_info": {
244
+ "name": "python"
245
+ }
246
+ },
247
+ "nbformat": 4,
248
+ "nbformat_minor": 5
249
+ }
track2/track2_prepare_data_pipeline.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %% [markdown]
2
+ # # VMC2026 Track 2 — Chuẩn bị data (gộp ESD + DailyTalk) trên Kaggle
3
+ #
4
+ # Gói Track 2 thiếu **1.417 mẫu giọng thật** (license tách ra):
5
+ # - **sys006** = ESD (1.379 file) · **sys001** = DailyTalk (38 file)
6
+ #
7
+ # Notebook này: cài **SoX** + build **sv56** → gom đúng utterance từ ESD/DailyTalk →
8
+ # **chuẩn hóa âm lượng** (giống mẫu TTS) → ráp vào `wav/` đủ **15.477 file**.
9
+ #
10
+ # ### Cách dùng
11
+ # 1. Settings → **Internet = On** (cần tải/biên dịch sv56). GPU không bắt buộc.
12
+ # 2. **+ Add Input** 3 dataset:
13
+ # - Gói Track 2 (`vmc2026_track2_..._v3.tar.gz` — Kaggle tự giải nén ra `vmc2026-track2/`).
14
+ # - ESD: `Emotional Speech Dataset (ESD).zip` (Kaggle tự giải nén ra `Emotion Speech Dataset/`).
15
+ # - DailyTalk: `dailytalk.zip` (giải nén ra `dailytalk/data/...`).
16
+ # 3. **Run All**. Xong → **Save Version** (Commit) để lưu `wav/` ra output.
17
+ # 4. Từ output đó → **Create Dataset** → dùng làm input cho notebook train/baseline.
18
+ #
19
+ # > Notebook tự dò vị trí ESD/DailyTalk dù Kaggle giải nén ra thư mục tên gì.
20
+
21
+ # %% [markdown]
22
+ # ## 0. Tìm gói Track 2 + copy ra thư mục ghi được
23
+
24
+ # %%
25
+ import os, glob, shutil, subprocess
26
+
27
+ # Tự dò thư mục vmc2026-track2 trong mọi input đã add.
28
+ _cands = glob.glob("/kaggle/input/*/vmc2026-track2") + glob.glob("/kaggle/input/**/vmc2026-track2", recursive=True)
29
+ TRACK2_SRC = _cands[0] if _cands else None
30
+ assert TRACK2_SRC, "Không thấy thư mục vmc2026-track2 — đã Add Input gói Track 2 chưa?"
31
+
32
+ WORK = "/kaggle/working/vmc2026-track2" # bản ghi được (input là read-only)
33
+ print("Track2 source :", TRACK2_SRC)
34
+
35
+ # Copy toàn bộ gói ra working (gồm wav/ 14060 file + scripts + csv). Mất vài phút.
36
+ if not os.path.exists(WORK):
37
+ print("Đang copy gói Track 2 ra working (vài phút)...")
38
+ shutil.copytree(TRACK2_SRC, WORK)
39
+ print("Số wav hiện có:", len(os.listdir(f"{WORK}/wav")))
40
+
41
+ # %% [markdown]
42
+ # ## 1. Cài SoX + build sv56 (cần Internet = On)
43
+ # sv56 = công cụ chuẩn hóa âm lượng của ITU-T, build từ source openitu/STL.
44
+
45
+ # %%
46
+ def sh(cmd):
47
+ print("$", cmd)
48
+ print(subprocess.run(cmd, shell=True, capture_output=True, text=True).stdout[-2000:])
49
+
50
+ # SoX + trình biên dịch (để build sv56)
51
+ sh("apt-get -qq update && apt-get -qq install -y sox make gcc")
52
+ sh("which sox && sox --version")
53
+
54
+ # sv56demo
55
+ SV56_DIR = "/kaggle/working/STL-2009"
56
+ SV56_BIN_DIR = f"{SV56_DIR}/src/sv56"
57
+ if not os.path.exists(f"{SV56_BIN_DIR}/sv56demo"):
58
+ sh("cd /kaggle/working && wget -q https://github.com/openitu/STL/archive/refs/tags/v2009.tar.gz")
59
+ sh("cd /kaggle/working && tar -xf v2009.tar.gz")
60
+ sh(f"cd {SV56_BIN_DIR} && make -f makefile.unx")
61
+ assert os.path.exists(f"{SV56_BIN_DIR}/sv56demo"), "Build sv56 thất bại — kiểm tra Internet=On + log make."
62
+
63
+ # Đưa cả sox và sv56demo vào PATH cho các script .sh dùng được
64
+ os.environ["PATH"] = SV56_BIN_DIR + ":" + os.environ["PATH"]
65
+ sh("which sv56demo")
66
+
67
+ # %% [markdown]
68
+ # ## 2. Dò vị trí ESD + DailyTalk (tự tìm dù tên thư mục khác nhau)
69
+
70
+ # %%
71
+ def find_root(rel_path):
72
+ """Tìm thư mục ROOT trong /kaggle/input sao cho ROOT/rel_path tồn tại."""
73
+ base = os.path.basename(rel_path)
74
+ for hit in glob.glob(f"/kaggle/input/**/{base}", recursive=True):
75
+ if hit.endswith(rel_path.replace("/", os.sep)) or hit.endswith(rel_path):
76
+ return hit[: -len(rel_path)].rstrip("/")
77
+ return None
78
+
79
+ # ESD: dòng CSV "0014/Angry/000381.wav" → file thật là "0014/Angry/0014_000381.wav"
80
+ _esd_first = open(f"{WORK}/ESD_utts_train_dev.csv").readline().strip().split(",")[0]
81
+ _p = _esd_first.split("/") # [spk, emo, uttID.wav]
82
+ ESD_REL = f"{_p[0]}/{_p[1]}/{_p[0]}_{_p[2]}"
83
+ ESD_ROOT = find_root(ESD_REL)
84
+
85
+ # DailyTalk: dòng CSV "1020/0_1_d1020.wav" → file thật ".../data/1020/0_1_d1020.wav"
86
+ _dt_first = open(f"{WORK}/DT_utts_train_dev.csv").readline().strip().split(",")[0]
87
+ DT_ROOT = find_root(_dt_first) # ROOT sao cho ROOT/1020/0_1_d1020.wav tồn tại
88
+
89
+ print("ESD_REL :", ESD_REL)
90
+ print("ESD_ROOT :", ESD_ROOT)
91
+ print("DT_ROOT :", DT_ROOT)
92
+ assert ESD_ROOT, "Không thấy ESD — đã Add Input 'Emotional Speech Dataset (ESD).zip' chưa?"
93
+ assert DT_ROOT, "Không thấy DailyTalk — đã Add Input 'dailytalk.zip' chưa?"
94
+
95
+ # %% [markdown]
96
+ # ## 3. Gom các utterance cần dùng → thư mục gathered/
97
+
98
+ # %%
99
+ GATHERED = f"{WORK}/gathered"
100
+ os.makedirs(GATHERED, exist_ok=True)
101
+
102
+ # ESD: copy ROOT/spk/emo/spk_uttID → gathered/<tên vmc2026>
103
+ n_esd = 0
104
+ for line in open(f"{WORK}/ESD_utts_train_dev.csv"):
105
+ src_rel, dst = line.strip().split(",")
106
+ p = src_rel.split("/")
107
+ src = f"{ESD_ROOT}/{p[0]}/{p[1]}/{p[0]}_{p[2]}"
108
+ if os.path.exists(src):
109
+ shutil.copy(src, f"{GATHERED}/{dst}")
110
+ n_esd += 1
111
+ else:
112
+ print("ESD thiếu:", src)
113
+
114
+ # DailyTalk: copy ROOT/parts[0] → gathered/<tên vmc2026>
115
+ n_dt = 0
116
+ for line in open(f"{WORK}/DT_utts_train_dev.csv"):
117
+ src_rel, dst = line.strip().split(",")
118
+ src = f"{DT_ROOT}/{src_rel}"
119
+ if os.path.exists(src):
120
+ shutil.copy(src, f"{GATHERED}/{dst}")
121
+ n_dt += 1
122
+ else:
123
+ print("DailyTalk thiếu:", src)
124
+
125
+ print(f"Đã gom: ESD {n_esd}/1379 · DailyTalk {n_dt}/38 · tổng {len(os.listdir(GATHERED))}")
126
+
127
+ # %% [markdown]
128
+ # ## 4. Chuẩn hóa âm lượng bằng sv56 (mức -26 dB, giữ nguyên sample rate)
129
+
130
+ # %%
131
+ # Dùng chính script gốc trong gói: batch_normRMSE.sh → tạo *_norm.wav trong gathered/
132
+ sh(f"bash {WORK}/sv56scripts/batch_normRMSE.sh {GATHERED}")
133
+
134
+ # Move các file đã chuẩn hóa vào wav/ với đúng tên (bỏ hậu tố _norm)
135
+ moved = 0
136
+ for n in os.listdir(GATHERED):
137
+ if n.endswith("_norm.wav"):
138
+ final = "_".join(n.split("_")[:-1]) + ".wav" # bỏ "_norm"
139
+ shutil.move(f"{GATHERED}/{n}", f"{WORK}/wav/{final}")
140
+ moved += 1
141
+ print("Đã chuẩn hóa & move vào wav/:", moved, "file")
142
+
143
+ # %% [markdown]
144
+ # ## 5. Kiểm tra đủ 15.477 + dọn rác
145
+
146
+ # %%
147
+ total = len(glob.glob(f"{WORK}/wav/*.wav"))
148
+ print("Tổng wav trong wav/:", total)
149
+ if total == 15477:
150
+ shutil.rmtree(GATHERED, ignore_errors=True)
151
+ # Dọn artifact build để dataset lưu ra GỌN (chỉ giữ vmc2026-track2/)
152
+ shutil.rmtree(SV56_DIR, ignore_errors=True)
153
+ for f in glob.glob("/kaggle/working/v2009.tar.gz"):
154
+ os.remove(f)
155
+ print("✅ ĐỦ 15.477 file. Sẵn sàng Save Version → tạo Dataset.")
156
+ print(" Dùng dataset này cho notebook baseline: DATA_ROOT = '/kaggle/input/<dataset-mới>/vmc2026-track2'")
157
+ else:
158
+ print(f"⚠️ Chưa đủ (đang {total}). Kiểm tra log bước 3-4 xem ESD/DailyTalk có thiếu file nào.")
159
+
160
+ # %% [markdown]
161
+ # ## Ghi chú
162
+ # - Output nặng (~2-3GB do 15.477 wav). `wav/` đã gồm cả train+dev nên dùng được cho cả fine-tune lẫn inference.
163
+ # - sv56 chuẩn hóa để mẫu ESD/DailyTalk cùng mức âm lượng với mẫu TTS → tránh model bị nhiễu bởi độ to.
164
+ # - Nếu Internet Off: SoX có thể có sẵn nhưng KHÔNG build được sv56 → bắt buộc bật Internet.