Minh Toàn commited on
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- demo_all_tracks_gradio.ipynb +482 -0
- demo_all_tracks_gradio_pipeline.py +390 -0
- demo_run_from_hf.ipynb +119 -0
- demo_run_from_hf_pipeline.py +59 -0
- track1/demo_track1_gradio.ipynb +144 -0
- track1/demo_track1_gradio_pipeline.py +100 -0
- track1/track1_baseline.ipynb +108 -0
- track1/track1_baseline_pipeline.py +64 -0
- track2/demo_track2_emotion_gradio.ipynb +533 -0
- track2/demo_track2_emotion_gradio_pipeline.py +431 -0
- track2/demo_track2_gradio.ipynb +175 -0
- track2/demo_track2_gradio_pipeline.py +120 -0
- track2/exp02_train_emos.ipynb +542 -0
- track2/exp02_train_emos_pipeline.py +407 -0
- track2/exp03_emos_sailer.ipynb +392 -0
- track2/exp03_emos_sailer_pipeline.py +264 -0
- track2/exp04_fusion.ipynb +790 -0
- track2/exp04_fusion_pipeline.py +652 -0
- track2/exp05_vad_audeering.ipynb +443 -0
- track2/exp05_vad_audeering_pipeline.py +303 -0
- track2/exp06_qmos_train.ipynb +628 -0
- track2/exp06_qmos_train_pipeline.py +502 -0
- track2/exp07_fusion_qmos.ipynb +780 -0
- track2/exp07_fusion_qmos_pipeline.py +654 -0
- track2/exp08_finetune_emotion.ipynb +820 -0
- track2/exp08_finetune_emotion_pipeline.py +673 -0
- track2/exp08b_finetune_resume.ipynb +782 -0
- track2/exp08b_finetune_resume_pipeline.py +642 -0
- track2/exp09a_qmos_utmosv2_probe.ipynb +339 -0
- track2/exp09a_qmos_utmosv2_probe_pipeline.py +239 -0
- track2/exp10_finetune_audeering.ipynb +691 -0
- track2/exp10_finetune_audeering_pipeline.py +553 -0
- track2/exp11_finetune_joint.ipynb +805 -0
- track2/exp11_finetune_joint_pipeline.py +665 -0
- track2/exp12_wavlm_scratch.ipynb +690 -0
- track2/exp12_wavlm_scratch_pipeline.py +564 -0
- track2/exp13_finetune_qmos.ipynb +733 -0
- track2/exp13_finetune_qmos_pipeline.py +607 -0
- track2/exp14_mamba_head.ipynb +952 -0
- track2/exp14_mamba_head_pipeline.py +798 -0
- track2/exp15_predict.ipynb +698 -0
- track2/exp15_predict_pipeline.py +554 -0
- track2/exp15_wavlm_mamba_emotion.ipynb +1081 -0
- track2/exp15_wavlm_mamba_emotion_pipeline.py +920 -0
- track2/exp16_llm_judge.ipynb +650 -0
- track2/exp16_llm_judge_pipeline.py +480 -0
- track2/track2_baseline.ipynb +130 -0
- track2/track2_baseline_pipeline.py +321 -0
- track2/track2_prepare_data.ipynb +249 -0
- track2/track2_prepare_data_pipeline.py +164 -0
demo_all_tracks_gradio.ipynb
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "739ac809",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 — Demo Gradio GỘP 3 TRACK (1 link cho mentor)\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"Gộp 3 demo lẻ (`track1/`, `track2/`, `track3/`) vào **1 app Gradio 3 tab**:\n",
|
| 11 |
+
"- **Track 1** · Speech Enhancement → **ACR** (chất lượng A) + **CCR** (so A vs B). Model: URGENT-MOS.\n",
|
| 12 |
+
"- **Track 2** · Emotional TTS → **EMOS / CAT / VAD**. Model TỐT NHẤT = **exp08** (WavLM fine-tune + audeering).\n",
|
| 13 |
+
"- **Track 3** · Speaker/Accent → **spk_sim / acc_sim**. Model: ECAPA fine-tuned (baseline BTC).\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"> **Lazy-load:** mỗi track chỉ nạp model khi bạn bấm \"Dự đoán\" ở tab đó → tab nào thiếu checkpoint/repo\n",
|
| 16 |
+
"> chỉ báo lỗi trong tab đó, KHÔNG sập cả app. Track 1 & 3 chỉ cần Internet; Track 2 cần thêm checkpoint exp08.\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"### Cách chạy trên Kaggle\n",
|
| 19 |
+
"1. Settings → **GPU T4 + Internet On**.\n",
|
| 20 |
+
"2. (Cho Track 2) Add Input: dataset Track 2 (`sets/train.csv`, `wav/`, `metadata.csv`) + dataset chứa\n",
|
| 21 |
+
" `ft_emotion_full_20epoch.pt` (slug `toanminh222/cache-exp8`). Thiếu thì 2 tab kia vẫn chạy.\n",
|
| 22 |
+
"3. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h) → gửi mentor."
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "markdown",
|
| 27 |
+
"id": "6f7119e0",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"source": [
|
| 30 |
+
"## 1. Cài đặt gói (1 lần cho cả 3 track)"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": null,
|
| 36 |
+
"id": "4da07abf",
|
| 37 |
+
"metadata": {
|
| 38 |
+
"lines_to_next_cell": 1
|
| 39 |
+
},
|
| 40 |
+
"outputs": [],
|
| 41 |
+
"source": [
|
| 42 |
+
"!pip install -q gradio librosa soundfile speechbrain torchaudio loralib scipy scikit-learn pandas tqdm\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"import os, sys, glob, subprocess\n",
|
| 45 |
+
"\n",
|
| 46 |
+
"def pip_install(*pkgs):\n",
|
| 47 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=False)\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"# Cài nhẹ (Kaggle có sẵn torch/transformers/numpy → KHÔNG đụng numpy để tránh lệch ABI)\n",
|
| 50 |
+
"pip_install(\"gradio\", \"librosa\", \"soundfile\", \"speechbrain\", \"torchaudio\",\n",
|
| 51 |
+
" \"loralib\", \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"import librosa\n",
|
| 54 |
+
"import numpy as np\n",
|
| 55 |
+
"import torch\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
|
| 58 |
+
"SR = 16000\n",
|
| 59 |
+
"print(\"Device:\", DEVICE, (\"✅ \" + torch.cuda.get_device_name(0)) if DEVICE == \"cuda\" else \"⚠️ CPU (chậm)\")\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"def _stem(p):\n",
|
| 62 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"def _scalar(x):\n",
|
| 65 |
+
" return float(x.item()) if hasattr(x, \"item\") else float(x)"
|
| 66 |
+
]
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
"cell_type": "markdown",
|
| 70 |
+
"id": "8ce385ee",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"source": [
|
| 73 |
+
"## 2. TRACK 1 — URGENT-MOS (ACR + CCR) · lazy-load"
|
| 74 |
+
]
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"cell_type": "code",
|
| 78 |
+
"execution_count": null,
|
| 79 |
+
"id": "8f09a3ee",
|
| 80 |
+
"metadata": {
|
| 81 |
+
"lines_to_next_cell": 1
|
| 82 |
+
},
|
| 83 |
+
"outputs": [],
|
| 84 |
+
"source": [
|
| 85 |
+
"URGENT_REPO = \"/kaggle/working/URGENT-MOS\"\n",
|
| 86 |
+
"URGENT_CKPT = \"urgent-challenge/urgent-mos-f1c1m5dcorpus\" # tự tải từ HuggingFace\n",
|
| 87 |
+
"_T1 = {}\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"def _t1_load():\n",
|
| 90 |
+
" \"\"\"Nạp URGENT-MOS 1 lần (clone repo + sys.path + checkpoint).\"\"\"\n",
|
| 91 |
+
" if \"m\" in _T1:\n",
|
| 92 |
+
" return _T1[\"m\"]\n",
|
| 93 |
+
" if not os.path.isdir(URGENT_REPO):\n",
|
| 94 |
+
" subprocess.run(f\"git clone -q https://github.com/vvwangvv/URGENT-MOS.git {URGENT_REPO}\",\n",
|
| 95 |
+
" shell=True, check=True)\n",
|
| 96 |
+
" if URGENT_REPO not in sys.path:\n",
|
| 97 |
+
" sys.path.insert(0, URGENT_REPO)\n",
|
| 98 |
+
" import importlib\n",
|
| 99 |
+
" importlib.invalidate_caches()\n",
|
| 100 |
+
" try:\n",
|
| 101 |
+
" importlib.import_module(\"urgent_mos.api.infer\")\n",
|
| 102 |
+
" except Exception:\n",
|
| 103 |
+
" subprocess.run(f\"pip install -q -e {URGENT_REPO}\", shell=True, check=False)\n",
|
| 104 |
+
" importlib.invalidate_caches()\n",
|
| 105 |
+
" from urgent_mos.utils import load_model_from_checkpoint\n",
|
| 106 |
+
" m = load_model_from_checkpoint(URGENT_CKPT, DEVICE)\n",
|
| 107 |
+
" m.eval()\n",
|
| 108 |
+
" _T1[\"m\"] = m\n",
|
| 109 |
+
" return m\n",
|
| 110 |
+
"\n",
|
| 111 |
+
"def t1_predict(audio_a, audio_b):\n",
|
| 112 |
+
" if not audio_a:\n",
|
| 113 |
+
" return \"⚠️ Hãy tải lên ít nhất **Audio A**.\"\n",
|
| 114 |
+
" try:\n",
|
| 115 |
+
" m = _t1_load()\n",
|
| 116 |
+
" from urgent_mos.api.infer import infer, infer_pairs\n",
|
| 117 |
+
" wa = torch.from_numpy(librosa.load(audio_a, sr=SR, mono=True)[0]).float()\n",
|
| 118 |
+
" acr_a = max(1.0, min(5.0, _scalar(infer(m, [wa], sample_rate=[SR],\n",
|
| 119 |
+
" batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
|
| 120 |
+
" out = f\"**ACR (Audio A): {acr_a:.3f}** (chất lượng tuyệt đối, thang 1–5)\"\n",
|
| 121 |
+
" if audio_b:\n",
|
| 122 |
+
" wb = torch.from_numpy(librosa.load(audio_b, sr=SR, mono=True)[0]).float()\n",
|
| 123 |
+
" acr_b = max(1.0, min(5.0, _scalar(infer(m, [wb], sample_rate=[SR],\n",
|
| 124 |
+
" batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
|
| 125 |
+
" ccr = max(-3.0, min(3.0, _scalar(infer_pairs(m, [(wa, wb)], sample_rate=[(SR, SR)],\n",
|
| 126 |
+
" batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
|
| 127 |
+
" out += (f\"\\n\\n**ACR (Audio B): {acr_b:.3f}**\"\n",
|
| 128 |
+
" f\"\\n\\n**CCR (A so với B): {ccr:+.3f}** (>0: A tốt hơn B; thang −3..+3)\")\n",
|
| 129 |
+
" return out\n",
|
| 130 |
+
" except Exception as e:\n",
|
| 131 |
+
" return f\"❌ Track 1 lỗi: `{repr(e)}`\\n\\nKiểm tra **Internet On** (cần tải URGENT-MOS từ GitHub/HuggingFace).\""
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
{
|
| 135 |
+
"cell_type": "markdown",
|
| 136 |
+
"id": "21d2f185",
|
| 137 |
+
"metadata": {},
|
| 138 |
+
"source": [
|
| 139 |
+
"## 3. TRACK 3 — ECAPA fine-tuned (spk_sim + acc_sim) · lazy-load"
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"cell_type": "code",
|
| 144 |
+
"execution_count": null,
|
| 145 |
+
"id": "143f5e27",
|
| 146 |
+
"metadata": {
|
| 147 |
+
"lines_to_next_cell": 1
|
| 148 |
+
},
|
| 149 |
+
"outputs": [],
|
| 150 |
+
"source": [
|
| 151 |
+
"T3_REPO = \"/kaggle/working/vmc2026-baselines/track3\"\n",
|
| 152 |
+
"CKPT_SPK = f\"{T3_REPO}/official-egs/spk_sim_adamw_lr1e-3/model_spk_sim_step20000.pt\"\n",
|
| 153 |
+
"CKPT_ACC = f\"{T3_REPO}/official-egs/acc_sim_adamw_lr1e-3/model_acc_sim_step20000.pt\"\n",
|
| 154 |
+
"_T3 = {}\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"def _t3_load():\n",
|
| 157 |
+
" if \"spk\" in _T3:\n",
|
| 158 |
+
" return _T3\n",
|
| 159 |
+
" repo_root = \"/kaggle/working/vmc2026-baselines\"\n",
|
| 160 |
+
" if not os.path.isdir(repo_root):\n",
|
| 161 |
+
" subprocess.run(f\"git clone -q https://github.com/voicemos-challenge/vmc2026-baselines.git {repo_root}\",\n",
|
| 162 |
+
" shell=True, check=True)\n",
|
| 163 |
+
" if T3_REPO not in sys.path:\n",
|
| 164 |
+
" sys.path.insert(0, T3_REPO)\n",
|
| 165 |
+
" from model import Model\n",
|
| 166 |
+
" spk = Model(mlp_heads=[\"spk_sim\"])\n",
|
| 167 |
+
" spk.load_state_dict(torch.load(CKPT_SPK, map_location=\"cpu\"))\n",
|
| 168 |
+
" acc = Model(mlp_heads=[\"acc_sim\"])\n",
|
| 169 |
+
" acc.load_state_dict(torch.load(CKPT_ACC, map_location=\"cpu\"))\n",
|
| 170 |
+
" _T3.update(spk=spk.to(DEVICE).eval(), acc=acc.to(DEVICE).eval())\n",
|
| 171 |
+
" return _T3\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"def t3_predict(audio_test, audio_ref):\n",
|
| 174 |
+
" if not audio_test or not audio_ref:\n",
|
| 175 |
+
" return \"⚠️ Cần **cả 2 file**: audio test + audio reference.\"\n",
|
| 176 |
+
" try:\n",
|
| 177 |
+
" M = _t3_load()\n",
|
| 178 |
+
" ta = torch.from_numpy(librosa.load(audio_test, sr=SR, mono=True)[0]).float().unsqueeze(0).to(DEVICE)\n",
|
| 179 |
+
" tb = torch.from_numpy(librosa.load(audio_ref, sr=SR, mono=True)[0]).float().unsqueeze(0).to(DEVICE)\n",
|
| 180 |
+
" with torch.no_grad():\n",
|
| 181 |
+
" o_spk = M[\"spk\"](ta, tb)\n",
|
| 182 |
+
" spk = float(o_spk[\"spk_sim\"].item())\n",
|
| 183 |
+
" acc = float(M[\"acc\"](ta, tb)[\"acc_sim\"].item())\n",
|
| 184 |
+
" cos = float(o_spk[\"cos_sim\"].item())\n",
|
| 185 |
+
" return (f\"**Speaker similarity: {spk:.3f}** (1–5)\\n\\n\"\n",
|
| 186 |
+
" f\"**Accent similarity : {acc:.3f}** (1–5)\\n\\n\"\n",
|
| 187 |
+
" f\"Cosine zero-shot (tham khảo): {cos:.3f}\")\n",
|
| 188 |
+
" except Exception as e:\n",
|
| 189 |
+
" return f\"❌ Track 3 lỗi: `{repr(e)}`\\n\\nKiểm tra **Internet On** (clone repo baseline chứa checkpoint).\""
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
{
|
| 193 |
+
"cell_type": "markdown",
|
| 194 |
+
"id": "6fd2047d",
|
| 195 |
+
"metadata": {},
|
| 196 |
+
"source": [
|
| 197 |
+
"## 4. TRACK 2 — exp08 Emotional TTS Evaluator (EMOS/CAT/VAD) · lazy-load\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"Model TỐT NHẤT: WavLM fine-tune (warm-start SAILER) + audeering frozen → trunk → 3 head.\n",
|
| 200 |
+
"Hằng kiến trúc PHẢI khớp exp08 (ckpt không lưu các số này)."
|
| 201 |
+
]
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"cell_type": "code",
|
| 205 |
+
"execution_count": null,
|
| 206 |
+
"id": "acfc33a7",
|
| 207 |
+
"metadata": {
|
| 208 |
+
"lines_to_next_cell": 1
|
| 209 |
+
},
|
| 210 |
+
"outputs": [],
|
| 211 |
+
"source": [
|
| 212 |
+
"EMO_MAX_SEC, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, USE_AMP = 8, 512, 128, 0.3, True\n",
|
| 213 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 214 |
+
"_EMO_ALIAS = {\n",
|
| 215 |
+
" \"angry\": \"angry\", \"anger\": \"angry\", \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 216 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\", \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 217 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 218 |
+
"}\n",
|
| 219 |
+
"def norm_emotion(label):\n",
|
| 220 |
+
" key = str(label).strip().lower()\n",
|
| 221 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"def _t2_find_ckpt():\n",
|
| 224 |
+
" for pat in [\"ft_emotion_full_20epoch*.pt\", \"ft_emotion_full*.pt\"]:\n",
|
| 225 |
+
" for base in [\"/kaggle/input\", \"/kaggle/working\"]:\n",
|
| 226 |
+
" hits = sorted(glob.glob(os.path.join(base, \"**\", pat), recursive=True))\n",
|
| 227 |
+
" if hits:\n",
|
| 228 |
+
" return hits[0]\n",
|
| 229 |
+
" return \"\"\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"_T2 = {}\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"def _t2_load():\n",
|
| 234 |
+
" if \"infer\" in _T2:\n",
|
| 235 |
+
" return _T2[\"infer\"]\n",
|
| 236 |
+
" import torch.nn as nn\n",
|
| 237 |
+
" ckpt_path = _t2_find_ckpt()\n",
|
| 238 |
+
" assert ckpt_path, \"Không thấy ft_emotion_full*.pt — Add Input dataset checkpoint exp08 (slug toanminh222/cache-exp8)?\"\n",
|
| 239 |
+
" # code SAILER để dựng backbone WavLM\n",
|
| 240 |
+
" repo = \"/kaggle/working/vox-profile-release\"\n",
|
| 241 |
+
" if not os.path.exists(repo):\n",
|
| 242 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 243 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", repo], check=True)\n",
|
| 244 |
+
" if repo not in sys.path:\n",
|
| 245 |
+
" sys.path.insert(0, repo)\n",
|
| 246 |
+
"\n",
|
| 247 |
+
" ckpt = torch.load(ckpt_path, map_location=\"cpu\", weights_only=False)\n",
|
| 248 |
+
" assert \"wavlm\" in ckpt and \"heads\" in ckpt, \"Checkpoint thiếu 'wavlm'/'heads' → cần bản đủ ft_emotion_full_20epoch.pt.\"\n",
|
| 249 |
+
" AUD_DIM = int(ckpt.get(\"AUD_DIM\", 0)); USE_AUDEERING = AUD_DIM > 0\n",
|
| 250 |
+
"\n",
|
| 251 |
+
" def find_hf_backbone(module):\n",
|
| 252 |
+
" cands = []\n",
|
| 253 |
+
" for name, m in module.named_modules():\n",
|
| 254 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 255 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 256 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 257 |
+
" cands.append((name, m))\n",
|
| 258 |
+
" if not cands:\n",
|
| 259 |
+
" return None, None\n",
|
| 260 |
+
" cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
|
| 261 |
+
" return cands[0]\n",
|
| 262 |
+
"\n",
|
| 263 |
+
" wavlm = None\n",
|
| 264 |
+
" try:\n",
|
| 265 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
|
| 266 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 267 |
+
" _name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 268 |
+
" except Exception as e:\n",
|
| 269 |
+
" print(\"⚠️ SAILER wrapper lỗi:\", repr(e), \"→ fallback WavLM trắng.\")\n",
|
| 270 |
+
" if wavlm is None:\n",
|
| 271 |
+
" from transformers import WavLMModel\n",
|
| 272 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 273 |
+
" wavlm = wavlm.to(DEVICE).eval()\n",
|
| 274 |
+
" WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 275 |
+
" wavlm.config.layerdrop = 0.0\n",
|
| 276 |
+
" wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
|
| 277 |
+
"\n",
|
| 278 |
+
" def masked_mean(hidden, attn_mask):\n",
|
| 279 |
+
" if attn_mask is None:\n",
|
| 280 |
+
" return hidden.mean(dim=1)\n",
|
| 281 |
+
" try:\n",
|
| 282 |
+
" fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
|
| 283 |
+
" except Exception:\n",
|
| 284 |
+
" return hidden.mean(dim=1)\n",
|
| 285 |
+
" fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
|
| 286 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
|
| 287 |
+
"\n",
|
| 288 |
+
" # audeering frozen (nếu ckpt dùng)\n",
|
| 289 |
+
" aud_backbone = aud_head = aud_proc = None\n",
|
| 290 |
+
" if USE_AUDEERING:\n",
|
| 291 |
+
" from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 292 |
+
" from huggingface_hub import hf_hub_download\n",
|
| 293 |
+
" AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 294 |
+
" aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 295 |
+
" aud_backbone = Wav2Vec2Model(Wav2Vec2Config.from_pretrained(AUD_NAME))\n",
|
| 296 |
+
" try:\n",
|
| 297 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 298 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 299 |
+
" except Exception:\n",
|
| 300 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 301 |
+
" bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 302 |
+
" aud_backbone.load_state_dict(bb_sd, strict=False)\n",
|
| 303 |
+
" _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 304 |
+
" aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(),\n",
|
| 305 |
+
" nn.Linear(_hid, _sd[\"classifier.out_proj.weight\"].shape[0]))\n",
|
| 306 |
+
" aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 307 |
+
" aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 308 |
+
" aud_backbone = aud_backbone.to(DEVICE).eval(); aud_head = aud_head.to(DEVICE).eval()\n",
|
| 309 |
+
"\n",
|
| 310 |
+
" @torch.no_grad()\n",
|
| 311 |
+
" def audeering_feat(wave):\n",
|
| 312 |
+
" x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 313 |
+
" x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(DEVICE)\n",
|
| 314 |
+
" h = aud_backbone(x)[0].mean(dim=1)\n",
|
| 315 |
+
" out = aud_head(h)[0].cpu().numpy()\n",
|
| 316 |
+
" vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32)\n",
|
| 317 |
+
" return np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
|
| 318 |
+
"\n",
|
| 319 |
+
" N_EMO = len(EMOTIONS5)\n",
|
| 320 |
+
" TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
|
| 321 |
+
"\n",
|
| 322 |
+
" class EmoHeads(nn.Module):\n",
|
| 323 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 324 |
+
" super().__init__()\n",
|
| 325 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 326 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 327 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 328 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 329 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 330 |
+
" def forward(self, feat, tgt):\n",
|
| 331 |
+
" h = self.trunk(feat)\n",
|
| 332 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 333 |
+
"\n",
|
| 334 |
+
" heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(DEVICE).eval()\n",
|
| 335 |
+
" heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
|
| 336 |
+
" emos_mu, emos_sd = float(ckpt[\"emos_mu\"]), float(ckpt[\"emos_sd\"])\n",
|
| 337 |
+
" vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
|
| 338 |
+
"\n",
|
| 339 |
+
" def onehot_target(tgt):\n",
|
| 340 |
+
" v = np.zeros(N_EMO, dtype=np.float32)\n",
|
| 341 |
+
" if tgt in EMOTIONS5:\n",
|
| 342 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 343 |
+
" return v\n",
|
| 344 |
+
"\n",
|
| 345 |
+
" @torch.no_grad()\n",
|
| 346 |
+
" def infer_wave(wave, target_emotion):\n",
|
| 347 |
+
" wave = wave[: EMO_MAX_SEC * SR].astype(np.float32)\n",
|
| 348 |
+
" iv = torch.from_numpy(wave).unsqueeze(0).to(DEVICE)\n",
|
| 349 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=DEVICE)\n",
|
| 350 |
+
" tgt = torch.from_numpy(onehot_target(norm_emotion(target_emotion) if target_emotion else None)).unsqueeze(0).to(DEVICE)\n",
|
| 351 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and DEVICE == \"cuda\"):\n",
|
| 352 |
+
" fw = wavlm(iv, attention_mask=am).last_hidden_state\n",
|
| 353 |
+
" fw = masked_mean(fw, am)\n",
|
| 354 |
+
" if USE_AUDEERING:\n",
|
| 355 |
+
" fw = torch.cat([fw, torch.from_numpy(audeering_feat(wave)).unsqueeze(0).to(DEVICE)], dim=1)\n",
|
| 356 |
+
" emos_p, cat_l, vad_p = heads(fw, tgt)\n",
|
| 357 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 358 |
+
" cat5 = torch.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 359 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 360 |
+
" return emos, cat5, vad3\n",
|
| 361 |
+
"\n",
|
| 362 |
+
" print(f\"✅ Track 2 exp08 nạp xong (audeering {'ON' if USE_AUDEERING else 'OFF'}) từ {ckpt_path}\")\n",
|
| 363 |
+
" _T2[\"infer\"] = infer_wave\n",
|
| 364 |
+
" return infer_wave\n",
|
| 365 |
+
"\n",
|
| 366 |
+
"def t2_predict(audio, target_emotion):\n",
|
| 367 |
+
" \"\"\"Trả: verdict(md), EMOS(number), CAT(label dict), VAL, ARO, DOM.\"\"\"\n",
|
| 368 |
+
" if not audio:\n",
|
| 369 |
+
" return \"### ⚠️ Hãy tải audio (giọng TTS).\", None, {}, None, None, None\n",
|
| 370 |
+
" try:\n",
|
| 371 |
+
" infer_wave = _t2_load()\n",
|
| 372 |
+
" wave, _ = librosa.load(audio, sr=SR, mono=True)\n",
|
| 373 |
+
" emos, cat5, vad3 = infer_wave(wave, target_emotion)\n",
|
| 374 |
+
" cat_dict = {e: float(cat5[i]) for i, e in enumerate(EMOTIONS5)}\n",
|
| 375 |
+
" perceived = EMOTIONS5[int(np.argmax(cat5))]\n",
|
| 376 |
+
" if target_emotion:\n",
|
| 377 |
+
" match = \"✅ **KHỚP** target\" if perceived == norm_emotion(target_emotion) else \"⚠️ **LỆCH** target\"\n",
|
| 378 |
+
" band = \"🟢 tốt\" if emos >= 4 else (\"🟡 khá\" if emos >= 3 else \"🔴 yếu\")\n",
|
| 379 |
+
" verdict = (f\"### Kết luận biểu cảm\\n\"\n",
|
| 380 |
+
" f\"- Cảm xúc cảm nhận: **{perceived}** → {match} (`{target_emotion}`)\\n\"\n",
|
| 381 |
+
" f\"- EMOS = **{emos:.2f}/5** → biểu cảm {band}\")\n",
|
| 382 |
+
" else:\n",
|
| 383 |
+
" verdict = (f\"### Kết luận biểu cảm\\n- Cảm xúc cảm nhận: **{perceived}**\\n\"\n",
|
| 384 |
+
" f\"- *(Chọn cảm xúc target để bật EMOS — độ khớp ý đồ)*\")\n",
|
| 385 |
+
" emos = None\n",
|
| 386 |
+
" return verdict, (round(emos, 3) if emos is not None else None), cat_dict, \\\n",
|
| 387 |
+
" round(float(vad3[0]), 3), round(float(vad3[1]), 3), round(float(vad3[2]), 3)\n",
|
| 388 |
+
" except Exception as e:\n",
|
| 389 |
+
" return f\"### ❌ Track 2 lỗi\\n`{repr(e)}`\\n\\nĐã Add Input checkpoint exp08 + Internet On chưa?\", None, {}, None, None, None"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"cell_type": "markdown",
|
| 394 |
+
"id": "81da6580",
|
| 395 |
+
"metadata": {},
|
| 396 |
+
"source": [
|
| 397 |
+
"## 5. Giao diện Gradio GỘP — 3 tab + launch"
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"cell_type": "code",
|
| 402 |
+
"execution_count": null,
|
| 403 |
+
"id": "418857af",
|
| 404 |
+
"metadata": {},
|
| 405 |
+
"outputs": [],
|
| 406 |
+
"source": [
|
| 407 |
+
"import gradio as gr\n",
|
| 408 |
+
"\n",
|
| 409 |
+
"INTRO = (\n",
|
| 410 |
+
" \"# 🎙️ VoiceMOS Challenge 2026 — Demo 3 Track\\n\"\n",
|
| 411 |
+
" \"Một link cho cả 3 track. Mỗi tab nhận audio → trả điểm bộ chấm tự động.\\n\\n\"\n",
|
| 412 |
+
" \"| Track | Bài toán | Output |\\n|---|---|---|\\n\"\n",
|
| 413 |
+
" \"| **1** | Speech Enhancement | ACR (chất lượng) · CCR (so sánh cặp) |\\n\"\n",
|
| 414 |
+
" \"| **2** | Emotional TTS | EMOS · CAT · VAD (5 cột cảm xúc) — *model tốt nhất exp08* |\\n\"\n",
|
| 415 |
+
" \"| **3** | Speaker/Accent | spk_sim · acc_sim |\\n\\n\"\n",
|
| 416 |
+
" \"> Model nạp **lần đầu bấm nút** (chờ ~1–2 phút tải). Tab thiếu checkpoint chỉ báo lỗi trong tab đó.\"\n",
|
| 417 |
+
")\n",
|
| 418 |
+
"\n",
|
| 419 |
+
"with gr.Blocks(title=\"VMC2026 — Demo 3 Track\") as demo:\n",
|
| 420 |
+
" gr.Markdown(INTRO)\n",
|
| 421 |
+
"\n",
|
| 422 |
+
" with gr.Tab(\"1️⃣ Track 1 · Chất lượng (ACR/CCR)\"):\n",
|
| 423 |
+
" gr.Markdown(\"Tải **Audio A** → ACR (1–5). Tải thêm **Audio B** → CCR (A vs B, −3..+3, >0 = A tốt hơn).\")\n",
|
| 424 |
+
" t1a = gr.Audio(type=\"filepath\", label=\"Audio A (bắt buộc)\")\n",
|
| 425 |
+
" t1b = gr.Audio(type=\"filepath\", label=\"Audio B (tùy chọn — để tính CCR)\")\n",
|
| 426 |
+
" t1out = gr.Markdown()\n",
|
| 427 |
+
" gr.Button(\"Dự đoán\", variant=\"primary\").click(t1_predict, [t1a, t1b], t1out)\n",
|
| 428 |
+
"\n",
|
| 429 |
+
" with gr.Tab(\"2️⃣ Track 2 · Cảm xúc (EMOS/CAT/VAD)\"):\n",
|
| 430 |
+
" gr.Markdown(\"Model tốt nhất **exp08** (WavLM fine-tune + audeering, offline). \"\n",
|
| 431 |
+
" \"Chọn **cảm xúc target** để bật EMOS (độ khớp ý đồ).\")\n",
|
| 432 |
+
" with gr.Row():\n",
|
| 433 |
+
" with gr.Column(scale=1):\n",
|
| 434 |
+
" t2a = gr.Audio(type=\"filepath\", label=\"Audio (giọng TTS)\")\n",
|
| 435 |
+
" t2tgt = gr.Dropdown(EMOTIONS5, label=\"🎯 Cảm xúc target (cho EMOS)\")\n",
|
| 436 |
+
" t2btn = gr.Button(\"Chấm cảm xúc\", variant=\"primary\")\n",
|
| 437 |
+
" with gr.Column(scale=2):\n",
|
| 438 |
+
" t2verdict = gr.Markdown()\n",
|
| 439 |
+
" t2emos = gr.Number(label=\"EMOS — khớp cảm xúc target (1–5)\", interactive=False)\n",
|
| 440 |
+
" t2cat = gr.Label(label=\"CAT — phân bố cảm xúc cảm nhận (5 lớp)\")\n",
|
| 441 |
+
" gr.Markdown(\"**VAD — toạ độ cảm xúc liên tục (1–5):**\")\n",
|
| 442 |
+
" with gr.Row():\n",
|
| 443 |
+
" t2val = gr.Number(label=\"Valence (tích cực↑)\", interactive=False)\n",
|
| 444 |
+
" t2aro = gr.Number(label=\"Arousal (kích động↑)\", interactive=False)\n",
|
| 445 |
+
" t2dom = gr.Number(label=\"Dominance (chi phối↑)\", interactive=False)\n",
|
| 446 |
+
" t2btn.click(t2_predict, [t2a, t2tgt], [t2verdict, t2emos, t2cat, t2val, t2aro, t2dom])\n",
|
| 447 |
+
"\n",
|
| 448 |
+
" with gr.Tab(\"3️⃣ Track 3 · Speaker/Accent\"):\n",
|
| 449 |
+
" gr.Markdown(\"Tải **audio cần đánh giá** + **audio tham chiếu** → độ giống người nói & accent (1–5).\")\n",
|
| 450 |
+
" t3t = gr.Audio(type=\"filepath\", label=\"Audio cần đánh giá (test)\")\n",
|
| 451 |
+
" t3r = gr.Audio(type=\"filepath\", label=\"Audio tham chiếu (reference)\")\n",
|
| 452 |
+
" t3out = gr.Markdown()\n",
|
| 453 |
+
" gr.Button(\"Dự đoán\", variant=\"primary\").click(t3_predict, [t3t, t3r], t3out)\n",
|
| 454 |
+
"\n",
|
| 455 |
+
"demo.launch(share=True)"
|
| 456 |
+
]
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"cell_type": "markdown",
|
| 460 |
+
"id": "b04c3d5e",
|
| 461 |
+
"metadata": {},
|
| 462 |
+
"source": [
|
| 463 |
+
"## Ghi chú\n",
|
| 464 |
+
"- **Lazy-load:** mỗi `_tN_load()` nạp model 1 lần rồi cache module-level → tab nào không bấm thì không tốn RAM/VRAM.\n",
|
| 465 |
+
"- Track 1 cần URGENT-MOS (GitHub + HuggingFace); Track 3 clone repo baseline (có sẵn checkpoint); Track 2 cần\n",
|
| 466 |
+
" checkpoint exp08 (`ft_emotion_full_20epoch.pt`, slug `toanminh222/cache-exp8`) + tải WavLM/SAILER/audeering.\n",
|
| 467 |
+
"- Hằng `TRUNK_HIDDEN/HEAD_HIDDEN/EMO_MAX_SEC` của Track 2 PHẢI khớp exp08 (ckpt không lưu) — sai là lệch shape.\n",
|
| 468 |
+
"- 3 tab độc lập: thiếu checkpoint/Internet của 1 track chỉ báo lỗi trong tab đó, 2 tab còn lại vẫn chạy.\n",
|
| 469 |
+
"- Cần **GPU T4 + Internet On**. Bản chỉ Track 2 đầy đủ (có tab metric val nội bộ) ở `track2/demo_track2_emotion_gradio`."
|
| 470 |
+
]
|
| 471 |
+
}
|
| 472 |
+
],
|
| 473 |
+
"metadata": {
|
| 474 |
+
"jupytext": {
|
| 475 |
+
"cell_metadata_filter": "-all",
|
| 476 |
+
"main_language": "python",
|
| 477 |
+
"notebook_metadata_filter": "-all"
|
| 478 |
+
}
|
| 479 |
+
},
|
| 480 |
+
"nbformat": 4,
|
| 481 |
+
"nbformat_minor": 5
|
| 482 |
+
}
|
demo_all_tracks_gradio_pipeline.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 — Demo Gradio GỘP 3 TRACK (1 link cho mentor)
|
| 3 |
+
#
|
| 4 |
+
# Gộp 3 demo lẻ (`track1/`, `track2/`, `track3/`) vào **1 app Gradio 3 tab**:
|
| 5 |
+
# - **Track 1** · Speech Enhancement → **ACR** (chất lượng A) + **CCR** (so A vs B). Model: URGENT-MOS.
|
| 6 |
+
# - **Track 2** · Emotional TTS → **EMOS / CAT / VAD**. Model TỐT NHẤT = **exp08** (WavLM fine-tune + audeering).
|
| 7 |
+
# - **Track 3** · Speaker/Accent → **spk_sim / acc_sim**. Model: ECAPA fine-tuned (baseline BTC).
|
| 8 |
+
#
|
| 9 |
+
# > **Lazy-load:** mỗi track chỉ nạp model khi bạn bấm "Dự đoán" ở tab đó → tab nào thiếu checkpoint/repo
|
| 10 |
+
# > chỉ báo lỗi trong tab đó, KHÔNG sập cả app. Track 1 & 3 chỉ cần Internet; Track 2 cần thêm checkpoint exp08.
|
| 11 |
+
#
|
| 12 |
+
# ### Cách chạy trên Kaggle
|
| 13 |
+
# 1. Settings → **GPU T4 + Internet On**.
|
| 14 |
+
# 2. (Cho Track 2) Add Input: dataset Track 2 (`sets/train.csv`, `wav/`, `metadata.csv`) + dataset chứa
|
| 15 |
+
# `ft_emotion_full_20epoch.pt` (slug `toanminh222/cache-exp8`). Thiếu thì 2 tab kia vẫn chạy.
|
| 16 |
+
# 3. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h) → gửi mentor.
|
| 17 |
+
|
| 18 |
+
# %% [markdown]
|
| 19 |
+
# ## 1. Cài đặt gói (1 lần cho cả 3 track)
|
| 20 |
+
|
| 21 |
+
# %%
|
| 22 |
+
# !pip install -q gradio librosa soundfile speechbrain torchaudio loralib scipy scikit-learn pandas tqdm
|
| 23 |
+
|
| 24 |
+
import os, sys, glob, subprocess
|
| 25 |
+
|
| 26 |
+
def pip_install(*pkgs):
|
| 27 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=False)
|
| 28 |
+
|
| 29 |
+
# Cài nhẹ (Kaggle có sẵn torch/transformers/numpy → KHÔNG đụng numpy để tránh lệch ABI)
|
| 30 |
+
pip_install("gradio", "librosa", "soundfile", "speechbrain", "torchaudio",
|
| 31 |
+
"loralib", "scipy", "scikit-learn", "pandas", "tqdm")
|
| 32 |
+
|
| 33 |
+
import librosa
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch
|
| 36 |
+
|
| 37 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 38 |
+
SR = 16000
|
| 39 |
+
print("Device:", DEVICE, ("✅ " + torch.cuda.get_device_name(0)) if DEVICE == "cuda" else "⚠️ CPU (chậm)")
|
| 40 |
+
|
| 41 |
+
def _stem(p):
|
| 42 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 43 |
+
|
| 44 |
+
def _scalar(x):
|
| 45 |
+
return float(x.item()) if hasattr(x, "item") else float(x)
|
| 46 |
+
|
| 47 |
+
# %% [markdown]
|
| 48 |
+
# ## 2. TRACK 1 — URGENT-MOS (ACR + CCR) · lazy-load
|
| 49 |
+
|
| 50 |
+
# %%
|
| 51 |
+
URGENT_REPO = "/kaggle/working/URGENT-MOS"
|
| 52 |
+
URGENT_CKPT = "urgent-challenge/urgent-mos-f1c1m5dcorpus" # tự tải từ HuggingFace
|
| 53 |
+
_T1 = {}
|
| 54 |
+
|
| 55 |
+
def _t1_load():
|
| 56 |
+
"""Nạp URGENT-MOS 1 lần (clone repo + sys.path + checkpoint)."""
|
| 57 |
+
if "m" in _T1:
|
| 58 |
+
return _T1["m"]
|
| 59 |
+
if not os.path.isdir(URGENT_REPO):
|
| 60 |
+
subprocess.run(f"git clone -q https://github.com/vvwangvv/URGENT-MOS.git {URGENT_REPO}",
|
| 61 |
+
shell=True, check=True)
|
| 62 |
+
if URGENT_REPO not in sys.path:
|
| 63 |
+
sys.path.insert(0, URGENT_REPO)
|
| 64 |
+
import importlib
|
| 65 |
+
importlib.invalidate_caches()
|
| 66 |
+
try:
|
| 67 |
+
importlib.import_module("urgent_mos.api.infer")
|
| 68 |
+
except Exception:
|
| 69 |
+
subprocess.run(f"pip install -q -e {URGENT_REPO}", shell=True, check=False)
|
| 70 |
+
importlib.invalidate_caches()
|
| 71 |
+
from urgent_mos.utils import load_model_from_checkpoint
|
| 72 |
+
m = load_model_from_checkpoint(URGENT_CKPT, DEVICE)
|
| 73 |
+
m.eval()
|
| 74 |
+
_T1["m"] = m
|
| 75 |
+
return m
|
| 76 |
+
|
| 77 |
+
def t1_predict(audio_a, audio_b):
|
| 78 |
+
if not audio_a:
|
| 79 |
+
return "⚠️ Hãy tải lên ít nhất **Audio A**."
|
| 80 |
+
try:
|
| 81 |
+
m = _t1_load()
|
| 82 |
+
from urgent_mos.api.infer import infer, infer_pairs
|
| 83 |
+
wa = torch.from_numpy(librosa.load(audio_a, sr=SR, mono=True)[0]).float()
|
| 84 |
+
acr_a = max(1.0, min(5.0, _scalar(infer(m, [wa], sample_rate=[SR],
|
| 85 |
+
batch_frames=None, num_workers=0)[0]["mos_overall"])))
|
| 86 |
+
out = f"**ACR (Audio A): {acr_a:.3f}** (chất lượng tuyệt đối, thang 1–5)"
|
| 87 |
+
if audio_b:
|
| 88 |
+
wb = torch.from_numpy(librosa.load(audio_b, sr=SR, mono=True)[0]).float()
|
| 89 |
+
acr_b = max(1.0, min(5.0, _scalar(infer(m, [wb], sample_rate=[SR],
|
| 90 |
+
batch_frames=None, num_workers=0)[0]["mos_overall"])))
|
| 91 |
+
ccr = max(-3.0, min(3.0, _scalar(infer_pairs(m, [(wa, wb)], sample_rate=[(SR, SR)],
|
| 92 |
+
batch_frames=None, num_workers=0)[0]["mos_overall"])))
|
| 93 |
+
out += (f"\n\n**ACR (Audio B): {acr_b:.3f}**"
|
| 94 |
+
f"\n\n**CCR (A so với B): {ccr:+.3f}** (>0: A tốt hơn B; thang −3..+3)")
|
| 95 |
+
return out
|
| 96 |
+
except Exception as e:
|
| 97 |
+
return f"❌ Track 1 lỗi: `{repr(e)}`\n\nKiểm tra **Internet On** (cần tải URGENT-MOS từ GitHub/HuggingFace)."
|
| 98 |
+
|
| 99 |
+
# %% [markdown]
|
| 100 |
+
# ## 3. TRACK 3 — ECAPA fine-tuned (spk_sim + acc_sim) · lazy-load
|
| 101 |
+
|
| 102 |
+
# %%
|
| 103 |
+
T3_REPO = "/kaggle/working/vmc2026-baselines/track3"
|
| 104 |
+
CKPT_SPK = f"{T3_REPO}/official-egs/spk_sim_adamw_lr1e-3/model_spk_sim_step20000.pt"
|
| 105 |
+
CKPT_ACC = f"{T3_REPO}/official-egs/acc_sim_adamw_lr1e-3/model_acc_sim_step20000.pt"
|
| 106 |
+
_T3 = {}
|
| 107 |
+
|
| 108 |
+
def _t3_load():
|
| 109 |
+
if "spk" in _T3:
|
| 110 |
+
return _T3
|
| 111 |
+
repo_root = "/kaggle/working/vmc2026-baselines"
|
| 112 |
+
if not os.path.isdir(repo_root):
|
| 113 |
+
subprocess.run(f"git clone -q https://github.com/voicemos-challenge/vmc2026-baselines.git {repo_root}",
|
| 114 |
+
shell=True, check=True)
|
| 115 |
+
if T3_REPO not in sys.path:
|
| 116 |
+
sys.path.insert(0, T3_REPO)
|
| 117 |
+
from model import Model
|
| 118 |
+
spk = Model(mlp_heads=["spk_sim"])
|
| 119 |
+
spk.load_state_dict(torch.load(CKPT_SPK, map_location="cpu"))
|
| 120 |
+
acc = Model(mlp_heads=["acc_sim"])
|
| 121 |
+
acc.load_state_dict(torch.load(CKPT_ACC, map_location="cpu"))
|
| 122 |
+
_T3.update(spk=spk.to(DEVICE).eval(), acc=acc.to(DEVICE).eval())
|
| 123 |
+
return _T3
|
| 124 |
+
|
| 125 |
+
def t3_predict(audio_test, audio_ref):
|
| 126 |
+
if not audio_test or not audio_ref:
|
| 127 |
+
return "⚠️ Cần **cả 2 file**: audio test + audio reference."
|
| 128 |
+
try:
|
| 129 |
+
M = _t3_load()
|
| 130 |
+
ta = torch.from_numpy(librosa.load(audio_test, sr=SR, mono=True)[0]).float().unsqueeze(0).to(DEVICE)
|
| 131 |
+
tb = torch.from_numpy(librosa.load(audio_ref, sr=SR, mono=True)[0]).float().unsqueeze(0).to(DEVICE)
|
| 132 |
+
with torch.no_grad():
|
| 133 |
+
o_spk = M["spk"](ta, tb)
|
| 134 |
+
spk = float(o_spk["spk_sim"].item())
|
| 135 |
+
acc = float(M["acc"](ta, tb)["acc_sim"].item())
|
| 136 |
+
cos = float(o_spk["cos_sim"].item())
|
| 137 |
+
return (f"**Speaker similarity: {spk:.3f}** (1–5)\n\n"
|
| 138 |
+
f"**Accent similarity : {acc:.3f}** (1–5)\n\n"
|
| 139 |
+
f"Cosine zero-shot (tham khảo): {cos:.3f}")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
return f"❌ Track 3 lỗi: `{repr(e)}`\n\nKiểm tra **Internet On** (clone repo baseline chứa checkpoint)."
|
| 142 |
+
|
| 143 |
+
# %% [markdown]
|
| 144 |
+
# ## 4. TRACK 2 — exp08 Emotional TTS Evaluator (EMOS/CAT/VAD) · lazy-load
|
| 145 |
+
#
|
| 146 |
+
# Model TỐT NHẤT: WavLM fine-tune (warm-start SAILER) + audeering frozen → trunk → 3 head.
|
| 147 |
+
# Hằng kiến trúc PHẢI khớp exp08 (ckpt không lưu các số này).
|
| 148 |
+
|
| 149 |
+
# %%
|
| 150 |
+
EMO_MAX_SEC, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, USE_AMP = 8, 512, 128, 0.3, True
|
| 151 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 152 |
+
_EMO_ALIAS = {
|
| 153 |
+
"angry": "angry", "anger": "angry", "happy": "happy", "happiness": "happy", "joy": "happy",
|
| 154 |
+
"neutral": "neutral", "calm": "neutral", "sad": "sad", "sadness": "sad",
|
| 155 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 156 |
+
}
|
| 157 |
+
def norm_emotion(label):
|
| 158 |
+
key = str(label).strip().lower()
|
| 159 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 160 |
+
|
| 161 |
+
def _t2_find_ckpt():
|
| 162 |
+
for pat in ["ft_emotion_full_20epoch*.pt", "ft_emotion_full*.pt"]:
|
| 163 |
+
for base in ["/kaggle/input", "/kaggle/working"]:
|
| 164 |
+
hits = sorted(glob.glob(os.path.join(base, "**", pat), recursive=True))
|
| 165 |
+
if hits:
|
| 166 |
+
return hits[0]
|
| 167 |
+
return ""
|
| 168 |
+
|
| 169 |
+
_T2 = {}
|
| 170 |
+
|
| 171 |
+
def _t2_load():
|
| 172 |
+
if "infer" in _T2:
|
| 173 |
+
return _T2["infer"]
|
| 174 |
+
import torch.nn as nn
|
| 175 |
+
ckpt_path = _t2_find_ckpt()
|
| 176 |
+
assert ckpt_path, "Không thấy ft_emotion_full*.pt — Add Input dataset checkpoint exp08 (slug toanminh222/cache-exp8)?"
|
| 177 |
+
# code SAILER để dựng backbone WavLM
|
| 178 |
+
repo = "/kaggle/working/vox-profile-release"
|
| 179 |
+
if not os.path.exists(repo):
|
| 180 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 181 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", repo], check=True)
|
| 182 |
+
if repo not in sys.path:
|
| 183 |
+
sys.path.insert(0, repo)
|
| 184 |
+
|
| 185 |
+
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 186 |
+
assert "wavlm" in ckpt and "heads" in ckpt, "Checkpoint thiếu 'wavlm'/'heads' → cần bản đủ ft_emotion_full_20epoch.pt."
|
| 187 |
+
AUD_DIM = int(ckpt.get("AUD_DIM", 0)); USE_AUDEERING = AUD_DIM > 0
|
| 188 |
+
|
| 189 |
+
def find_hf_backbone(module):
|
| 190 |
+
cands = []
|
| 191 |
+
for name, m in module.named_modules():
|
| 192 |
+
enc = getattr(m, "encoder", None)
|
| 193 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 194 |
+
and getattr(enc, "layers", None) is not None:
|
| 195 |
+
cands.append((name, m))
|
| 196 |
+
if not cands:
|
| 197 |
+
return None, None
|
| 198 |
+
cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
|
| 199 |
+
return cands[0]
|
| 200 |
+
|
| 201 |
+
wavlm = None
|
| 202 |
+
try:
|
| 203 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper
|
| 204 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 205 |
+
_name, wavlm = find_hf_backbone(_wrapper)
|
| 206 |
+
except Exception as e:
|
| 207 |
+
print("⚠️ SAILER wrapper lỗi:", repr(e), "→ fallback WavLM trắng.")
|
| 208 |
+
if wavlm is None:
|
| 209 |
+
from transformers import WavLMModel
|
| 210 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 211 |
+
wavlm = wavlm.to(DEVICE).eval()
|
| 212 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 213 |
+
wavlm.config.layerdrop = 0.0
|
| 214 |
+
wavlm.load_state_dict(ckpt["wavlm"], strict=False)
|
| 215 |
+
|
| 216 |
+
def masked_mean(hidden, attn_mask):
|
| 217 |
+
if attn_mask is None:
|
| 218 |
+
return hidden.mean(dim=1)
|
| 219 |
+
try:
|
| 220 |
+
fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
|
| 221 |
+
except Exception:
|
| 222 |
+
return hidden.mean(dim=1)
|
| 223 |
+
fm = fm.unsqueeze(-1).to(hidden.dtype)
|
| 224 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 225 |
+
|
| 226 |
+
# audeering frozen (nếu ckpt dùng)
|
| 227 |
+
aud_backbone = aud_head = aud_proc = None
|
| 228 |
+
if USE_AUDEERING:
|
| 229 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 230 |
+
from huggingface_hub import hf_hub_download
|
| 231 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 232 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 233 |
+
aud_backbone = Wav2Vec2Model(Wav2Vec2Config.from_pretrained(AUD_NAME))
|
| 234 |
+
try:
|
| 235 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 236 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 237 |
+
except Exception:
|
| 238 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 239 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 240 |
+
aud_backbone.load_state_dict(bb_sd, strict=False)
|
| 241 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 242 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(),
|
| 243 |
+
nn.Linear(_hid, _sd["classifier.out_proj.weight"].shape[0]))
|
| 244 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 245 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 246 |
+
aud_backbone = aud_backbone.to(DEVICE).eval(); aud_head = aud_head.to(DEVICE).eval()
|
| 247 |
+
|
| 248 |
+
@torch.no_grad()
|
| 249 |
+
def audeering_feat(wave):
|
| 250 |
+
x = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 251 |
+
x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(DEVICE)
|
| 252 |
+
h = aud_backbone(x)[0].mean(dim=1)
|
| 253 |
+
out = aud_head(h)[0].cpu().numpy()
|
| 254 |
+
vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32)
|
| 255 |
+
return np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
|
| 256 |
+
|
| 257 |
+
N_EMO = len(EMOTIONS5)
|
| 258 |
+
TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
|
| 259 |
+
|
| 260 |
+
class EmoHeads(nn.Module):
|
| 261 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 262 |
+
super().__init__()
|
| 263 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 264 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 265 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 266 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 267 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 268 |
+
def forward(self, feat, tgt):
|
| 269 |
+
h = self.trunk(feat)
|
| 270 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 271 |
+
|
| 272 |
+
heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(DEVICE).eval()
|
| 273 |
+
heads.load_state_dict(ckpt["heads"], strict=False)
|
| 274 |
+
emos_mu, emos_sd = float(ckpt["emos_mu"]), float(ckpt["emos_sd"])
|
| 275 |
+
vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
|
| 276 |
+
|
| 277 |
+
def onehot_target(tgt):
|
| 278 |
+
v = np.zeros(N_EMO, dtype=np.float32)
|
| 279 |
+
if tgt in EMOTIONS5:
|
| 280 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 281 |
+
return v
|
| 282 |
+
|
| 283 |
+
@torch.no_grad()
|
| 284 |
+
def infer_wave(wave, target_emotion):
|
| 285 |
+
wave = wave[: EMO_MAX_SEC * SR].astype(np.float32)
|
| 286 |
+
iv = torch.from_numpy(wave).unsqueeze(0).to(DEVICE)
|
| 287 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=DEVICE)
|
| 288 |
+
tgt = torch.from_numpy(onehot_target(norm_emotion(target_emotion) if target_emotion else None)).unsqueeze(0).to(DEVICE)
|
| 289 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and DEVICE == "cuda"):
|
| 290 |
+
fw = wavlm(iv, attention_mask=am).last_hidden_state
|
| 291 |
+
fw = masked_mean(fw, am)
|
| 292 |
+
if USE_AUDEERING:
|
| 293 |
+
fw = torch.cat([fw, torch.from_numpy(audeering_feat(wave)).unsqueeze(0).to(DEVICE)], dim=1)
|
| 294 |
+
emos_p, cat_l, vad_p = heads(fw, tgt)
|
| 295 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 296 |
+
cat5 = torch.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 297 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 298 |
+
return emos, cat5, vad3
|
| 299 |
+
|
| 300 |
+
print(f"✅ Track 2 exp08 nạp xong (audeering {'ON' if USE_AUDEERING else 'OFF'}) từ {ckpt_path}")
|
| 301 |
+
_T2["infer"] = infer_wave
|
| 302 |
+
return infer_wave
|
| 303 |
+
|
| 304 |
+
def t2_predict(audio, target_emotion):
|
| 305 |
+
"""Trả: verdict(md), EMOS(number), CAT(label dict), VAL, ARO, DOM."""
|
| 306 |
+
if not audio:
|
| 307 |
+
return "### ⚠️ Hãy tải audio (giọng TTS).", None, {}, None, None, None
|
| 308 |
+
try:
|
| 309 |
+
infer_wave = _t2_load()
|
| 310 |
+
wave, _ = librosa.load(audio, sr=SR, mono=True)
|
| 311 |
+
emos, cat5, vad3 = infer_wave(wave, target_emotion)
|
| 312 |
+
cat_dict = {e: float(cat5[i]) for i, e in enumerate(EMOTIONS5)}
|
| 313 |
+
perceived = EMOTIONS5[int(np.argmax(cat5))]
|
| 314 |
+
if target_emotion:
|
| 315 |
+
match = "✅ **KHỚP** target" if perceived == norm_emotion(target_emotion) else "⚠️ **LỆCH** target"
|
| 316 |
+
band = "🟢 tốt" if emos >= 4 else ("🟡 khá" if emos >= 3 else "🔴 yếu")
|
| 317 |
+
verdict = (f"### Kết luận biểu cảm\n"
|
| 318 |
+
f"- Cảm xúc cảm nhận: **{perceived}** → {match} (`{target_emotion}`)\n"
|
| 319 |
+
f"- EMOS = **{emos:.2f}/5** → biểu cảm {band}")
|
| 320 |
+
else:
|
| 321 |
+
verdict = (f"### Kết luận biểu cảm\n- Cảm xúc cảm nhận: **{perceived}**\n"
|
| 322 |
+
f"- *(Chọn cảm xúc target để bật EMOS — độ khớp ý đồ)*")
|
| 323 |
+
emos = None
|
| 324 |
+
return verdict, (round(emos, 3) if emos is not None else None), cat_dict, \
|
| 325 |
+
round(float(vad3[0]), 3), round(float(vad3[1]), 3), round(float(vad3[2]), 3)
|
| 326 |
+
except Exception as e:
|
| 327 |
+
return f"### ❌ Track 2 lỗi\n`{repr(e)}`\n\nĐã Add Input checkpoint exp08 + Internet On chưa?", None, {}, None, None, None
|
| 328 |
+
|
| 329 |
+
# %% [markdown]
|
| 330 |
+
# ## 5. Giao diện Gradio GỘP — 3 tab + launch
|
| 331 |
+
|
| 332 |
+
# %%
|
| 333 |
+
import gradio as gr
|
| 334 |
+
|
| 335 |
+
INTRO = (
|
| 336 |
+
"# 🎙️ VoiceMOS Challenge 2026 — Demo 3 Track\n"
|
| 337 |
+
"Một link cho cả 3 track. Mỗi tab nhận audio → trả điểm bộ chấm tự động.\n\n"
|
| 338 |
+
"| Track | Bài toán | Output |\n|---|---|---|\n"
|
| 339 |
+
"| **1** | Speech Enhancement | ACR (chất lượng) · CCR (so sánh cặp) |\n"
|
| 340 |
+
"| **2** | Emotional TTS | EMOS · CAT · VAD (5 cột cảm xúc) — *model tốt nhất exp08* |\n"
|
| 341 |
+
"| **3** | Speaker/Accent | spk_sim · acc_sim |\n\n"
|
| 342 |
+
"> Model nạp **lần đầu bấm nút** (chờ ~1–2 phút tải). Tab thiếu checkpoint chỉ báo lỗi trong tab đó."
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
with gr.Blocks(title="VMC2026 — Demo 3 Track") as demo:
|
| 346 |
+
gr.Markdown(INTRO)
|
| 347 |
+
|
| 348 |
+
with gr.Tab("1️⃣ Track 1 · Chất lượng (ACR/CCR)"):
|
| 349 |
+
gr.Markdown("Tải **Audio A** → ACR (1–5). Tải thêm **Audio B** → CCR (A vs B, −3..+3, >0 = A tốt hơn).")
|
| 350 |
+
t1a = gr.Audio(type="filepath", label="Audio A (bắt buộc)")
|
| 351 |
+
t1b = gr.Audio(type="filepath", label="Audio B (tùy chọn — để tính CCR)")
|
| 352 |
+
t1out = gr.Markdown()
|
| 353 |
+
gr.Button("Dự đoán", variant="primary").click(t1_predict, [t1a, t1b], t1out)
|
| 354 |
+
|
| 355 |
+
with gr.Tab("2️⃣ Track 2 · Cảm xúc (EMOS/CAT/VAD)"):
|
| 356 |
+
gr.Markdown("Model tốt nhất **exp08** (WavLM fine-tune + audeering, offline). "
|
| 357 |
+
"Chọn **cảm xúc target** để bật EMOS (độ khớp ý đồ).")
|
| 358 |
+
with gr.Row():
|
| 359 |
+
with gr.Column(scale=1):
|
| 360 |
+
t2a = gr.Audio(type="filepath", label="Audio (giọng TTS)")
|
| 361 |
+
t2tgt = gr.Dropdown(EMOTIONS5, label="🎯 Cảm xúc target (cho EMOS)")
|
| 362 |
+
t2btn = gr.Button("Chấm cảm xúc", variant="primary")
|
| 363 |
+
with gr.Column(scale=2):
|
| 364 |
+
t2verdict = gr.Markdown()
|
| 365 |
+
t2emos = gr.Number(label="EMOS — khớp cảm xúc target (1–5)", interactive=False)
|
| 366 |
+
t2cat = gr.Label(label="CAT — phân bố cảm xúc cảm nhận (5 lớp)")
|
| 367 |
+
gr.Markdown("**VAD — toạ độ cảm xúc liên tục (1–5):**")
|
| 368 |
+
with gr.Row():
|
| 369 |
+
t2val = gr.Number(label="Valence (tích cực↑)", interactive=False)
|
| 370 |
+
t2aro = gr.Number(label="Arousal (kích động↑)", interactive=False)
|
| 371 |
+
t2dom = gr.Number(label="Dominance (chi phối↑)", interactive=False)
|
| 372 |
+
t2btn.click(t2_predict, [t2a, t2tgt], [t2verdict, t2emos, t2cat, t2val, t2aro, t2dom])
|
| 373 |
+
|
| 374 |
+
with gr.Tab("3️⃣ Track 3 · Speaker/Accent"):
|
| 375 |
+
gr.Markdown("Tải **audio cần đánh giá** + **audio tham chiếu** → độ giống người nói & accent (1–5).")
|
| 376 |
+
t3t = gr.Audio(type="filepath", label="Audio cần đánh giá (test)")
|
| 377 |
+
t3r = gr.Audio(type="filepath", label="Audio tham chiếu (reference)")
|
| 378 |
+
t3out = gr.Markdown()
|
| 379 |
+
gr.Button("Dự đoán", variant="primary").click(t3_predict, [t3t, t3r], t3out)
|
| 380 |
+
|
| 381 |
+
demo.launch(share=True)
|
| 382 |
+
|
| 383 |
+
# %% [markdown]
|
| 384 |
+
# ## Ghi chú
|
| 385 |
+
# - **Lazy-load:** mỗi `_tN_load()` nạp model 1 lần rồi cache module-level → tab nào không bấm thì không tốn RAM/VRAM.
|
| 386 |
+
# - Track 1 cần URGENT-MOS (GitHub + HuggingFace); Track 3 clone repo baseline (có sẵn checkpoint); Track 2 cần
|
| 387 |
+
# checkpoint exp08 (`ft_emotion_full_20epoch.pt`, slug `toanminh222/cache-exp8`) + tải WavLM/SAILER/audeering.
|
| 388 |
+
# - Hằng `TRUNK_HIDDEN/HEAD_HIDDEN/EMO_MAX_SEC` của Track 2 PHẢI khớp exp08 (ckpt không lưu) — sai là lệch shape.
|
| 389 |
+
# - 3 tab độc lập: thiếu checkpoint/Internet của 1 track chỉ báo lỗi trong tab đó, 2 tab còn lại vẫn chạy.
|
| 390 |
+
# - Cần **GPU T4 + Internet On**. Bản chỉ Track 2 đầy đủ (có tab metric val nội bộ) ở `track2/demo_track2_emotion_gradio`.
|
demo_run_from_hf.ipynb
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "50fca144",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 — Chạy demo Gradio trên KAGGLE bằng cách KÉO code UI từ Hugging Face\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"Chiến lược: **HF = nơi chứa code UI** (Space `tranminhtoan140601/voicemos2026-demo`),\n",
|
| 11 |
+
"**Kaggle = nơi chạy** (GPU T4 free). Notebook này tải `app.py` từ Space về rồi chạy →\n",
|
| 12 |
+
"ra link `*.gradio.live` (sống ~72h) để gửi mentor. KHÔNG tốn GPU trả phí của HF.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"`app.py` tự nhận môi trường: trên Kaggle → `share=True` (link công khai); checkpoint Track 2\n",
|
| 15 |
+
"tự tải từ HF Models repo `tranminhtoan140601/voicemos2026-track2-emotion`.\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"### Cách chạy\n",
|
| 18 |
+
"1. Settings → **GPU T4 + Internet On**.\n",
|
| 19 |
+
"2. **Run All** → cell cuối in link `*.gradio.live`."
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "markdown",
|
| 24 |
+
"id": "53fd0798",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"source": [
|
| 27 |
+
"## 1. Cài deps (khớp Space) — KHÔNG đụng numpy/torch có sẵn Kaggle"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
+
"id": "59468886",
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"outputs": [],
|
| 36 |
+
"source": [
|
| 37 |
+
"!pip install -q gradio==6.17.3 huggingface_hub librosa soundfile speechbrain loralib scipy scikit-learn pandas\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"import subprocess, sys\n",
|
| 40 |
+
"\n",
|
| 41 |
+
"def pip_install(*pkgs):\n",
|
| 42 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=False)\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"pip_install(\"gradio==6.17.3\", \"huggingface_hub\", \"librosa\", \"soundfile\",\n",
|
| 45 |
+
" \"speechbrain\", \"loralib\", \"scipy\", \"scikit-learn\", \"pandas\")"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "markdown",
|
| 50 |
+
"id": "1feff99f",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"source": [
|
| 53 |
+
"## 2. Kéo code UI (app.py) từ HF Space về Kaggle"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "code",
|
| 58 |
+
"execution_count": null,
|
| 59 |
+
"id": "ca69f24e",
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [],
|
| 62 |
+
"source": [
|
| 63 |
+
"import os\n",
|
| 64 |
+
"from huggingface_hub import snapshot_download\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"SPACE_REPO = \"tranminhtoan140601/voicemos2026-demo\"\n",
|
| 67 |
+
"LOCAL_DIR = \"/kaggle/working/vmc_demo\"\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# Tải toàn bộ repo Space (app.py + requirements + README) về local\n",
|
| 70 |
+
"snapshot_download(repo_id=SPACE_REPO, repo_type=\"space\", local_dir=LOCAL_DIR)\n",
|
| 71 |
+
"print(\"✅ Đã kéo Space về:\", LOCAL_DIR)\n",
|
| 72 |
+
"print(\"Files:\", os.listdir(LOCAL_DIR))"
|
| 73 |
+
]
|
| 74 |
+
},
|
| 75 |
+
{
|
| 76 |
+
"cell_type": "markdown",
|
| 77 |
+
"id": "613d7d61",
|
| 78 |
+
"metadata": {},
|
| 79 |
+
"source": [
|
| 80 |
+
"## 3. Chạy app.py (Kaggle có GPU → nhanh; app.py tự share=True ra link gradio.live)\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"`app.py` tải checkpoint Track 2 từ HF Models repo, clone URGENT-MOS/SAILER/baseline lúc bấm nút.\n",
|
| 83 |
+
"Cell này sẽ **chạy mãi** (server Gradio) — đợi dòng `Running on public URL: https://....gradio.live`."
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"id": "8e779da3",
|
| 90 |
+
"metadata": {},
|
| 91 |
+
"outputs": [],
|
| 92 |
+
"source": [
|
| 93 |
+
"# Chạy như tiến trình con để giữ log; KHÔNG có SPACE_ID nên app.py tự bật share=True\n",
|
| 94 |
+
"subprocess.run([sys.executable, \"app.py\"], cwd=LOCAL_DIR, check=True)"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "markdown",
|
| 99 |
+
"id": "288e61cb",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"source": [
|
| 102 |
+
"## Ghi chú\n",
|
| 103 |
+
"- Đây là cách \"1 nguồn code, chạy nơi có GPU free\": sửa UI thì sửa trên Space HF → chạy lại notebook này.\n",
|
| 104 |
+
"- Nếu muốn chạy bản local trong `kaggle_baseline/demo_all_tracks_gradio` (code inline) thì dùng notebook đó.\n",
|
| 105 |
+
"- Lần đầu bấm nút mỗi track sẽ tải model (WavLM/SAILER/URGENT-MOS/ECAPA) → chờ chút; Kaggle có GPU nên inference nhanh.\n",
|
| 106 |
+
"- Cần **Internet On** (tải code HF + model) + **GPU T4**."
|
| 107 |
+
]
|
| 108 |
+
}
|
| 109 |
+
],
|
| 110 |
+
"metadata": {
|
| 111 |
+
"jupytext": {
|
| 112 |
+
"cell_metadata_filter": "-all",
|
| 113 |
+
"main_language": "python",
|
| 114 |
+
"notebook_metadata_filter": "-all"
|
| 115 |
+
}
|
| 116 |
+
},
|
| 117 |
+
"nbformat": 4,
|
| 118 |
+
"nbformat_minor": 5
|
| 119 |
+
}
|
demo_run_from_hf_pipeline.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 — Chạy demo Gradio trên KAGGLE bằng cách KÉO code UI từ Hugging Face
|
| 3 |
+
#
|
| 4 |
+
# Chiến lược: **HF = nơi chứa code UI** (Space `tranminhtoan140601/voicemos2026-demo`),
|
| 5 |
+
# **Kaggle = nơi chạy** (GPU T4 free). Notebook này tải `app.py` từ Space về rồi chạy →
|
| 6 |
+
# ra link `*.gradio.live` (sống ~72h) để gửi mentor. KHÔNG tốn GPU trả phí của HF.
|
| 7 |
+
#
|
| 8 |
+
# `app.py` tự nhận môi trường: trên Kaggle → `share=True` (link công khai); checkpoint Track 2
|
| 9 |
+
# tự tải từ HF Models repo `tranminhtoan140601/voicemos2026-track2-emotion`.
|
| 10 |
+
#
|
| 11 |
+
# ### Cách chạy
|
| 12 |
+
# 1. Settings → **GPU T4 + Internet On**.
|
| 13 |
+
# 2. **Run All** → cell cuối in link `*.gradio.live`.
|
| 14 |
+
|
| 15 |
+
# %% [markdown]
|
| 16 |
+
# ## 1. Cài deps (khớp Space) — KHÔNG đụng numpy/torch có sẵn Kaggle
|
| 17 |
+
|
| 18 |
+
# %%
|
| 19 |
+
# !pip install -q gradio==6.17.3 huggingface_hub librosa soundfile speechbrain loralib scipy scikit-learn pandas
|
| 20 |
+
|
| 21 |
+
import subprocess, sys
|
| 22 |
+
|
| 23 |
+
def pip_install(*pkgs):
|
| 24 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=False)
|
| 25 |
+
|
| 26 |
+
pip_install("gradio==6.17.3", "huggingface_hub", "librosa", "soundfile",
|
| 27 |
+
"speechbrain", "loralib", "scipy", "scikit-learn", "pandas")
|
| 28 |
+
|
| 29 |
+
# %% [markdown]
|
| 30 |
+
# ## 2. Kéo code UI (app.py) từ HF Space về Kaggle
|
| 31 |
+
|
| 32 |
+
# %%
|
| 33 |
+
import os
|
| 34 |
+
from huggingface_hub import snapshot_download
|
| 35 |
+
|
| 36 |
+
SPACE_REPO = "tranminhtoan140601/voicemos2026-demo"
|
| 37 |
+
LOCAL_DIR = "/kaggle/working/vmc_demo"
|
| 38 |
+
|
| 39 |
+
# Tải toàn bộ repo Space (app.py + requirements + README) về local
|
| 40 |
+
snapshot_download(repo_id=SPACE_REPO, repo_type="space", local_dir=LOCAL_DIR)
|
| 41 |
+
print("✅ Đã kéo Space về:", LOCAL_DIR)
|
| 42 |
+
print("Files:", os.listdir(LOCAL_DIR))
|
| 43 |
+
|
| 44 |
+
# %% [markdown]
|
| 45 |
+
# ## 3. Chạy app.py (Kaggle có GPU → nhanh; app.py tự share=True ra link gradio.live)
|
| 46 |
+
#
|
| 47 |
+
# `app.py` tải checkpoint Track 2 từ HF Models repo, clone URGENT-MOS/SAILER/baseline lúc bấm nút.
|
| 48 |
+
# Cell này sẽ **chạy mãi** (server Gradio) — đợi dòng `Running on public URL: https://....gradio.live`.
|
| 49 |
+
|
| 50 |
+
# %%
|
| 51 |
+
# Chạy như tiến trình con để giữ log; KHÔNG có SPACE_ID nên app.py tự bật share=True
|
| 52 |
+
subprocess.run([sys.executable, "app.py"], cwd=LOCAL_DIR, check=True)
|
| 53 |
+
|
| 54 |
+
# %% [markdown]
|
| 55 |
+
# ## Ghi chú
|
| 56 |
+
# - Đây là cách "1 nguồn code, chạy nơi có GPU free": sửa UI thì sửa trên Space HF → chạy lại notebook này.
|
| 57 |
+
# - Nếu muốn chạy bản local trong `kaggle_baseline/demo_all_tracks_gradio` (code inline) thì dùng notebook đó.
|
| 58 |
+
# - Lần đầu bấm nút mỗi track sẽ tải model (WavLM/SAILER/URGENT-MOS/ECAPA) → chờ chút; Kaggle có GPU nên inference nhanh.
|
| 59 |
+
# - Cần **Internet On** (tải code HF + model) + **GPU T4**.
|
track1/demo_track1_gradio.ipynb
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# VMC2026 Track 1 — Demo Gradio (Speech Enhancement: ACR + CCR)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Baseline **URGENT-MOS**. Tải **Audio A** → **ACR** (chất lượng 1–5).\n",
|
| 10 |
+
"Tải thêm **Audio B** → **CCR** (so sánh A vs B, thang −3..+3, >0 nghĩa là A tốt hơn).\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"### Cách dùng trên Kaggle\n",
|
| 13 |
+
"1. Settings → **GPU T4 + Internet On**.\n",
|
| 14 |
+
"2. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h) → gửi mentor."
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "markdown",
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"source": [
|
| 21 |
+
"## 1. Cài đặt + clone URGENT-MOS"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"!pip install -q gradio librosa soundfile\n",
|
| 31 |
+
"!git clone -q https://github.com/vvwangvv/URGENT-MOS.git /kaggle/working/URGENT-MOS\n",
|
| 32 |
+
"!pip install -q -e /kaggle/working/URGENT-MOS"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "markdown",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"source": [
|
| 39 |
+
"## 2. Nạp model + hàm dự đoán"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"id": "68754b2d",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"import os, sys, subprocess, librosa\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"DEVICE = \"cuda\"\n",
|
| 52 |
+
"URGENT_REPO = \"/kaggle/working/URGENT-MOS\"\n",
|
| 53 |
+
"URGENT_CKPT = \"urgent-challenge/urgent-mos-f1c1m5dcorpus\" # tự tải từ HuggingFace\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"def _ensure_urgent_mos():\n",
|
| 57 |
+
" \"\"\"Tự clone + cài URGENT-MOS nếu chưa có (phòng khi cell cài chưa chạy).\"\"\"\n",
|
| 58 |
+
" if not os.path.isdir(URGENT_REPO):\n",
|
| 59 |
+
" subprocess.run(f\"git clone -q https://github.com/vvwangvv/URGENT-MOS.git {URGENT_REPO}\",\n",
|
| 60 |
+
" shell=True, check=True)\n",
|
| 61 |
+
" subprocess.run(f\"pip install -q -e {URGENT_REPO}\", shell=True, check=True)\n",
|
| 62 |
+
" if URGENT_REPO not in sys.path: # package nằm ở root repo → thêm vào path là import được\n",
|
| 63 |
+
" sys.path.insert(0, URGENT_REPO)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"_M = {}\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"def _load():\n",
|
| 69 |
+
" if \"m\" not in _M:\n",
|
| 70 |
+
" _ensure_urgent_mos()\n",
|
| 71 |
+
" import torch\n",
|
| 72 |
+
" from urgent_mos.utils import load_model_from_checkpoint\n",
|
| 73 |
+
" dev = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 74 |
+
" m = load_model_from_checkpoint(URGENT_CKPT, dev)\n",
|
| 75 |
+
" m.eval()\n",
|
| 76 |
+
" _M[\"m\"] = m\n",
|
| 77 |
+
" return _M[\"m\"]\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"def _scalar(x):\n",
|
| 81 |
+
" return float(x.item()) if hasattr(x, \"item\") else float(x)\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"def predict(audio_a, audio_b):\n",
|
| 85 |
+
" import torch\n",
|
| 86 |
+
" from urgent_mos.api.infer import infer, infer_pairs\n",
|
| 87 |
+
" if not audio_a:\n",
|
| 88 |
+
" return \"⚠️ Hãy tải lên ít nhất Audio A.\"\n",
|
| 89 |
+
" m = _load()\n",
|
| 90 |
+
" wa = torch.from_numpy(librosa.load(audio_a, sr=16000, mono=True)[0]).float()\n",
|
| 91 |
+
" acr_a = max(1.0, min(5.0, _scalar(infer(m, [wa], sample_rate=[16000],\n",
|
| 92 |
+
" batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
|
| 93 |
+
" out = f\"ACR (Audio A): {acr_a:.3f} (chất lượng tuyệt đối, thang 1–5)\"\n",
|
| 94 |
+
" if audio_b:\n",
|
| 95 |
+
" wb = torch.from_numpy(librosa.load(audio_b, sr=16000, mono=True)[0]).float()\n",
|
| 96 |
+
" acr_b = max(1.0, min(5.0, _scalar(infer(m, [wb], sample_rate=[16000],\n",
|
| 97 |
+
" batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
|
| 98 |
+
" ccr = max(-3.0, min(3.0, _scalar(infer_pairs(m, [(wa, wb)], sample_rate=[(16000, 16000)],\n",
|
| 99 |
+
" batch_frames=None, num_workers=0)[0][\"mos_overall\"])))\n",
|
| 100 |
+
" out += (f\"\\nACR (Audio B): {acr_b:.3f}\"\n",
|
| 101 |
+
" f\"\\nCCR (A so với B): {ccr:+.3f} (>0: A tốt hơn B; thang −3..+3)\")\n",
|
| 102 |
+
" return out"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "markdown",
|
| 107 |
+
"metadata": {},
|
| 108 |
+
"source": [
|
| 109 |
+
"## 3. Giao diện Gradio + launch"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"cell_type": "code",
|
| 114 |
+
"execution_count": null,
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"outputs": [],
|
| 117 |
+
"source": [
|
| 118 |
+
"import gradio as gr\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"with gr.Blocks(title=\"VMC2026 Track 1 — ACR/CCR\") as demo:\n",
|
| 121 |
+
" gr.Markdown(\"# 🎙️ Track 1 · Speech Enhancement (ACR / CCR)\\n\"\n",
|
| 122 |
+
" \"Tải **Audio A** để có ACR. Tải thêm **Audio B** để so sánh CCR (A vs B).\")\n",
|
| 123 |
+
" a = gr.Audio(type=\"filepath\", label=\"Audio A (bắt buộc)\")\n",
|
| 124 |
+
" b = gr.Audio(type=\"filepath\", label=\"Audio B (tùy chọn — để tính CCR)\")\n",
|
| 125 |
+
" out = gr.Textbox(label=\"Kết quả\", lines=4)\n",
|
| 126 |
+
" gr.Button(\"Dự đoán\", variant=\"primary\").click(predict, [a, b], out)\n",
|
| 127 |
+
"\n",
|
| 128 |
+
"demo.launch(share=True)"
|
| 129 |
+
]
|
| 130 |
+
}
|
| 131 |
+
],
|
| 132 |
+
"metadata": {
|
| 133 |
+
"kernelspec": {
|
| 134 |
+
"display_name": "Python 3",
|
| 135 |
+
"language": "python",
|
| 136 |
+
"name": "python3"
|
| 137 |
+
},
|
| 138 |
+
"language_info": {
|
| 139 |
+
"name": "python"
|
| 140 |
+
}
|
| 141 |
+
},
|
| 142 |
+
"nbformat": 4,
|
| 143 |
+
"nbformat_minor": 5
|
| 144 |
+
}
|
track1/demo_track1_gradio_pipeline.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 1 — Demo Gradio (Speech Enhancement: ACR + CCR)
|
| 3 |
+
#
|
| 4 |
+
# Baseline **URGENT-MOS**. Tải **Audio A** → **ACR** (chất lượng 1–5).
|
| 5 |
+
# Tải thêm **Audio B** → **CCR** (so sánh A vs B, thang −3..+3, >0 nghĩa là A tốt hơn).
|
| 6 |
+
#
|
| 7 |
+
# ### Cách dùng trên Kaggle
|
| 8 |
+
# 1. Settings → **GPU T4 + Internet On**.
|
| 9 |
+
# 2. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h) → gửi mentor.
|
| 10 |
+
|
| 11 |
+
# %% [markdown]
|
| 12 |
+
# ## 1. Cài đặt + clone URGENT-MOS
|
| 13 |
+
|
| 14 |
+
# %%
|
| 15 |
+
# !pip install -q gradio librosa soundfile
|
| 16 |
+
# !git clone -q https://github.com/vvwangvv/URGENT-MOS.git /kaggle/working/URGENT-MOS
|
| 17 |
+
# !pip install -q -e /kaggle/working/URGENT-MOS
|
| 18 |
+
|
| 19 |
+
# %% [markdown]
|
| 20 |
+
# ## 2. Nạp model + hàm dự đoán
|
| 21 |
+
|
| 22 |
+
# %%
|
| 23 |
+
import os, sys, subprocess, librosa
|
| 24 |
+
|
| 25 |
+
DEVICE = "cuda"
|
| 26 |
+
URGENT_REPO = "/kaggle/working/URGENT-MOS"
|
| 27 |
+
URGENT_CKPT = "urgent-challenge/urgent-mos-f1c1m5dcorpus" # tự tải từ HuggingFace
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _ensure_urgent_mos():
|
| 31 |
+
"""Đảm bảo import được `urgent_mos`: clone repo nếu thiếu, cài deps, thêm vào sys.path."""
|
| 32 |
+
if not os.path.isdir(URGENT_REPO):
|
| 33 |
+
subprocess.run(f"git clone -q https://github.com/vvwangvv/URGENT-MOS.git {URGENT_REPO}",
|
| 34 |
+
shell=True, check=True)
|
| 35 |
+
# package nằm ở ROOT repo → thêm vào path là import được, KHÔNG phụ thuộc pip install thành công
|
| 36 |
+
if URGENT_REPO not in sys.path:
|
| 37 |
+
sys.path.insert(0, URGENT_REPO)
|
| 38 |
+
import importlib
|
| 39 |
+
importlib.invalidate_caches()
|
| 40 |
+
# thử import; nếu thiếu dependency thì cài editable (kéo theo torchcodec, hydra-core, omegaconf...)
|
| 41 |
+
try:
|
| 42 |
+
importlib.import_module("urgent_mos.api.infer")
|
| 43 |
+
except Exception:
|
| 44 |
+
subprocess.run(f"pip install -q -e {URGENT_REPO}", shell=True)
|
| 45 |
+
importlib.invalidate_caches()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
_M = {}
|
| 49 |
+
|
| 50 |
+
def _load():
|
| 51 |
+
if "m" not in _M:
|
| 52 |
+
_ensure_urgent_mos()
|
| 53 |
+
import torch
|
| 54 |
+
from urgent_mos.utils import load_model_from_checkpoint
|
| 55 |
+
dev = DEVICE if torch.cuda.is_available() else "cpu"
|
| 56 |
+
m = load_model_from_checkpoint(URGENT_CKPT, dev)
|
| 57 |
+
m.eval()
|
| 58 |
+
_M["m"] = m
|
| 59 |
+
return _M["m"]
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _scalar(x):
|
| 63 |
+
return float(x.item()) if hasattr(x, "item") else float(x)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def predict(audio_a, audio_b):
|
| 67 |
+
if not audio_a:
|
| 68 |
+
return "⚠️ Hãy tải lên ít nhất Audio A."
|
| 69 |
+
import torch
|
| 70 |
+
m = _load() # _load() tự ensure repo + sys.path TRƯỚC
|
| 71 |
+
from urgent_mos.api.infer import infer, infer_pairs # giờ mới import được
|
| 72 |
+
wa = torch.from_numpy(librosa.load(audio_a, sr=16000, mono=True)[0]).float()
|
| 73 |
+
acr_a = max(1.0, min(5.0, _scalar(infer(m, [wa], sample_rate=[16000],
|
| 74 |
+
batch_frames=None, num_workers=0)[0]["mos_overall"])))
|
| 75 |
+
out = f"ACR (Audio A): {acr_a:.3f} (chất lượng tuyệt đối, thang 1–5)"
|
| 76 |
+
if audio_b:
|
| 77 |
+
wb = torch.from_numpy(librosa.load(audio_b, sr=16000, mono=True)[0]).float()
|
| 78 |
+
acr_b = max(1.0, min(5.0, _scalar(infer(m, [wb], sample_rate=[16000],
|
| 79 |
+
batch_frames=None, num_workers=0)[0]["mos_overall"])))
|
| 80 |
+
ccr = max(-3.0, min(3.0, _scalar(infer_pairs(m, [(wa, wb)], sample_rate=[(16000, 16000)],
|
| 81 |
+
batch_frames=None, num_workers=0)[0]["mos_overall"])))
|
| 82 |
+
out += (f"\nACR (Audio B): {acr_b:.3f}"
|
| 83 |
+
f"\nCCR (A so với B): {ccr:+.3f} (>0: A tốt hơn B; thang −3..+3)")
|
| 84 |
+
return out
|
| 85 |
+
|
| 86 |
+
# %% [markdown]
|
| 87 |
+
# ## 3. Giao diện Gradio + launch
|
| 88 |
+
|
| 89 |
+
# %%
|
| 90 |
+
import gradio as gr
|
| 91 |
+
|
| 92 |
+
with gr.Blocks(title="VMC2026 Track 1 — ACR/CCR") as demo:
|
| 93 |
+
gr.Markdown("# 🎙️ Track 1 · Speech Enhancement (ACR / CCR)\n"
|
| 94 |
+
"Tải **Audio A** để có ACR. Tải thêm **Audio B** để so sánh CCR (A vs B).")
|
| 95 |
+
a = gr.Audio(type="filepath", label="Audio A (bắt buộc)")
|
| 96 |
+
b = gr.Audio(type="filepath", label="Audio B (tùy chọn — để tính CCR)")
|
| 97 |
+
out = gr.Textbox(label="Kết quả", lines=4)
|
| 98 |
+
gr.Button("Dự đoán", variant="primary").click(predict, [a, b], out)
|
| 99 |
+
|
| 100 |
+
demo.launch(share=True)
|
track1/track1_baseline.ipynb
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# VMC2026 Track 1 — Baseline (URGENT-MOS)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Chạy ngay được — data dev công khai trên HuggingFace, checkpoint tự tải.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"**Trước khi chạy:** Session options → Accelerator = **GPU T4**, Internet = **On** (verify phone nếu cần).\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"Output: `submission_track1.zip` (chứa `predictions.csv`) → nộp Track 1."
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"## 1. Cài đặt URGENT-MOS"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"metadata": {},
|
| 27 |
+
"outputs": [],
|
| 28 |
+
"source": [
|
| 29 |
+
"!git clone -q https://github.com/vvwangvv/URGENT-MOS.git /kaggle/working/URGENT-MOS\n",
|
| 30 |
+
"!pip install -q -e /kaggle/working/URGENT-MOS"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "markdown",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"source": [
|
| 37 |
+
"## 2. Smoke test (10 mẫu)\n",
|
| 38 |
+
"Kiểm tra môi trường + tải checkpoint `urgent-challenge/urgent-mos-f1c1m5dcorpus` từ HF (lần đầu hơi lâu)."
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "code",
|
| 43 |
+
"execution_count": null,
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"outputs": [],
|
| 46 |
+
"source": [
|
| 47 |
+
"!cd /kaggle/working/URGENT-MOS && python scripts/infer_vmc2026_track1.py --split dev --limit 10 --output /kaggle/working/predictions_smoke.csv\n",
|
| 48 |
+
"import pandas as pd\n",
|
| 49 |
+
"pd.read_csv('/kaggle/working/predictions_smoke.csv').head()"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "markdown",
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"source": [
|
| 56 |
+
"## 3. Inference đầy đủ (ACR 1008 + CCR 2520)\n",
|
| 57 |
+
"Nếu OOM: thêm `--batch-frames 8000` (hoặc nhỏ hơn)."
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "code",
|
| 62 |
+
"execution_count": null,
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": "!cd /kaggle/working/URGENT-MOS && python scripts/infer_vmc2026_track1.py --split dev --output /kaggle/working/predictions.csv"
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": [
|
| 71 |
+
"## 4. Validate"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "code",
|
| 76 |
+
"execution_count": null,
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": "import pandas as pd\ndf = pd.read_csv('/kaggle/working/predictions.csv')\nassert list(df.columns) == ['sample_id', 'pred_score'], df.columns.tolist()\nacr = df[df['sample_id'].str.contains('-acr_')]\nccr = df[df['sample_id'].str.contains('-ccr_')]\nprint(f'Tổng {len(df)} | ACR {len(acr)} | CCR {len(ccr)}')\nprint('ACR:', acr['pred_score'].min(), '→', acr['pred_score'].max(), '(cần [1,5])')\nprint('CCR:', ccr['pred_score'].min(), '→', ccr['pred_score'].max(), '(cần [-3,+3])')"
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "markdown",
|
| 83 |
+
"metadata": {},
|
| 84 |
+
"source": [
|
| 85 |
+
"## 5. Đóng zip nộp"
|
| 86 |
+
]
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"cell_type": "code",
|
| 90 |
+
"execution_count": null,
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [],
|
| 93 |
+
"source": "!cd /kaggle/working && zip -j submission_track1.zip predictions.csv && unzip -l submission_track1.zip\nprint('Tải /kaggle/working/submission_track1.zip → nộp My Submissions (chọn Track 1, bỏ chọn track khác)')"
|
| 94 |
+
}
|
| 95 |
+
],
|
| 96 |
+
"metadata": {
|
| 97 |
+
"kernelspec": {
|
| 98 |
+
"display_name": "Python 3",
|
| 99 |
+
"language": "python",
|
| 100 |
+
"name": "python3"
|
| 101 |
+
},
|
| 102 |
+
"language_info": {
|
| 103 |
+
"name": "python"
|
| 104 |
+
}
|
| 105 |
+
},
|
| 106 |
+
"nbformat": 4,
|
| 107 |
+
"nbformat_minor": 5
|
| 108 |
+
}
|
track1/track1_baseline_pipeline.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 1 — Baseline Pipeline (Kaggle)
|
| 3 |
+
#
|
| 4 |
+
# Baseline = **URGENT-MOS** (checkpoint pre-trained tự tải từ HuggingFace).
|
| 5 |
+
# Repo có sẵn script xuất đúng format nộp → gần như không phải code gì.
|
| 6 |
+
#
|
| 7 |
+
# **Không bị chặn bởi license:** dev data Track 1 công khai trên HuggingFace
|
| 8 |
+
# (`urgent-challenge/vmc2026-track1-dev`, configs `acr` + `ccr`).
|
| 9 |
+
#
|
| 10 |
+
# **Cách dùng:** Notebook → GPU T4 + Internet On → chạy lần lượt các cell.
|
| 11 |
+
# Output: `predictions.csv` (`sample_id,pred_score`) → zip → nộp Track 1.
|
| 12 |
+
# ⚠️ File trong zip BẮT BUỘC tên `predictions.csv` (guideline Track 1) — đừng để `predictions_dev.csv`.
|
| 13 |
+
|
| 14 |
+
# %% [markdown]
|
| 15 |
+
# ## 1. Cài đặt URGENT-MOS
|
| 16 |
+
|
| 17 |
+
# %%
|
| 18 |
+
# !git clone -q https://github.com/vvwangvv/URGENT-MOS.git /kaggle/working/URGENT-MOS
|
| 19 |
+
# !pip install -q -e /kaggle/working/URGENT-MOS
|
| 20 |
+
|
| 21 |
+
# %% [markdown]
|
| 22 |
+
# ## 2. Smoke test (vài mẫu) — kiểm tra môi trường + tải checkpoint
|
| 23 |
+
# Checkpoint `urgent-challenge/urgent-mos-f1c1m5dcorpus` tự tải từ HF lần chạy đầu.
|
| 24 |
+
|
| 25 |
+
# %%
|
| 26 |
+
# !cd /kaggle/working/URGENT-MOS && python scripts/infer_vmc2026_track1.py \
|
| 27 |
+
# --split dev --limit 10 --output /kaggle/working/predictions_smoke.csv
|
| 28 |
+
# import pandas as pd; pd.read_csv("/kaggle/working/predictions_smoke.csv").head()
|
| 29 |
+
|
| 30 |
+
# %% [markdown]
|
| 31 |
+
# ## 3. Inference đầy đủ trên dev set (ACR + CCR)
|
| 32 |
+
# Script tự tải dataset từ HF, chạy cả ACR (1008) + CCR (2520) → 1 file predictions.
|
| 33 |
+
|
| 34 |
+
# %%
|
| 35 |
+
# !cd /kaggle/working/URGENT-MOS && python scripts/infer_vmc2026_track1.py \
|
| 36 |
+
# --split dev --output /kaggle/working/predictions.csv
|
| 37 |
+
# Nếu OOM: thêm --batch-frames <N> để giảm bộ nhớ.
|
| 38 |
+
|
| 39 |
+
# %% [markdown]
|
| 40 |
+
# ## 4. Validate + đóng zip nộp
|
| 41 |
+
|
| 42 |
+
# %%
|
| 43 |
+
import pandas as pd
|
| 44 |
+
|
| 45 |
+
PRED = "/kaggle/working/predictions.csv"
|
| 46 |
+
df = pd.read_csv(PRED)
|
| 47 |
+
assert list(df.columns) == ["sample_id", "pred_score"], f"Header sai: {df.columns.tolist()}"
|
| 48 |
+
|
| 49 |
+
acr = df[df["sample_id"].str.contains("-acr_")]
|
| 50 |
+
ccr = df[df["sample_id"].str.contains("-ccr_")]
|
| 51 |
+
print(f"Tổng {len(df)} dòng | ACR {len(acr)} | CCR {len(ccr)}")
|
| 52 |
+
print("ACR range:", acr["pred_score"].min(), "→", acr["pred_score"].max(), "(cần [1,5])")
|
| 53 |
+
print("CCR range:", ccr["pred_score"].min(), "→", ccr["pred_score"].max(), "(cần [-3,+3])")
|
| 54 |
+
# Kỳ vọng dev: ACR=1008, CCR=2520
|
| 55 |
+
|
| 56 |
+
# %%
|
| 57 |
+
# !cd /kaggle/working && zip -j submission_track1.zip predictions.csv && unzip -l submission_track1.zip
|
| 58 |
+
|
| 59 |
+
# %% [markdown]
|
| 60 |
+
# ## Ghi chú
|
| 61 |
+
# - Nộp: My Submissions → chọn **Track 1**, **bỏ chọn** track khác → upload `submission_track1.zip`.
|
| 62 |
+
# - File nộp Track 1 tên **`predictions.csv`** (KHÁC Track 2/3 dùng `answer.txt`). Script đã xuất đúng cột `sample_id,pred_score`.
|
| 63 |
+
# - Eval phase: đổi `--split test` (sau khi eval data ra 31/7).
|
| 64 |
+
# - GPU khuyến nghị; chỉ inference nên nhẹ, fit T4 16GB.
|
track2/demo_track2_emotion_gradio.ipynb
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "73831f26",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — Demo Gradio \"Emotional TTS Evaluator\" (model TỐT NHẤT = exp08)\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"Demo này dùng **checkpoint cảm xúc tốt nhất** (`ft_emotion_full_20epoch.pt`: WavLM fine-tune warm-start\n",
|
| 11 |
+
"SAILER + audeering frozen) để chấm **5 cột cảm xúc** của 1 file giọng TTS: **EMOS / CAT / VAL / ARO / DOM**.\n",
|
| 12 |
+
"Khác demo cũ (`demo_track2_gradio`) dùng baseline UTMOS+emotion2vec+Gemini — bản này KHÔNG cần API.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"**2 tab:**\n",
|
| 15 |
+
"1. *Chấm 1 file TTS* — tải audio + chọn cảm xúc target → ra điểm biểu cảm cảm xúc + diễn giải KHỚP/LỆCH.\n",
|
| 16 |
+
"2. *Metric bộ chấm* — tính UTT-SRCC (EMOS/VAD) + CAT-err trên val nội bộ (train.csv) → cho biết độ tin cậy.\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"**Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input (1) dataset Track 2, (2) dataset chứa\n",
|
| 19 |
+
"`ft_emotion_full_20epoch.pt` → Run All → cell cuối in link `*.gradio.live`."
|
| 20 |
+
]
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "markdown",
|
| 24 |
+
"id": "d505183e",
|
| 25 |
+
"metadata": {},
|
| 26 |
+
"source": [
|
| 27 |
+
"## 0. Cấu hình — auto-dò DATA_ROOT + checkpoint"
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
+
"id": "20d538e6",
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"outputs": [],
|
| 36 |
+
"source": [
|
| 37 |
+
"import os, glob\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"def find_data_root(search_root=\"/kaggle/input\"):\n",
|
| 40 |
+
" cands = []\n",
|
| 41 |
+
" for train_csv in glob.glob(os.path.join(search_root, \"**\", \"sets\", \"train.csv\"), recursive=True):\n",
|
| 42 |
+
" root = os.path.dirname(os.path.dirname(train_csv))\n",
|
| 43 |
+
" score = os.path.isdir(os.path.join(root, \"wav\")) + os.path.exists(os.path.join(root, \"metadata.csv\"))\n",
|
| 44 |
+
" cands.append((score, root))\n",
|
| 45 |
+
" cands.sort(reverse=True)\n",
|
| 46 |
+
" return cands\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"_cands = find_data_root(\"/kaggle/input\")\n",
|
| 49 |
+
"if _cands:\n",
|
| 50 |
+
" print(\"🔎 Ứng viên DATA_ROOT:\")\n",
|
| 51 |
+
" for sc, r in _cands:\n",
|
| 52 |
+
" print(f\" [{sc}/2] {r}\")\n",
|
| 53 |
+
" DATA_ROOT = _cands[0][1]\n",
|
| 54 |
+
" print(f\"👉 Tự chọn DATA_ROOT = {DATA_ROOT}\")\n",
|
| 55 |
+
"else:\n",
|
| 56 |
+
" DATA_ROOT = \"/kaggle/input/datasets/minhtoan2\" # dự phòng\n",
|
| 57 |
+
" print(f\"❌ Không thấy sets/train.csv → dự phòng {DATA_ROOT} (đã Add Input chưa?)\")\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 60 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\"\n",
|
| 61 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\"\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"# ── Checkpoint cảm xúc exp08 (ưu tiên bản 20 epoch = TỐT NHẤT) ─────────────────\n",
|
| 64 |
+
"CKPT_PATH = \"\" # << \"\" = auto-dò; hoặc trỏ tay \"/kaggle/input/<slug>/ft_emotion_full_20epoch.pt\"\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"def find_ckpt(explicit):\n",
|
| 67 |
+
" if explicit and os.path.exists(explicit):\n",
|
| 68 |
+
" return explicit\n",
|
| 69 |
+
" pats = [\"ft_emotion_full_20epoch*.pt\", \"ft_emotion_full*.pt\"] # ưu tiên bản 20epoch\n",
|
| 70 |
+
" for pat in pats:\n",
|
| 71 |
+
" for base in [\"/kaggle/input\", \"/kaggle/working\"]:\n",
|
| 72 |
+
" hits = sorted(glob.glob(os.path.join(base, \"**\", pat), recursive=True))\n",
|
| 73 |
+
" if hits:\n",
|
| 74 |
+
" return hits[0]\n",
|
| 75 |
+
" return \"\"\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"CKPT_PATH = find_ckpt(CKPT_PATH)\n",
|
| 78 |
+
"assert CKPT_PATH, \"❌ Không thấy ft_emotion_full*.pt. Add Input dataset chứa checkpoint exp08 chưa?\"\n",
|
| 79 |
+
"print(\"✅ Checkpoint:\", CKPT_PATH)\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"# ── Hằng kiến trúc PHẢI khớp exp08 (ckpt không lưu các số này) ────────────────\n",
|
| 82 |
+
"DEVICE = \"cuda\"\n",
|
| 83 |
+
"SR = 16000\n",
|
| 84 |
+
"EMO_MAX_SEC = 8\n",
|
| 85 |
+
"TRUNK_HIDDEN = 512\n",
|
| 86 |
+
"HEAD_HIDDEN = 128\n",
|
| 87 |
+
"DROPOUT = 0.3 # không ảnh hưởng eval\n",
|
| 88 |
+
"USE_AMP = True\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 91 |
+
"_EMO_ALIAS = {\n",
|
| 92 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 93 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 94 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 95 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 96 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 97 |
+
"}\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"def norm_emotion(label):\n",
|
| 100 |
+
" key = str(label).strip().lower()\n",
|
| 101 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"def stem(p):\n",
|
| 104 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 105 |
+
"\n",
|
| 106 |
+
"# Mốc exp08 (val nội bộ / DEV) để so trong tab metric\n",
|
| 107 |
+
"EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751}"
|
| 108 |
+
]
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"cell_type": "markdown",
|
| 112 |
+
"id": "daadb3d8",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"source": [
|
| 115 |
+
"## 1. Cài đặt + clone code SAILER"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"cell_type": "code",
|
| 120 |
+
"execution_count": null,
|
| 121 |
+
"id": "6303e9cf",
|
| 122 |
+
"metadata": {},
|
| 123 |
+
"outputs": [],
|
| 124 |
+
"source": [
|
| 125 |
+
"import sys, subprocess\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"def pip_install(*pkgs):\n",
|
| 128 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"pip_install(\"gradio\", \"loralib\", \"speechbrain\", \"librosa\", \"soundfile\",\n",
|
| 131 |
+
" \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 134 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 135 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 136 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 137 |
+
"if REPO_DIR not in sys.path:\n",
|
| 138 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"cell_type": "markdown",
|
| 143 |
+
"id": "233b6770",
|
| 144 |
+
"metadata": {},
|
| 145 |
+
"source": [
|
| 146 |
+
"## 2. Nạp model exp08 (backbone WavLM ft + audeering frozen + heads) — 1 lần"
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "code",
|
| 151 |
+
"execution_count": null,
|
| 152 |
+
"id": "1881d27c",
|
| 153 |
+
"metadata": {
|
| 154 |
+
"lines_to_next_cell": 1
|
| 155 |
+
},
|
| 156 |
+
"outputs": [],
|
| 157 |
+
"source": [
|
| 158 |
+
"import torch\n",
|
| 159 |
+
"import torch.nn as nn\n",
|
| 160 |
+
"import torch.nn.functional as F\n",
|
| 161 |
+
"import numpy as np\n",
|
| 162 |
+
"import librosa\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 165 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (chậm)\")\n",
|
| 166 |
+
"\n",
|
| 167 |
+
"ckpt = torch.load(CKPT_PATH, map_location=\"cpu\", weights_only=False) # ckpt có numpy → cần False\n",
|
| 168 |
+
"assert \"wavlm\" in ckpt and \"heads\" in ckpt, \"❌ Checkpoint thiếu 'wavlm'/'heads' → cần ft_emotion_full_20epoch.pt đủ.\"\n",
|
| 169 |
+
"AUD_DIM = int(ckpt.get(\"AUD_DIM\", 0))\n",
|
| 170 |
+
"USE_AUDEERING = AUD_DIM > 0\n",
|
| 171 |
+
"print(\"✅ Nạp ckpt | keys:\", list(ckpt.keys()), \"| AUD_DIM:\", AUD_DIM, \"(audeering\", \"ON)\" if USE_AUDEERING else \"OFF)\")\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"def find_hf_backbone(module):\n",
|
| 174 |
+
" cands = []\n",
|
| 175 |
+
" for name, m in module.named_modules():\n",
|
| 176 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 177 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 178 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 179 |
+
" cands.append((name, m))\n",
|
| 180 |
+
" if not cands:\n",
|
| 181 |
+
" return None, None\n",
|
| 182 |
+
" cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
|
| 183 |
+
" return cands[0]\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"wavlm = None\n",
|
| 186 |
+
"try:\n",
|
| 187 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 188 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 189 |
+
" _name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 190 |
+
" if wavlm is not None:\n",
|
| 191 |
+
" print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{_name}'\")\n",
|
| 192 |
+
"except Exception as e:\n",
|
| 193 |
+
" print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
|
| 194 |
+
"if wavlm is None:\n",
|
| 195 |
+
" from transformers import WavLMModel\n",
|
| 196 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 197 |
+
" print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"wavlm = wavlm.to(device).eval()\n",
|
| 200 |
+
"WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 201 |
+
"wavlm.config.layerdrop = 0.0\n",
|
| 202 |
+
"_miss, _unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
|
| 203 |
+
"print(f\"🔁 load wavlm: thiếu {len(_miss)} / dư {len(_unexp)} key (kỳ vọng ~0)\")\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"def masked_mean(hidden, attn_mask):\n",
|
| 206 |
+
" if attn_mask is None:\n",
|
| 207 |
+
" return hidden.mean(dim=1)\n",
|
| 208 |
+
" try:\n",
|
| 209 |
+
" fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
|
| 210 |
+
" except Exception:\n",
|
| 211 |
+
" return hidden.mean(dim=1)\n",
|
| 212 |
+
" fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
|
| 213 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"@torch.no_grad()\n",
|
| 216 |
+
"def wavlm_embed(input_values, attn_mask):\n",
|
| 217 |
+
" out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
|
| 218 |
+
" return masked_mean(out, attn_mask)\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"# ── audeering frozen (đặc trưng phụ) — chỉ dựng nếu ckpt có dùng ──\n",
|
| 221 |
+
"aud_backbone = aud_head = aud_proc = None\n",
|
| 222 |
+
"if USE_AUDEERING:\n",
|
| 223 |
+
" from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 224 |
+
" from huggingface_hub import hf_hub_download\n",
|
| 225 |
+
" AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 226 |
+
" aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 227 |
+
" aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 228 |
+
" aud_backbone = Wav2Vec2Model(aud_cfg)\n",
|
| 229 |
+
" try:\n",
|
| 230 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 231 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 232 |
+
" except Exception:\n",
|
| 233 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 234 |
+
" bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 235 |
+
" aud_backbone.load_state_dict(bb_sd, strict=False)\n",
|
| 236 |
+
" _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 237 |
+
" aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd[\"classifier.out_proj.weight\"].shape[0]))\n",
|
| 238 |
+
" aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 239 |
+
" aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 240 |
+
" aud_backbone = aud_backbone.to(device).eval()\n",
|
| 241 |
+
" aud_head = aud_head.to(device).eval()\n",
|
| 242 |
+
" assert _hid + 3 == AUD_DIM, f\"⚠️ AUD_DIM dựng ({_hid+3}) ≠ ckpt ({AUD_DIM})\"\n",
|
| 243 |
+
" print(f\"✅ audeering frozen ({AUD_DIM}-D)\")\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"@torch.no_grad()\n",
|
| 246 |
+
"def audeering_feat(wave):\n",
|
| 247 |
+
" x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 248 |
+
" x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
|
| 249 |
+
" h = aud_backbone(x)[0].mean(dim=1)\n",
|
| 250 |
+
" out = aud_head(h)[0].cpu().numpy()\n",
|
| 251 |
+
" vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
|
| 252 |
+
" return np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"# ── EmoHeads (khớp exp08) + nạp trọng số + chuẩn hóa từ ckpt ──\n",
|
| 255 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 256 |
+
"TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"class EmoHeads(nn.Module):\n",
|
| 259 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 260 |
+
" super().__init__()\n",
|
| 261 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 262 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 263 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 264 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 265 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 266 |
+
" def forward(self, feat, tgt):\n",
|
| 267 |
+
" h = self.trunk(feat)\n",
|
| 268 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device).eval()\n",
|
| 271 |
+
"_hm, _hu = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
|
| 272 |
+
"print(f\"🔁 load heads: thiếu {len(_hm)} / dư {len(_hu)} key (kỳ vọng 0)\")\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
|
| 275 |
+
"vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
|
| 276 |
+
"print(f\"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"def onehot_target(tgt):\n",
|
| 279 |
+
" v = np.zeros(N_EMO, dtype=np.float32)\n",
|
| 280 |
+
" if tgt in EMOTIONS5:\n",
|
| 281 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 282 |
+
" return v"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "markdown",
|
| 287 |
+
"id": "3c7ce7b5",
|
| 288 |
+
"metadata": {},
|
| 289 |
+
"source": [
|
| 290 |
+
"## 3. Hàm suy luận lõi (1 wave numpy → emos/cat5/vad3)"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"execution_count": null,
|
| 296 |
+
"id": "daf81b82",
|
| 297 |
+
"metadata": {
|
| 298 |
+
"lines_to_next_cell": 1
|
| 299 |
+
},
|
| 300 |
+
"outputs": [],
|
| 301 |
+
"source": [
|
| 302 |
+
"@torch.no_grad()\n",
|
| 303 |
+
"def infer_wave(wave, target_emotion):\n",
|
| 304 |
+
" \"\"\"wave: numpy float32 (đã 16k mono). target_emotion: str hoặc None. Trả (emos, cat5, vad3).\"\"\"\n",
|
| 305 |
+
" wave = wave[: EMO_MAX_SEC * SR].astype(np.float32)\n",
|
| 306 |
+
" iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 307 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
|
| 308 |
+
" tgt = torch.from_numpy(onehot_target(norm_emotion(target_emotion) if target_emotion else None)).unsqueeze(0).to(device)\n",
|
| 309 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 310 |
+
" fw = wavlm_embed(iv, am)\n",
|
| 311 |
+
" if USE_AUDEERING:\n",
|
| 312 |
+
" fw = torch.cat([fw, torch.from_numpy(audeering_feat(wave)).unsqueeze(0).to(device)], dim=1)\n",
|
| 313 |
+
" emos_p, cat_l, vad_p = heads(fw, tgt)\n",
|
| 314 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 315 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 316 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 317 |
+
" return emos, cat5, vad3"
|
| 318 |
+
]
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"cell_type": "markdown",
|
| 322 |
+
"id": "768a4de5",
|
| 323 |
+
"metadata": {},
|
| 324 |
+
"source": [
|
| 325 |
+
"## 4. Hàm metric val nội bộ (UTT-SRCC + CAT-err) — đánh giá độ tin cậy bộ chấm"
|
| 326 |
+
]
|
| 327 |
+
},
|
| 328 |
+
{
|
| 329 |
+
"cell_type": "code",
|
| 330 |
+
"execution_count": null,
|
| 331 |
+
"id": "f4900fe6",
|
| 332 |
+
"metadata": {
|
| 333 |
+
"lines_to_next_cell": 1
|
| 334 |
+
},
|
| 335 |
+
"outputs": [],
|
| 336 |
+
"source": [
|
| 337 |
+
"import pandas as pd\n",
|
| 338 |
+
"from scipy.stats import spearmanr\n",
|
| 339 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 340 |
+
"from tqdm.auto import tqdm\n",
|
| 341 |
+
"\n",
|
| 342 |
+
"def parse_emocat_votes(cell):\n",
|
| 343 |
+
" v = np.zeros(N_EMO, dtype=np.float32)\n",
|
| 344 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 345 |
+
" e = norm_emotion(tok)\n",
|
| 346 |
+
" if e in EMOTIONS5:\n",
|
| 347 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 348 |
+
" return v\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"def _col(cols_map, *names, df=None, default_idx=None):\n",
|
| 351 |
+
" for n in names:\n",
|
| 352 |
+
" if n in cols_map:\n",
|
| 353 |
+
" return cols_map[n]\n",
|
| 354 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"def load_train_labels():\n",
|
| 357 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 358 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 359 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
|
| 360 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 361 |
+
" val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 362 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 363 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 364 |
+
" rows = []\n",
|
| 365 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 366 |
+
" rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
|
| 367 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 368 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 369 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 370 |
+
" votes = np.zeros(N_EMO, dtype=np.float32)\n",
|
| 371 |
+
" if cat_col:\n",
|
| 372 |
+
" for cell in g[cat_col]:\n",
|
| 373 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 374 |
+
" s = votes.sum()\n",
|
| 375 |
+
" cat = votes / s if s > 0 else np.full(N_EMO, 0.2, dtype=np.float32)\n",
|
| 376 |
+
" for i in range(N_EMO):\n",
|
| 377 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 378 |
+
" rows.append(rec)\n",
|
| 379 |
+
" return pd.DataFrame(rows)\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"# target cảm xúc theo wav (cho EMOS) từ metadata\n",
|
| 382 |
+
"def load_target_emotions():\n",
|
| 383 |
+
" tgt = {}\n",
|
| 384 |
+
" if os.path.exists(METADATA_CSV):\n",
|
| 385 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 386 |
+
" for ln in f:\n",
|
| 387 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 388 |
+
" if len(parts) >= 2:\n",
|
| 389 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 390 |
+
" return tgt\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"_target_map = None\n",
|
| 393 |
+
"_val_df = None\n",
|
| 394 |
+
"def _prep_eval():\n",
|
| 395 |
+
" \"\"\"Lazy: đọc nhãn + tách 10% val nội bộ (seed 42, khớp exp08).\"\"\"\n",
|
| 396 |
+
" global _target_map, _val_df\n",
|
| 397 |
+
" if _val_df is None:\n",
|
| 398 |
+
" _target_map = load_target_emotions()\n",
|
| 399 |
+
" df = load_train_labels()\n",
|
| 400 |
+
" df = df[df[\"wavID\"].map(lambda s: os.path.exists(os.path.join(WAV_DIR, s + \".wav\")))].reset_index(drop=True)\n",
|
| 401 |
+
" _, va = train_test_split(np.arange(len(df)), test_size=0.10, random_state=42)\n",
|
| 402 |
+
" _val_df = df.iloc[va].reset_index(drop=True)\n",
|
| 403 |
+
" return _target_map, _val_df\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"def eval_metrics(limit):\n",
|
| 406 |
+
" tmap, vdf = _prep_eval()\n",
|
| 407 |
+
" n = min(int(limit), len(vdf))\n",
|
| 408 |
+
" P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
|
| 409 |
+
" catP, catY = [], []\n",
|
| 410 |
+
" for i in tqdm(range(n), desc=\"eval\"):\n",
|
| 411 |
+
" r = vdf.iloc[i]; sid = r[\"wavID\"]\n",
|
| 412 |
+
" wav = os.path.join(WAV_DIR, sid + \".wav\")\n",
|
| 413 |
+
" wave, _ = librosa.load(wav, sr=SR, mono=True)\n",
|
| 414 |
+
" emos, cat5, vad3 = infer_wave(wave, tmap.get(sid))\n",
|
| 415 |
+
" P[\"emos\"].append(emos); Y[\"emos\"].append(float(r[\"emos\"]))\n",
|
| 416 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 417 |
+
" P[t].append(float(vad3[j])); Y[t].append(float(r[t]))\n",
|
| 418 |
+
" catP.append(cat5); catY.append([r[f\"cat{k}\"] for k in range(N_EMO)])\n",
|
| 419 |
+
" rows = []\n",
|
| 420 |
+
" for t in [\"emos\", \"val\", \"aro\", \"dom\"]:\n",
|
| 421 |
+
" srcc = spearmanr(P[t], Y[t]).correlation\n",
|
| 422 |
+
" rows.append([t.upper(), f\"{srcc:.4f}\", f\"{EXP08.get(t, float('nan')):.3f}\"])\n",
|
| 423 |
+
" cat_err = float(np.abs(np.array(catP) - np.array(catY)).sum(1).mean())\n",
|
| 424 |
+
" rows.append([\"CAT-err ↓\", f\"{cat_err:.4f}\", f\"{EXP08['cat_err']:.3f}\"])\n",
|
| 425 |
+
" return rows"
|
| 426 |
+
]
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"cell_type": "markdown",
|
| 430 |
+
"id": "36b32835",
|
| 431 |
+
"metadata": {},
|
| 432 |
+
"source": [
|
| 433 |
+
"## 5. Giao diện Gradio (2 tab)"
|
| 434 |
+
]
|
| 435 |
+
},
|
| 436 |
+
{
|
| 437 |
+
"cell_type": "code",
|
| 438 |
+
"execution_count": null,
|
| 439 |
+
"id": "5ea0f8f6",
|
| 440 |
+
"metadata": {},
|
| 441 |
+
"outputs": [],
|
| 442 |
+
"source": [
|
| 443 |
+
"import gradio as gr\n",
|
| 444 |
+
"\n",
|
| 445 |
+
"def ui_predict(audio, target_emotion):\n",
|
| 446 |
+
" \"\"\"Trả về: verdict(md) · EMOS(number) · CAT(label) · VAL/ARO/DOM(number).\"\"\"\n",
|
| 447 |
+
" if not audio:\n",
|
| 448 |
+
" return \"### ⚠️ Hãy tải audio.\", None, {}, None, None, None\n",
|
| 449 |
+
" wave, _ = librosa.load(audio, sr=SR, mono=True)\n",
|
| 450 |
+
" emos, cat5, vad3 = infer_wave(wave, target_emotion)\n",
|
| 451 |
+
" cat_dict = {e: float(cat5[i]) for i, e in enumerate(EMOTIONS5)}\n",
|
| 452 |
+
" perceived = EMOTIONS5[int(np.argmax(cat5))]\n",
|
| 453 |
+
" if target_emotion:\n",
|
| 454 |
+
" match = \"✅ **KHỚP** target\" if perceived == norm_emotion(target_emotion) else \"⚠️ **LỆCH** target\"\n",
|
| 455 |
+
" band = \"🟢 tốt\" if emos >= 4 else (\"🟡 khá\" if emos >= 3 else \"🔴 yếu\")\n",
|
| 456 |
+
" verdict = (f\"### Kết luận biểu cảm\\n\"\n",
|
| 457 |
+
" f\"- Cảm xúc cảm nhận: **{perceived}** → {match} (`{target_emotion}`)\\n\"\n",
|
| 458 |
+
" f\"- EMOS = **{emos:.2f}/5** → biểu cảm {band}\")\n",
|
| 459 |
+
" else:\n",
|
| 460 |
+
" verdict = (f\"### Kết luận biểu cảm\\n\"\n",
|
| 461 |
+
" f\"- Cảm xúc cảm nhận: **{perceived}**\\n\"\n",
|
| 462 |
+
" f\"- *(Chọn cảm xúc target để bật EMOS — độ khớp ý đồ)*\")\n",
|
| 463 |
+
" emos = None\n",
|
| 464 |
+
" return verdict, (round(emos, 3) if emos is not None else None), cat_dict, \\\n",
|
| 465 |
+
" round(float(vad3[0]), 3), round(float(vad3[1]), 3), round(float(vad3[2]), 3)\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"def ui_eval(limit):\n",
|
| 468 |
+
" return eval_metrics(limit)\n",
|
| 469 |
+
"\n",
|
| 470 |
+
"INTRO = (\n",
|
| 471 |
+
" \"# 🎙️ Emotional TTS Evaluator — VoiceMOS 2026 Track 2\\n\"\n",
|
| 472 |
+
" \"Bộ chấm **độ biểu cảm cảm xúc** của giọng TTS, chạy bằng model tốt nhất (**exp08**: WavLM fine-tune + \"\n",
|
| 473 |
+
" \"audeering). Offline, không cần API.\\n\\n\"\n",
|
| 474 |
+
" \"> **5 output dưới đây CHÍNH LÀ định nghĩa \\\"expressive emotion\\\" của Track 2** — mỗi cái trả lời một câu hỏi:\\n\"\n",
|
| 475 |
+
" \"> **EMOS** = có đúng cảm xúc được yêu cầu không · **CAT** = người nghe cảm nhận cảm xúc nào · \"\n",
|
| 476 |
+
" \"**VAD** = hóa trị / cường độ / chi phối.\"\n",
|
| 477 |
+
")\n",
|
| 478 |
+
"\n",
|
| 479 |
+
"with gr.Blocks(title=\"VMC2026 Track 2 — Emotional TTS Evaluator (exp08)\") as demo:\n",
|
| 480 |
+
" gr.Markdown(INTRO)\n",
|
| 481 |
+
" with gr.Tab(\"🎯 Chấm 1 file TTS\"):\n",
|
| 482 |
+
" with gr.Row():\n",
|
| 483 |
+
" with gr.Column(scale=1):\n",
|
| 484 |
+
" a = gr.Audio(type=\"filepath\", label=\"Audio (giọng TTS)\")\n",
|
| 485 |
+
" tgt = gr.Dropdown(EMOTIONS5, label=\"🎯 Cảm xúc target (cho EMOS)\")\n",
|
| 486 |
+
" btn = gr.Button(\"Chấm cảm xúc\", variant=\"primary\")\n",
|
| 487 |
+
" with gr.Column(scale=2):\n",
|
| 488 |
+
" verdict = gr.Markdown()\n",
|
| 489 |
+
" with gr.Row():\n",
|
| 490 |
+
" emos_o = gr.Number(label=\"EMOS — khớp cảm xúc target (1–5)\", interactive=False)\n",
|
| 491 |
+
" cat_o = gr.Label(label=\"CAT — phân bố cảm xúc cảm nhận (5 lớp)\")\n",
|
| 492 |
+
" gr.Markdown(\"**VAD — toạ độ cảm xúc liên tục (1–5):**\")\n",
|
| 493 |
+
" with gr.Row():\n",
|
| 494 |
+
" val_o = gr.Number(label=\"Valence (tích cực↑)\", interactive=False)\n",
|
| 495 |
+
" aro_o = gr.Number(label=\"Arousal (kích động↑)\", interactive=False)\n",
|
| 496 |
+
" dom_o = gr.Number(label=\"Dominance (chi phối↑)\", interactive=False)\n",
|
| 497 |
+
" btn.click(ui_predict, [a, tgt], [verdict, emos_o, cat_o, val_o, aro_o, dom_o])\n",
|
| 498 |
+
" with gr.Tab(\"📊 Độ tin cậy bộ chấm\"):\n",
|
| 499 |
+
" gr.Markdown(\"Đo model tái lập nhãn người tốt tới đâu trên **val nội bộ** (10% train.csv, seed 42) — \"\n",
|
| 500 |
+
" \"**UTT-SRCC** (EMOS/VAD, cao=tốt) + **CAT-err** (thấp=tốt).\\n\"\n",
|
| 501 |
+
" \"⚠️ Dev label ẩn → đây **KHÔNG** phải điểm leaderboard, chỉ để biết bộ chấm đáng tin cỡ nào.\")\n",
|
| 502 |
+
" lim = gr.Slider(20, 300, value=100, step=20, label=\"Số mẫu val để chấm (nhiều = chậm)\")\n",
|
| 503 |
+
" tbl = gr.Dataframe(headers=[\"Cột\", \"Model (val nội bộ)\", \"Mốc exp08\"],\n",
|
| 504 |
+
" label=\"UTT-SRCC / CAT-err\", interactive=False)\n",
|
| 505 |
+
" gr.Button(\"Chạy đánh giá\", variant=\"primary\").click(ui_eval, [lim], [tbl])\n",
|
| 506 |
+
"\n",
|
| 507 |
+
"demo.launch(share=True)"
|
| 508 |
+
]
|
| 509 |
+
},
|
| 510 |
+
{
|
| 511 |
+
"cell_type": "markdown",
|
| 512 |
+
"id": "fcea6f73",
|
| 513 |
+
"metadata": {},
|
| 514 |
+
"source": [
|
| 515 |
+
"## Ghi chú\n",
|
| 516 |
+
"- Hằng `TRUNK_HIDDEN/HEAD_HIDDEN` PHẢI khớp exp08 (ckpt không lưu) — sai là lệch key/shape.\n",
|
| 517 |
+
"- EMOS cần cảm xúc target → chưa chọn dropdown thì chỉ hiện CAT/VAD.\n",
|
| 518 |
+
"- exp08 = mean-pool (không Mamba) → demo dùng `masked_mean`.\n",
|
| 519 |
+
"- Metric tab chấm trên val nội bộ train.csv (dev ẩn) → con số ~ mốc exp08 nếu trùng tập val.\n",
|
| 520 |
+
"- Cần GPU T4 + Internet On (tải WavLM/SAILER/audeering lần đầu)."
|
| 521 |
+
]
|
| 522 |
+
}
|
| 523 |
+
],
|
| 524 |
+
"metadata": {
|
| 525 |
+
"jupytext": {
|
| 526 |
+
"cell_metadata_filter": "-all",
|
| 527 |
+
"main_language": "python",
|
| 528 |
+
"notebook_metadata_filter": "-all"
|
| 529 |
+
}
|
| 530 |
+
},
|
| 531 |
+
"nbformat": 4,
|
| 532 |
+
"nbformat_minor": 5
|
| 533 |
+
}
|
track2/demo_track2_emotion_gradio_pipeline.py
ADDED
|
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — Demo Gradio "Emotional TTS Evaluator" (model TỐT NHẤT = exp08)
|
| 3 |
+
#
|
| 4 |
+
# Demo này dùng **checkpoint cảm xúc tốt nhất** (`ft_emotion_full_20epoch.pt`: WavLM fine-tune warm-start
|
| 5 |
+
# SAILER + audeering frozen) để chấm **5 cột cảm xúc** của 1 file giọng TTS: **EMOS / CAT / VAL / ARO / DOM**.
|
| 6 |
+
# Khác demo cũ (`demo_track2_gradio`) dùng baseline UTMOS+emotion2vec+Gemini — bản này KHÔNG cần API.
|
| 7 |
+
#
|
| 8 |
+
# **2 tab:**
|
| 9 |
+
# 1. *Chấm 1 file TTS* — tải audio + chọn cảm xúc target → ra điểm biểu cảm cảm xúc + diễn giải KHỚP/LỆCH.
|
| 10 |
+
# 2. *Metric bộ chấm* — tính UTT-SRCC (EMOS/VAD) + CAT-err trên val nội bộ (train.csv) → cho biết độ tin cậy.
|
| 11 |
+
#
|
| 12 |
+
# **Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input (1) dataset Track 2, (2) dataset chứa
|
| 13 |
+
# `ft_emotion_full_20epoch.pt` → Run All → cell cuối in link `*.gradio.live`.
|
| 14 |
+
|
| 15 |
+
# %% [markdown]
|
| 16 |
+
# ## 0. Cấu hình — auto-dò DATA_ROOT + checkpoint
|
| 17 |
+
|
| 18 |
+
# %%
|
| 19 |
+
import os, glob
|
| 20 |
+
|
| 21 |
+
def find_data_root(search_root="/kaggle/input"):
|
| 22 |
+
cands = []
|
| 23 |
+
for train_csv in glob.glob(os.path.join(search_root, "**", "sets", "train.csv"), recursive=True):
|
| 24 |
+
root = os.path.dirname(os.path.dirname(train_csv))
|
| 25 |
+
score = os.path.isdir(os.path.join(root, "wav")) + os.path.exists(os.path.join(root, "metadata.csv"))
|
| 26 |
+
cands.append((score, root))
|
| 27 |
+
cands.sort(reverse=True)
|
| 28 |
+
return cands
|
| 29 |
+
|
| 30 |
+
_cands = find_data_root("/kaggle/input")
|
| 31 |
+
if _cands:
|
| 32 |
+
print("🔎 Ứng viên DATA_ROOT:")
|
| 33 |
+
for sc, r in _cands:
|
| 34 |
+
print(f" [{sc}/2] {r}")
|
| 35 |
+
DATA_ROOT = _cands[0][1]
|
| 36 |
+
print(f"👉 Tự chọn DATA_ROOT = {DATA_ROOT}")
|
| 37 |
+
else:
|
| 38 |
+
DATA_ROOT = "/kaggle/input/datasets/minhtoan2" # dự phòng
|
| 39 |
+
print(f"❌ Không thấy sets/train.csv → dự phòng {DATA_ROOT} (đã Add Input chưa?)")
|
| 40 |
+
|
| 41 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 42 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv"
|
| 43 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv"
|
| 44 |
+
|
| 45 |
+
# ── Checkpoint cảm xúc exp08 (ưu tiên bản 20 epoch = TỐT NHẤT) ─────────────────
|
| 46 |
+
CKPT_PATH = "" # << "" = auto-dò; hoặc trỏ tay "/kaggle/input/<slug>/ft_emotion_full_20epoch.pt"
|
| 47 |
+
|
| 48 |
+
def find_ckpt(explicit):
|
| 49 |
+
if explicit and os.path.exists(explicit):
|
| 50 |
+
return explicit
|
| 51 |
+
pats = ["ft_emotion_full_20epoch*.pt", "ft_emotion_full*.pt"] # ưu tiên bản 20epoch
|
| 52 |
+
for pat in pats:
|
| 53 |
+
for base in ["/kaggle/input", "/kaggle/working"]:
|
| 54 |
+
hits = sorted(glob.glob(os.path.join(base, "**", pat), recursive=True))
|
| 55 |
+
if hits:
|
| 56 |
+
return hits[0]
|
| 57 |
+
return ""
|
| 58 |
+
|
| 59 |
+
CKPT_PATH = find_ckpt(CKPT_PATH)
|
| 60 |
+
assert CKPT_PATH, "❌ Không thấy ft_emotion_full*.pt. Add Input dataset chứa checkpoint exp08 chưa?"
|
| 61 |
+
print("✅ Checkpoint:", CKPT_PATH)
|
| 62 |
+
|
| 63 |
+
# ── Hằng kiến trúc PHẢI khớp exp08 (ckpt không lưu các số này) ────────────────
|
| 64 |
+
DEVICE = "cuda"
|
| 65 |
+
SR = 16000
|
| 66 |
+
EMO_MAX_SEC = 8
|
| 67 |
+
TRUNK_HIDDEN = 512
|
| 68 |
+
HEAD_HIDDEN = 128
|
| 69 |
+
DROPOUT = 0.3 # không ảnh hưởng eval
|
| 70 |
+
USE_AMP = True
|
| 71 |
+
|
| 72 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 73 |
+
_EMO_ALIAS = {
|
| 74 |
+
"angry": "angry", "anger": "angry",
|
| 75 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 76 |
+
"neutral": "neutral", "calm": "neutral",
|
| 77 |
+
"sad": "sad", "sadness": "sad",
|
| 78 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
def norm_emotion(label):
|
| 82 |
+
key = str(label).strip().lower()
|
| 83 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 84 |
+
|
| 85 |
+
def stem(p):
|
| 86 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 87 |
+
|
| 88 |
+
# Mốc exp08 (val nội bộ / DEV) để so trong tab metric
|
| 89 |
+
EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751}
|
| 90 |
+
|
| 91 |
+
# %% [markdown]
|
| 92 |
+
# ## 1. Cài đặt + clone code SAILER
|
| 93 |
+
|
| 94 |
+
# %%
|
| 95 |
+
import sys, subprocess
|
| 96 |
+
|
| 97 |
+
def pip_install(*pkgs):
|
| 98 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 99 |
+
|
| 100 |
+
pip_install("gradio", "loralib", "speechbrain", "librosa", "soundfile",
|
| 101 |
+
"scipy", "scikit-learn", "pandas", "tqdm")
|
| 102 |
+
|
| 103 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 104 |
+
if not os.path.exists(REPO_DIR):
|
| 105 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 106 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 107 |
+
if REPO_DIR not in sys.path:
|
| 108 |
+
sys.path.insert(0, REPO_DIR)
|
| 109 |
+
|
| 110 |
+
# %% [markdown]
|
| 111 |
+
# ## 2. Nạp model exp08 (backbone WavLM ft + audeering frozen + heads) — 1 lần
|
| 112 |
+
|
| 113 |
+
# %%
|
| 114 |
+
import torch
|
| 115 |
+
import torch.nn as nn
|
| 116 |
+
import torch.nn.functional as F
|
| 117 |
+
import numpy as np
|
| 118 |
+
import librosa
|
| 119 |
+
|
| 120 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 121 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (chậm)")
|
| 122 |
+
|
| 123 |
+
ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False) # ckpt có numpy → cần False
|
| 124 |
+
assert "wavlm" in ckpt and "heads" in ckpt, "❌ Checkpoint thiếu 'wavlm'/'heads' → cần ft_emotion_full_20epoch.pt đủ."
|
| 125 |
+
AUD_DIM = int(ckpt.get("AUD_DIM", 0))
|
| 126 |
+
USE_AUDEERING = AUD_DIM > 0
|
| 127 |
+
print("✅ Nạp ckpt | keys:", list(ckpt.keys()), "| AUD_DIM:", AUD_DIM, "(audeering", "ON)" if USE_AUDEERING else "OFF)")
|
| 128 |
+
|
| 129 |
+
def find_hf_backbone(module):
|
| 130 |
+
cands = []
|
| 131 |
+
for name, m in module.named_modules():
|
| 132 |
+
enc = getattr(m, "encoder", None)
|
| 133 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 134 |
+
and getattr(enc, "layers", None) is not None:
|
| 135 |
+
cands.append((name, m))
|
| 136 |
+
if not cands:
|
| 137 |
+
return None, None
|
| 138 |
+
cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
|
| 139 |
+
return cands[0]
|
| 140 |
+
|
| 141 |
+
wavlm = None
|
| 142 |
+
try:
|
| 143 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 144 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 145 |
+
_name, wavlm = find_hf_backbone(_wrapper)
|
| 146 |
+
if wavlm is not None:
|
| 147 |
+
print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{_name}'")
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
|
| 150 |
+
if wavlm is None:
|
| 151 |
+
from transformers import WavLMModel
|
| 152 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 153 |
+
print("ℹ️ Fallback: microsoft/wavlm-large.")
|
| 154 |
+
|
| 155 |
+
wavlm = wavlm.to(device).eval()
|
| 156 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 157 |
+
wavlm.config.layerdrop = 0.0
|
| 158 |
+
_miss, _unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
|
| 159 |
+
print(f"🔁 load wavlm: thiếu {len(_miss)} / dư {len(_unexp)} key (kỳ vọng ~0)")
|
| 160 |
+
|
| 161 |
+
def masked_mean(hidden, attn_mask):
|
| 162 |
+
if attn_mask is None:
|
| 163 |
+
return hidden.mean(dim=1)
|
| 164 |
+
try:
|
| 165 |
+
fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
|
| 166 |
+
except Exception:
|
| 167 |
+
return hidden.mean(dim=1)
|
| 168 |
+
fm = fm.unsqueeze(-1).to(hidden.dtype)
|
| 169 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 170 |
+
|
| 171 |
+
@torch.no_grad()
|
| 172 |
+
def wavlm_embed(input_values, attn_mask):
|
| 173 |
+
out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
|
| 174 |
+
return masked_mean(out, attn_mask)
|
| 175 |
+
|
| 176 |
+
# ── audeering frozen (đặc trưng phụ) — chỉ dựng nếu ckpt có dùng ──
|
| 177 |
+
aud_backbone = aud_head = aud_proc = None
|
| 178 |
+
if USE_AUDEERING:
|
| 179 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 180 |
+
from huggingface_hub import hf_hub_download
|
| 181 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 182 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 183 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 184 |
+
aud_backbone = Wav2Vec2Model(aud_cfg)
|
| 185 |
+
try:
|
| 186 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 187 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 188 |
+
except Exception:
|
| 189 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 190 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 191 |
+
aud_backbone.load_state_dict(bb_sd, strict=False)
|
| 192 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 193 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd["classifier.out_proj.weight"].shape[0]))
|
| 194 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 195 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 196 |
+
aud_backbone = aud_backbone.to(device).eval()
|
| 197 |
+
aud_head = aud_head.to(device).eval()
|
| 198 |
+
assert _hid + 3 == AUD_DIM, f"⚠️ AUD_DIM dựng ({_hid+3}) ≠ ckpt ({AUD_DIM})"
|
| 199 |
+
print(f"✅ audeering frozen ({AUD_DIM}-D)")
|
| 200 |
+
|
| 201 |
+
@torch.no_grad()
|
| 202 |
+
def audeering_feat(wave):
|
| 203 |
+
x = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 204 |
+
x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
|
| 205 |
+
h = aud_backbone(x)[0].mean(dim=1)
|
| 206 |
+
out = aud_head(h)[0].cpu().numpy()
|
| 207 |
+
vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
|
| 208 |
+
return np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
|
| 209 |
+
|
| 210 |
+
# ── EmoHeads (khớp exp08) + nạp trọng số + chuẩn hóa từ ckpt ──
|
| 211 |
+
N_EMO = len(EMOTIONS5)
|
| 212 |
+
TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
|
| 213 |
+
|
| 214 |
+
class EmoHeads(nn.Module):
|
| 215 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 218 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 219 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 220 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 221 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 222 |
+
def forward(self, feat, tgt):
|
| 223 |
+
h = self.trunk(feat)
|
| 224 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 225 |
+
|
| 226 |
+
heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device).eval()
|
| 227 |
+
_hm, _hu = heads.load_state_dict(ckpt["heads"], strict=False)
|
| 228 |
+
print(f"🔁 load heads: thiếu {len(_hm)} / dư {len(_hu)} key (kỳ vọng 0)")
|
| 229 |
+
|
| 230 |
+
emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
|
| 231 |
+
vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
|
| 232 |
+
print(f"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
|
| 233 |
+
|
| 234 |
+
def onehot_target(tgt):
|
| 235 |
+
v = np.zeros(N_EMO, dtype=np.float32)
|
| 236 |
+
if tgt in EMOTIONS5:
|
| 237 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 238 |
+
return v
|
| 239 |
+
|
| 240 |
+
# %% [markdown]
|
| 241 |
+
# ## 3. Hàm suy luận lõi (1 wave numpy → emos/cat5/vad3)
|
| 242 |
+
|
| 243 |
+
# %%
|
| 244 |
+
@torch.no_grad()
|
| 245 |
+
def infer_wave(wave, target_emotion):
|
| 246 |
+
"""wave: numpy float32 (đã 16k mono). target_emotion: str hoặc None. Trả (emos, cat5, vad3)."""
|
| 247 |
+
wave = wave[: EMO_MAX_SEC * SR].astype(np.float32)
|
| 248 |
+
iv = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 249 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
|
| 250 |
+
tgt = torch.from_numpy(onehot_target(norm_emotion(target_emotion) if target_emotion else None)).unsqueeze(0).to(device)
|
| 251 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 252 |
+
fw = wavlm_embed(iv, am)
|
| 253 |
+
if USE_AUDEERING:
|
| 254 |
+
fw = torch.cat([fw, torch.from_numpy(audeering_feat(wave)).unsqueeze(0).to(device)], dim=1)
|
| 255 |
+
emos_p, cat_l, vad_p = heads(fw, tgt)
|
| 256 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 257 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 258 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 259 |
+
return emos, cat5, vad3
|
| 260 |
+
|
| 261 |
+
# %% [markdown]
|
| 262 |
+
# ## 4. Hàm metric val nội bộ (UTT-SRCC + CAT-err) — đánh giá độ tin cậy bộ chấm
|
| 263 |
+
|
| 264 |
+
# %%
|
| 265 |
+
import pandas as pd
|
| 266 |
+
from scipy.stats import spearmanr
|
| 267 |
+
from sklearn.model_selection import train_test_split
|
| 268 |
+
from tqdm.auto import tqdm
|
| 269 |
+
|
| 270 |
+
def parse_emocat_votes(cell):
|
| 271 |
+
v = np.zeros(N_EMO, dtype=np.float32)
|
| 272 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 273 |
+
e = norm_emotion(tok)
|
| 274 |
+
if e in EMOTIONS5:
|
| 275 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 276 |
+
return v
|
| 277 |
+
|
| 278 |
+
def _col(cols_map, *names, df=None, default_idx=None):
|
| 279 |
+
for n in names:
|
| 280 |
+
if n in cols_map:
|
| 281 |
+
return cols_map[n]
|
| 282 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 283 |
+
|
| 284 |
+
def load_train_labels():
|
| 285 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 286 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 287 |
+
wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
|
| 288 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 289 |
+
val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
|
| 290 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 291 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 292 |
+
rows = []
|
| 293 |
+
for sid, g in df.groupby("_stem"):
|
| 294 |
+
rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
|
| 295 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 296 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 297 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 298 |
+
votes = np.zeros(N_EMO, dtype=np.float32)
|
| 299 |
+
if cat_col:
|
| 300 |
+
for cell in g[cat_col]:
|
| 301 |
+
votes += parse_emocat_votes(cell)
|
| 302 |
+
s = votes.sum()
|
| 303 |
+
cat = votes / s if s > 0 else np.full(N_EMO, 0.2, dtype=np.float32)
|
| 304 |
+
for i in range(N_EMO):
|
| 305 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 306 |
+
rows.append(rec)
|
| 307 |
+
return pd.DataFrame(rows)
|
| 308 |
+
|
| 309 |
+
# target cảm xúc theo wav (cho EMOS) từ metadata
|
| 310 |
+
def load_target_emotions():
|
| 311 |
+
tgt = {}
|
| 312 |
+
if os.path.exists(METADATA_CSV):
|
| 313 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 314 |
+
for ln in f:
|
| 315 |
+
parts = ln.strip().split("|")
|
| 316 |
+
if len(parts) >= 2:
|
| 317 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 318 |
+
return tgt
|
| 319 |
+
|
| 320 |
+
_target_map = None
|
| 321 |
+
_val_df = None
|
| 322 |
+
def _prep_eval():
|
| 323 |
+
"""Lazy: đọc nhãn + tách 10% val nội bộ (seed 42, khớp exp08)."""
|
| 324 |
+
global _target_map, _val_df
|
| 325 |
+
if _val_df is None:
|
| 326 |
+
_target_map = load_target_emotions()
|
| 327 |
+
df = load_train_labels()
|
| 328 |
+
df = df[df["wavID"].map(lambda s: os.path.exists(os.path.join(WAV_DIR, s + ".wav")))].reset_index(drop=True)
|
| 329 |
+
_, va = train_test_split(np.arange(len(df)), test_size=0.10, random_state=42)
|
| 330 |
+
_val_df = df.iloc[va].reset_index(drop=True)
|
| 331 |
+
return _target_map, _val_df
|
| 332 |
+
|
| 333 |
+
def eval_metrics(limit):
|
| 334 |
+
tmap, vdf = _prep_eval()
|
| 335 |
+
n = min(int(limit), len(vdf))
|
| 336 |
+
P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
|
| 337 |
+
catP, catY = [], []
|
| 338 |
+
for i in tqdm(range(n), desc="eval"):
|
| 339 |
+
r = vdf.iloc[i]; sid = r["wavID"]
|
| 340 |
+
wav = os.path.join(WAV_DIR, sid + ".wav")
|
| 341 |
+
wave, _ = librosa.load(wav, sr=SR, mono=True)
|
| 342 |
+
emos, cat5, vad3 = infer_wave(wave, tmap.get(sid))
|
| 343 |
+
P["emos"].append(emos); Y["emos"].append(float(r["emos"]))
|
| 344 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 345 |
+
P[t].append(float(vad3[j])); Y[t].append(float(r[t]))
|
| 346 |
+
catP.append(cat5); catY.append([r[f"cat{k}"] for k in range(N_EMO)])
|
| 347 |
+
rows = []
|
| 348 |
+
for t in ["emos", "val", "aro", "dom"]:
|
| 349 |
+
srcc = spearmanr(P[t], Y[t]).correlation
|
| 350 |
+
rows.append([t.upper(), f"{srcc:.4f}", f"{EXP08.get(t, float('nan')):.3f}"])
|
| 351 |
+
cat_err = float(np.abs(np.array(catP) - np.array(catY)).sum(1).mean())
|
| 352 |
+
rows.append(["CAT-err ↓", f"{cat_err:.4f}", f"{EXP08['cat_err']:.3f}"])
|
| 353 |
+
return rows
|
| 354 |
+
|
| 355 |
+
# %% [markdown]
|
| 356 |
+
# ## 5. Giao diện Gradio (2 tab)
|
| 357 |
+
|
| 358 |
+
# %%
|
| 359 |
+
import gradio as gr
|
| 360 |
+
|
| 361 |
+
def ui_predict(audio, target_emotion):
|
| 362 |
+
"""Trả về: verdict(md) · EMOS(number) · CAT(label) · VAL/ARO/DOM(number)."""
|
| 363 |
+
if not audio:
|
| 364 |
+
return "### ⚠️ Hãy tải audio.", None, {}, None, None, None
|
| 365 |
+
wave, _ = librosa.load(audio, sr=SR, mono=True)
|
| 366 |
+
emos, cat5, vad3 = infer_wave(wave, target_emotion)
|
| 367 |
+
cat_dict = {e: float(cat5[i]) for i, e in enumerate(EMOTIONS5)}
|
| 368 |
+
perceived = EMOTIONS5[int(np.argmax(cat5))]
|
| 369 |
+
if target_emotion:
|
| 370 |
+
match = "✅ **KHỚP** target" if perceived == norm_emotion(target_emotion) else "⚠️ **LỆCH** target"
|
| 371 |
+
band = "🟢 tốt" if emos >= 4 else ("🟡 khá" if emos >= 3 else "🔴 yếu")
|
| 372 |
+
verdict = (f"### Kết luận biểu cảm\n"
|
| 373 |
+
f"- Cảm xúc cảm nhận: **{perceived}** → {match} (`{target_emotion}`)\n"
|
| 374 |
+
f"- EMOS = **{emos:.2f}/5** → biểu cảm {band}")
|
| 375 |
+
else:
|
| 376 |
+
verdict = (f"### Kết luận biểu cảm\n"
|
| 377 |
+
f"- Cảm xúc cảm nhận: **{perceived}**\n"
|
| 378 |
+
f"- *(Chọn cảm xúc target để bật EMOS — độ khớp ý đồ)*")
|
| 379 |
+
emos = None
|
| 380 |
+
return verdict, (round(emos, 3) if emos is not None else None), cat_dict, \
|
| 381 |
+
round(float(vad3[0]), 3), round(float(vad3[1]), 3), round(float(vad3[2]), 3)
|
| 382 |
+
|
| 383 |
+
def ui_eval(limit):
|
| 384 |
+
return eval_metrics(limit)
|
| 385 |
+
|
| 386 |
+
INTRO = (
|
| 387 |
+
"# 🎙️ Emotional TTS Evaluator — VoiceMOS 2026 Track 2\n"
|
| 388 |
+
"Bộ chấm **độ biểu cảm cảm xúc** của giọng TTS, chạy bằng model tốt nhất (**exp08**: WavLM fine-tune + "
|
| 389 |
+
"audeering). Offline, không cần API.\n\n"
|
| 390 |
+
"> **5 output dưới đây CHÍNH LÀ định nghĩa \"expressive emotion\" của Track 2** — mỗi cái trả lời một câu hỏi:\n"
|
| 391 |
+
"> **EMOS** = có đúng cảm xúc được yêu cầu không · **CAT** = người nghe cảm nhận cảm xúc nào · "
|
| 392 |
+
"**VAD** = hóa trị / cường độ / chi phối."
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
with gr.Blocks(title="VMC2026 Track 2 — Emotional TTS Evaluator (exp08)") as demo:
|
| 396 |
+
gr.Markdown(INTRO)
|
| 397 |
+
with gr.Tab("🎯 Chấm 1 file TTS"):
|
| 398 |
+
with gr.Row():
|
| 399 |
+
with gr.Column(scale=1):
|
| 400 |
+
a = gr.Audio(type="filepath", label="Audio (giọng TTS)")
|
| 401 |
+
tgt = gr.Dropdown(EMOTIONS5, label="🎯 Cảm xúc target (cho EMOS)")
|
| 402 |
+
btn = gr.Button("Chấm cảm xúc", variant="primary")
|
| 403 |
+
with gr.Column(scale=2):
|
| 404 |
+
verdict = gr.Markdown()
|
| 405 |
+
with gr.Row():
|
| 406 |
+
emos_o = gr.Number(label="EMOS — khớp cảm xúc target (1–5)", interactive=False)
|
| 407 |
+
cat_o = gr.Label(label="CAT — phân bố cảm xúc cảm nhận (5 lớp)")
|
| 408 |
+
gr.Markdown("**VAD — toạ độ cảm xúc liên tục (1–5):**")
|
| 409 |
+
with gr.Row():
|
| 410 |
+
val_o = gr.Number(label="Valence (tích cực↑)", interactive=False)
|
| 411 |
+
aro_o = gr.Number(label="Arousal (kích động↑)", interactive=False)
|
| 412 |
+
dom_o = gr.Number(label="Dominance (chi phối↑)", interactive=False)
|
| 413 |
+
btn.click(ui_predict, [a, tgt], [verdict, emos_o, cat_o, val_o, aro_o, dom_o])
|
| 414 |
+
with gr.Tab("📊 Độ tin cậy bộ chấm"):
|
| 415 |
+
gr.Markdown("Đo model tái lập nhãn người tốt tới đâu trên **val nội bộ** (10% train.csv, seed 42) — "
|
| 416 |
+
"**UTT-SRCC** (EMOS/VAD, cao=tốt) + **CAT-err** (thấp=tốt).\n"
|
| 417 |
+
"⚠️ Dev label ẩn → đây **KHÔNG** phải điểm leaderboard, chỉ để biết bộ chấm đáng tin cỡ nào.")
|
| 418 |
+
lim = gr.Slider(20, 300, value=100, step=20, label="Số mẫu val để chấm (nhiều = chậm)")
|
| 419 |
+
tbl = gr.Dataframe(headers=["Cột", "Model (val nội bộ)", "Mốc exp08"],
|
| 420 |
+
label="UTT-SRCC / CAT-err", interactive=False)
|
| 421 |
+
gr.Button("Chạy đánh giá", variant="primary").click(ui_eval, [lim], [tbl])
|
| 422 |
+
|
| 423 |
+
demo.launch(share=True)
|
| 424 |
+
|
| 425 |
+
# %% [markdown]
|
| 426 |
+
# ## Ghi chú
|
| 427 |
+
# - Hằng `TRUNK_HIDDEN/HEAD_HIDDEN` PHẢI khớp exp08 (ckpt không lưu) — sai là lệch key/shape.
|
| 428 |
+
# - EMOS cần cảm xúc target → chưa chọn dropdown thì chỉ hiện CAT/VAD.
|
| 429 |
+
# - exp08 = mean-pool (không Mamba) → demo dùng `masked_mean`.
|
| 430 |
+
# - Metric tab chấm trên val nội bộ train.csv (dev ẩn) → con số ~ mốc exp08 nếu trùng tập val.
|
| 431 |
+
# - Cần GPU T4 + Internet On (tải WavLM/SAILER/audeering lần đầu).
|
track2/demo_track2_gradio.ipynb
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# VMC2026 Track 2 — Demo Gradio (Emotional TTS: QMOS / CAT / EMOS / VAD)\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"- **QMOS** (UTMOS) + **CAT** (emotion2vec, 5 cảm xúc): chạy ngay, chỉ cần audio.\n",
|
| 10 |
+
"- **EMOS / VAD** (Gemini): tùy chọn — cần dán `GEMINI_API_KEY` + chọn cảm xúc target.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"### Cách dùng trên Kaggle\n",
|
| 13 |
+
"1. Settings → **GPU T4 + Internet On**.\n",
|
| 14 |
+
"2. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h)."
|
| 15 |
+
]
|
| 16 |
+
},
|
| 17 |
+
{
|
| 18 |
+
"cell_type": "markdown",
|
| 19 |
+
"metadata": {},
|
| 20 |
+
"source": [
|
| 21 |
+
"## 1. Cài đặt"
|
| 22 |
+
]
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"cell_type": "code",
|
| 26 |
+
"execution_count": null,
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"!pip install -q gradio speechmos funasr librosa soundfile google-genai"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "markdown",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"source": [
|
| 37 |
+
"## 2. Nạp model + hàm dự đoán"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": null,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"import re, json, librosa\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"GEMINI_MODEL = \"gemini-2.0-flash\"\n",
|
| 49 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"_M = {}\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"def _qmos():\n",
|
| 54 |
+
" if \"qmos\" not in _M:\n",
|
| 55 |
+
" import torch\n",
|
| 56 |
+
" _M[\"qmos\"] = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True)\n",
|
| 57 |
+
" return _M[\"qmos\"]\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"def _emocat():\n",
|
| 61 |
+
" if \"emocat\" not in _M:\n",
|
| 62 |
+
" from funasr import AutoModel\n",
|
| 63 |
+
" _M[\"emocat\"] = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\")\n",
|
| 64 |
+
" return _M[\"emocat\"]\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"def _gemini_emos_vad(audio_path, target_emotion, api_key):\n",
|
| 68 |
+
" \"\"\"EMOS (1-5 độ khớp cảm xúc target) + VAD (val/aro/dom 1-5) — bản demo gọn qua Gemini.\"\"\"\n",
|
| 69 |
+
" from google import genai\n",
|
| 70 |
+
" from google.genai import types\n",
|
| 71 |
+
" client = genai.Client(api_key=api_key)\n",
|
| 72 |
+
" part = types.Part.from_bytes(data=open(audio_path, \"rb\").read(), mime_type=\"audio/wav\")\n",
|
| 73 |
+
" cfg = types.GenerateContentConfig(temperature=0.0)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
" p_emos = (f\"The target emotion is '{target_emotion}'. On a scale of 1 to 5, how well does the \"\n",
|
| 76 |
+
" f\"speaker express that emotion? 5=perfect match, 1=no match. Answer with ONLY one integer 1-5.\")\n",
|
| 77 |
+
" r = client.models.generate_content(model=GEMINI_MODEL, config=cfg, contents=[p_emos, part])\n",
|
| 78 |
+
" mm = re.search(r\"[1-5]\", getattr(r, \"text\", \"\") or \"\")\n",
|
| 79 |
+
" emos = int(mm.group()) if mm else None\n",
|
| 80 |
+
"\n",
|
| 81 |
+
" p_vad = ('Rate this speech on three 1-5 scales: Valence (1=very negative,5=very positive), '\n",
|
| 82 |
+
" 'Arousal (1=very calm,5=very excited), Dominance (1=very submissive,5=very dominant). '\n",
|
| 83 |
+
" 'Answer ONLY as JSON: {\"val\":x,\"aro\":y,\"dom\":z}.')\n",
|
| 84 |
+
" r2 = client.models.generate_content(model=GEMINI_MODEL, config=cfg, contents=[p_vad, part])\n",
|
| 85 |
+
" val = aro = dom = None\n",
|
| 86 |
+
" try:\n",
|
| 87 |
+
" d = json.loads(re.search(r\"\\{.*\\}\", getattr(r2, \"text\", \"\") or \"\", re.S).group())\n",
|
| 88 |
+
" val, aro, dom = d.get(\"val\"), d.get(\"aro\"), d.get(\"dom\")\n",
|
| 89 |
+
" except Exception:\n",
|
| 90 |
+
" pass\n",
|
| 91 |
+
" return emos, (val, aro, dom)\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"\n",
|
| 94 |
+
"def predict(audio, target_emotion, gemini_key):\n",
|
| 95 |
+
" import torch\n",
|
| 96 |
+
" if not audio:\n",
|
| 97 |
+
" return \"⚠️ Hãy tải audio.\", {}\n",
|
| 98 |
+
" wav = librosa.load(audio, sr=16000, mono=True)[0]\n",
|
| 99 |
+
" # QMOS\n",
|
| 100 |
+
" qmos = float(_qmos()(torch.from_numpy(wav).unsqueeze(0), sr=16000).mean().item())\n",
|
| 101 |
+
" # CAT\n",
|
| 102 |
+
" rec = _emocat().generate(audio, granularity=\"utterance\", extract_embedding=False)\n",
|
| 103 |
+
" probs = {e: 0.0 for e in EMOTIONS5}\n",
|
| 104 |
+
" for lab, sc in zip(rec[0][\"labels\"], rec[0][\"scores\"]):\n",
|
| 105 |
+
" name = lab.split(\"/\")[-1]\n",
|
| 106 |
+
" if name in probs:\n",
|
| 107 |
+
" probs[name] = float(sc)\n",
|
| 108 |
+
" tot = sum(probs.values())\n",
|
| 109 |
+
" if tot > 0:\n",
|
| 110 |
+
" probs = {k: v / tot for k, v in probs.items()}\n",
|
| 111 |
+
"\n",
|
| 112 |
+
" lines = [f\"QMOS (chất lượng giọng, 1–5): {qmos:.3f}\"]\n",
|
| 113 |
+
" if gemini_key and target_emotion:\n",
|
| 114 |
+
" try:\n",
|
| 115 |
+
" emos, (val, aro, dom) = _gemini_emos_vad(audio, target_emotion, gemini_key)\n",
|
| 116 |
+
" lines.append(f\"EMOS (độ khớp cảm xúc '{target_emotion}', 1–5): {emos}\")\n",
|
| 117 |
+
" lines.append(f\"VAD — Valence: {val} · Arousal: {aro} · Dominance: {dom}\")\n",
|
| 118 |
+
" except Exception as e:\n",
|
| 119 |
+
" lines.append(f\"(EMOS/VAD lỗi: {e})\")\n",
|
| 120 |
+
" else:\n",
|
| 121 |
+
" lines.append(\"(EMOS/VAD: dán GEMINI_API_KEY + chọn cảm xúc target để bật)\")\n",
|
| 122 |
+
" return \"\\n\".join(lines), probs"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "markdown",
|
| 127 |
+
"metadata": {},
|
| 128 |
+
"source": [
|
| 129 |
+
"## 3. Giao diện Gradio + launch"
|
| 130 |
+
]
|
| 131 |
+
},
|
| 132 |
+
{
|
| 133 |
+
"cell_type": "code",
|
| 134 |
+
"execution_count": null,
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [],
|
| 137 |
+
"source": [
|
| 138 |
+
"import gradio as gr\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"with gr.Blocks(title=\"VMC2026 Track 2 — Emotional TTS\") as demo:\n",
|
| 141 |
+
" gr.Markdown(\"# 🎙️ Track 2 · Emotional TTS (QMOS / CAT / EMOS / VAD)\\n\"\n",
|
| 142 |
+
" \"QMOS + phân bố cảm xúc (CAT) chạy ngay. EMOS/VAD cần Gemini key + cảm xúc target.\")\n",
|
| 143 |
+
" a = gr.Audio(type=\"filepath\", label=\"Audio\")\n",
|
| 144 |
+
" with gr.Row():\n",
|
| 145 |
+
" tgt = gr.Dropdown(EMOTIONS5, label=\"Cảm xúc target (cho EMOS, tùy chọn)\")\n",
|
| 146 |
+
" key = gr.Textbox(label=\"GEMINI_API_KEY (tùy chọn)\", type=\"password\")\n",
|
| 147 |
+
" out = gr.Textbox(label=\"Kết quả số\", lines=5)\n",
|
| 148 |
+
" lbl = gr.Label(label=\"CAT — phân bố cảm xúc cảm nhận\")\n",
|
| 149 |
+
" gr.Button(\"Dự đoán\", variant=\"primary\").click(predict, [a, tgt, key], [out, lbl])\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"demo.launch(share=True)"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "markdown",
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"source": [
|
| 158 |
+
"## Ghi chú\n",
|
| 159 |
+
"- EMOS/VAD là bản demo gọn (prompt rút gọn) — KHÔNG hoàn toàn giống script baseline gốc, chỉ minh họa."
|
| 160 |
+
]
|
| 161 |
+
}
|
| 162 |
+
],
|
| 163 |
+
"metadata": {
|
| 164 |
+
"kernelspec": {
|
| 165 |
+
"display_name": "Python 3",
|
| 166 |
+
"language": "python",
|
| 167 |
+
"name": "python3"
|
| 168 |
+
},
|
| 169 |
+
"language_info": {
|
| 170 |
+
"name": "python"
|
| 171 |
+
}
|
| 172 |
+
},
|
| 173 |
+
"nbformat": 4,
|
| 174 |
+
"nbformat_minor": 5
|
| 175 |
+
}
|
track2/demo_track2_gradio_pipeline.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — Demo Gradio (Emotional TTS: QMOS / CAT / EMOS / VAD)
|
| 3 |
+
#
|
| 4 |
+
# - **QMOS** (UTMOS) + **CAT** (emotion2vec, 5 cảm xúc): chạy ngay, chỉ cần audio.
|
| 5 |
+
# - **EMOS / VAD** (Gemini): tùy chọn — cần dán `GEMINI_API_KEY` + chọn cảm xúc target.
|
| 6 |
+
#
|
| 7 |
+
# ### Cách dùng trên Kaggle
|
| 8 |
+
# 1. Settings → **GPU T4 + Internet On**.
|
| 9 |
+
# 2. **Run All** → cell cuối in link `*.gradio.live` (sống ~72h).
|
| 10 |
+
|
| 11 |
+
# %% [markdown]
|
| 12 |
+
# ## 1. Cài đặt
|
| 13 |
+
|
| 14 |
+
# %%
|
| 15 |
+
# !pip install -q gradio speechmos funasr librosa soundfile google-genai
|
| 16 |
+
|
| 17 |
+
# %% [markdown]
|
| 18 |
+
# ## 2. Nạp model + hàm dự đoán
|
| 19 |
+
|
| 20 |
+
# %%
|
| 21 |
+
import re, json, librosa
|
| 22 |
+
|
| 23 |
+
GEMINI_MODEL = "gemini-2.0-flash"
|
| 24 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 25 |
+
|
| 26 |
+
_M = {}
|
| 27 |
+
|
| 28 |
+
def _qmos():
|
| 29 |
+
if "qmos" not in _M:
|
| 30 |
+
import torch
|
| 31 |
+
_M["qmos"] = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
| 32 |
+
return _M["qmos"]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _emocat():
|
| 36 |
+
if "emocat" not in _M:
|
| 37 |
+
from funasr import AutoModel
|
| 38 |
+
_M["emocat"] = AutoModel(model="iic/emotion2vec_plus_large", hub="hf")
|
| 39 |
+
return _M["emocat"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _gemini_emos_vad(audio_path, target_emotion, api_key):
|
| 43 |
+
"""EMOS (1-5 độ khớp cảm xúc target) + VAD (val/aro/dom 1-5) — bản demo gọn qua Gemini."""
|
| 44 |
+
from google import genai
|
| 45 |
+
from google.genai import types
|
| 46 |
+
client = genai.Client(api_key=api_key)
|
| 47 |
+
part = types.Part.from_bytes(data=open(audio_path, "rb").read(), mime_type="audio/wav")
|
| 48 |
+
cfg = types.GenerateContentConfig(temperature=0.0)
|
| 49 |
+
|
| 50 |
+
p_emos = (f"The target emotion is '{target_emotion}'. On a scale of 1 to 5, how well does the "
|
| 51 |
+
f"speaker express that emotion? 5=perfect match, 1=no match. Answer with ONLY one integer 1-5.")
|
| 52 |
+
r = client.models.generate_content(model=GEMINI_MODEL, config=cfg, contents=[p_emos, part])
|
| 53 |
+
mm = re.search(r"[1-5]", getattr(r, "text", "") or "")
|
| 54 |
+
emos = int(mm.group()) if mm else None
|
| 55 |
+
|
| 56 |
+
p_vad = ('Rate this speech on three 1-5 scales: Valence (1=very negative,5=very positive), '
|
| 57 |
+
'Arousal (1=very calm,5=very excited), Dominance (1=very submissive,5=very dominant). '
|
| 58 |
+
'Answer ONLY as JSON: {"val":x,"aro":y,"dom":z}.')
|
| 59 |
+
r2 = client.models.generate_content(model=GEMINI_MODEL, config=cfg, contents=[p_vad, part])
|
| 60 |
+
val = aro = dom = None
|
| 61 |
+
try:
|
| 62 |
+
d = json.loads(re.search(r"\{.*\}", getattr(r2, "text", "") or "", re.S).group())
|
| 63 |
+
val, aro, dom = d.get("val"), d.get("aro"), d.get("dom")
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
return emos, (val, aro, dom)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def predict(audio, target_emotion, gemini_key):
|
| 70 |
+
import torch
|
| 71 |
+
if not audio:
|
| 72 |
+
return "⚠️ Hãy tải audio.", {}
|
| 73 |
+
wav = librosa.load(audio, sr=16000, mono=True)[0]
|
| 74 |
+
# QMOS
|
| 75 |
+
qmos = float(_qmos()(torch.from_numpy(wav).unsqueeze(0), sr=16000).mean().item())
|
| 76 |
+
# CAT
|
| 77 |
+
rec = _emocat().generate(audio, granularity="utterance", extract_embedding=False)
|
| 78 |
+
probs = {e: 0.0 for e in EMOTIONS5}
|
| 79 |
+
for lab, sc in zip(rec[0]["labels"], rec[0]["scores"]):
|
| 80 |
+
name = lab.split("/")[-1]
|
| 81 |
+
if name in probs:
|
| 82 |
+
probs[name] = float(sc)
|
| 83 |
+
tot = sum(probs.values())
|
| 84 |
+
if tot > 0:
|
| 85 |
+
probs = {k: v / tot for k, v in probs.items()}
|
| 86 |
+
|
| 87 |
+
lines = [f"QMOS (chất lượng giọng, 1–5): {qmos:.3f}"]
|
| 88 |
+
if gemini_key and target_emotion:
|
| 89 |
+
try:
|
| 90 |
+
emos, (val, aro, dom) = _gemini_emos_vad(audio, target_emotion, gemini_key)
|
| 91 |
+
lines.append(f"EMOS (độ khớp cảm xúc '{target_emotion}', 1–5): {emos}")
|
| 92 |
+
lines.append(f"VAD — Valence: {val} · Arousal: {aro} · Dominance: {dom}")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
lines.append(f"(EMOS/VAD lỗi: {e})")
|
| 95 |
+
else:
|
| 96 |
+
lines.append("(EMOS/VAD: dán GEMINI_API_KEY + chọn cảm xúc target để bật)")
|
| 97 |
+
return "\n".join(lines), probs
|
| 98 |
+
|
| 99 |
+
# %% [markdown]
|
| 100 |
+
# ## 3. Giao diện Gradio + launch
|
| 101 |
+
|
| 102 |
+
# %%
|
| 103 |
+
import gradio as gr
|
| 104 |
+
|
| 105 |
+
with gr.Blocks(title="VMC2026 Track 2 — Emotional TTS") as demo:
|
| 106 |
+
gr.Markdown("# 🎙️ Track 2 · Emotional TTS (QMOS / CAT / EMOS / VAD)\n"
|
| 107 |
+
"QMOS + phân bố cảm xúc (CAT) chạy ngay. EMOS/VAD cần Gemini key + cảm xúc target.")
|
| 108 |
+
a = gr.Audio(type="filepath", label="Audio")
|
| 109 |
+
with gr.Row():
|
| 110 |
+
tgt = gr.Dropdown(EMOTIONS5, label="Cảm xúc target (cho EMOS, tùy chọn)")
|
| 111 |
+
key = gr.Textbox(label="GEMINI_API_KEY (tùy chọn)", type="password")
|
| 112 |
+
out = gr.Textbox(label="Kết quả số", lines=5)
|
| 113 |
+
lbl = gr.Label(label="CAT — phân bố cảm xúc cảm nhận")
|
| 114 |
+
gr.Button("Dự đoán", variant="primary").click(predict, [a, tgt, key], [out, lbl])
|
| 115 |
+
|
| 116 |
+
demo.launch(share=True)
|
| 117 |
+
|
| 118 |
+
# %% [markdown]
|
| 119 |
+
# ## Ghi chú
|
| 120 |
+
# - EMOS/VAD là bản demo gọn (prompt rút gọn) — KHÔNG hoàn toàn giống script baseline gốc, chỉ minh họa.
|
track2/exp02_train_emos.ipynb
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "b886ed3a",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp02 (EMOS có train) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục tiêu:** train một model dự đoán **EMOS** (độ khớp cảm xúc target) từ ~12.746 mẫu\n",
|
| 11 |
+
"có nhãn người nghe trong `sets/train.csv`, kỳ vọng **vượt baseline 0.194** (exp01 offline).\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"## Ý tưởng (đọc 1 lần cho hiểu)\n",
|
| 14 |
+
"EMOS phụ thuộc **cả audio LẪN cảm xúc target** (cùng audio \"vui\": target=happy → điểm cao,\n",
|
| 15 |
+
"target=sad → điểm thấp). Vì vậy model phải nhận vào cả hai:\n",
|
| 16 |
+
"\n",
|
| 17 |
+
"```\n",
|
| 18 |
+
"mỗi wav ─► emotion2vec ─► (a) embedding ~D chiều ┐\n",
|
| 19 |
+
" (b) xác suất 5 cảm xúc ├─► nối ─► MLP head ─► EMOS (1–5)\n",
|
| 20 |
+
" target emotion ───► one-hot 5 chiều ┘ (CÁI MÌNH TRAIN)\n",
|
| 21 |
+
"```\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"- **Backbone emotion2vec ĐÓNG BĂNG** (không train lại) → chỉ trích đặc trưng. Nhẹ GPU, ít data vẫn ổn.\n",
|
| 24 |
+
"- **Chỉ train MLP head nhỏ** → học ánh xạ `(đặc trưng + target) → điểm người chấm`.\n",
|
| 25 |
+
"- **Nhãn vàng** = trung bình `eMOS` của mọi listener trên cùng 1 wav (gộp theo `wavID`).\n",
|
| 26 |
+
"- Embedding **trích 1 lần → cache .npz** (12.746 file rất lâu, chạy lại tốn giờ GPU).\n",
|
| 27 |
+
"- Tách 10% train làm **validation nội bộ** → đo SRCC trong lúc train (DEV không có nhãn để tự chấm).\n",
|
| 28 |
+
"- Cuối cùng xuất `answer.txt` **đầy đủ**: QMOS=SpeechMOS · CAT=emotion2vec · **EMOS=head vừa train** → nộp được ngay.\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"**Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On** → + Add Input dataset\n",
|
| 31 |
+
"Track 2 (15.477 wav, có `sets/train.csv`) → sửa `DATA_ROOT` ở cell 0 → Run All."
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "markdown",
|
| 36 |
+
"id": "d0fadc26",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"source": [
|
| 39 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"id": "7fac05b4",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"import os, glob, json, time\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"# ── Data Track 2 (dataset 15.477 wav đã ráp, có sets/train.csv) ──────────────\n",
|
| 52 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
|
| 53 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 54 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) → target emotion\n",
|
| 55 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # nhãn người nghe: lisID,wavID,qMOS,emoCat,eMOS,val,dom,aro\n",
|
| 56 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV (tập cần nộp ở training phase)\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 59 |
+
"CACHE_DIR = \"/kaggle/working/emb_cache\" # nơi lưu embedding đã trích (tái dùng giữa các lần chạy)\n",
|
| 60 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"# ── Siêu tham số train (đổi nếu muốn thử nghiệm) ─────────────────────────────\n",
|
| 63 |
+
"DEVICE = \"cuda\" # \"cuda\" trên Kaggle GPU; \"cpu\" nếu không có GPU\n",
|
| 64 |
+
"HIDDEN = 256 # số neuron lớp ẩn của MLP head\n",
|
| 65 |
+
"DROPOUT = 0.3\n",
|
| 66 |
+
"LR = 1e-3\n",
|
| 67 |
+
"EPOCHS = 60\n",
|
| 68 |
+
"BATCH = 64\n",
|
| 69 |
+
"VAL_FRAC = 0.10 # 10% train → validation nội bộ (đo SRCC)\n",
|
| 70 |
+
"PATIENCE = 12 # early stop: dừng nếu val-SRCC không cải thiện sau N epoch\n",
|
| 71 |
+
"SEED = 42\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"LIMIT_TRAIN = None # đặt số nhỏ (vd 300) để chạy thử nhanh; None = full\n",
|
| 74 |
+
"USE_CLASSPROB = True # thêm 5 xác suất cảm xúc của emotion2vec vào feature (tín hiệu exp01)\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"_EMO_ALIAS = {\n",
|
| 79 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 80 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 81 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 82 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 83 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 84 |
+
"}\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"def norm_emotion(label):\n",
|
| 87 |
+
" \"\"\"Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp.\"\"\"\n",
|
| 88 |
+
" key = str(label).strip().lower()\n",
|
| 89 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"def stem(path_or_name):\n",
|
| 92 |
+
" \"\"\"Lấy tên file không đuôi, để khớp wavID giữa train.csv / metadata / dev.scp.\"\"\"\n",
|
| 93 |
+
" return os.path.splitext(os.path.basename(str(path_or_name)))[0]\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 96 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 97 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "markdown",
|
| 102 |
+
"id": "b947aceb",
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"source": [
|
| 105 |
+
"## 1. Cài đặt"
|
| 106 |
+
]
|
| 107 |
+
},
|
| 108 |
+
{
|
| 109 |
+
"cell_type": "code",
|
| 110 |
+
"execution_count": null,
|
| 111 |
+
"id": "49850676",
|
| 112 |
+
"metadata": {},
|
| 113 |
+
"outputs": [],
|
| 114 |
+
"source": [
|
| 115 |
+
"!pip install -q speechmos funasr librosa soundfile pandas scipy scikit-learn tqdm"
|
| 116 |
+
]
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"cell_type": "markdown",
|
| 120 |
+
"id": "b4e67d6a",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"source": [
|
| 123 |
+
"## 2. Đọc & gộp nhãn\n",
|
| 124 |
+
"- `train.csv`: mỗi dòng = 1 listener chấm 1 wav → **gộp trung bình eMOS theo wavID** = nhãn vàng.\n",
|
| 125 |
+
"- `metadata.csv`: lấy **cảm xúc target** cho mỗi wav (chuẩn hóa về 5 lớp)."
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": null,
|
| 131 |
+
"id": "88b2fb84",
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"outputs": [],
|
| 134 |
+
"source": [
|
| 135 |
+
"import pandas as pd\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"def load_target_emotions():\n",
|
| 138 |
+
" \"\"\"metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}.\"\"\"\n",
|
| 139 |
+
" tgt = {}\n",
|
| 140 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 141 |
+
" for ln in f:\n",
|
| 142 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 143 |
+
" if len(parts) < 2:\n",
|
| 144 |
+
" continue\n",
|
| 145 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 146 |
+
" return tgt\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"def load_train_labels():\n",
|
| 149 |
+
" \"\"\"train.csv → DataFrame [wavID(stem), emos] đã gộp trung bình theo wav.\"\"\"\n",
|
| 150 |
+
" df = pd.read_csv(TRAIN_CSV)\n",
|
| 151 |
+
" # Chuẩn hóa tên cột (phòng khi viết hoa/thường khác nhau)\n",
|
| 152 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 153 |
+
" wav_col = cols.get(\"wavid\") or cols.get(\"wav\") or list(df.columns)[1]\n",
|
| 154 |
+
" emos_col = cols.get(\"emos\") or cols.get(\"emo\") or cols.get(\"emomos\")\n",
|
| 155 |
+
" assert emos_col, f\"Không thấy cột eMOS trong train.csv (cột hiện có: {list(df.columns)})\"\n",
|
| 156 |
+
" g = df.groupby(df[wav_col].map(stem))[emos_col].mean()\n",
|
| 157 |
+
" out = g.reset_index()\n",
|
| 158 |
+
" out.columns = [\"wavID\", \"emos\"]\n",
|
| 159 |
+
" return out\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"target_map = load_target_emotions()\n",
|
| 162 |
+
"train_df = load_train_labels()\n",
|
| 163 |
+
"print(f\"Target emotions: {len(target_map)} | wav train (đã gộp): {len(train_df)}\")\n",
|
| 164 |
+
"print(\"eMOS thống kê:\", train_df[\"emos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
|
| 165 |
+
"train_df.head()"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"cell_type": "markdown",
|
| 170 |
+
"id": "06dea3ef",
|
| 171 |
+
"metadata": {},
|
| 172 |
+
"source": [
|
| 173 |
+
"## 3. Trích đặc trưng emotion2vec (có cache)\n",
|
| 174 |
+
"Mỗi wav → 1 lần `generate(extract_embedding=True)` cho ra **embedding** (cho EMOS) +\n",
|
| 175 |
+
"**xác suất 5 lớp** (cho CAT và làm feature). Lưu cache `.npz` để lần sau khỏi chạy lại."
|
| 176 |
+
]
|
| 177 |
+
},
|
| 178 |
+
{
|
| 179 |
+
"cell_type": "code",
|
| 180 |
+
"execution_count": null,
|
| 181 |
+
"id": "c63cc1f5",
|
| 182 |
+
"metadata": {
|
| 183 |
+
"lines_to_next_cell": 1
|
| 184 |
+
},
|
| 185 |
+
"outputs": [],
|
| 186 |
+
"source": [
|
| 187 |
+
"import numpy as np\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"_e2v_model = None\n",
|
| 190 |
+
"def get_e2v():\n",
|
| 191 |
+
" global _e2v_model\n",
|
| 192 |
+
" if _e2v_model is None:\n",
|
| 193 |
+
" from funasr import AutoModel\n",
|
| 194 |
+
" _e2v_model = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\")\n",
|
| 195 |
+
" return _e2v_model\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"def extract_one(wav_path):\n",
|
| 198 |
+
" \"\"\"→ (emb: np.float32[D], probs5: np.float32[5] tổng=1). None nếu lỗi/thiếu file.\"\"\"\n",
|
| 199 |
+
" if not os.path.exists(wav_path):\n",
|
| 200 |
+
" return None\n",
|
| 201 |
+
" rec = get_e2v().generate(wav_path, granularity=\"utterance\", extract_embedding=True)\n",
|
| 202 |
+
" r = rec[0]\n",
|
| 203 |
+
" emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
|
| 204 |
+
" probs = {e: 0.0 for e in EMOTIONS5}\n",
|
| 205 |
+
" for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
|
| 206 |
+
" name = lab.split(\"/\")[-1]\n",
|
| 207 |
+
" if name in probs:\n",
|
| 208 |
+
" probs[name] = float(sc)\n",
|
| 209 |
+
" tot = sum(probs.values())\n",
|
| 210 |
+
" if tot > 0:\n",
|
| 211 |
+
" probs = {k: v / tot for k, v in probs.items()}\n",
|
| 212 |
+
" probs5 = np.array([probs[e] for e in EMOTIONS5], dtype=np.float32)\n",
|
| 213 |
+
" return emb, probs5\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"def extract_set(stems, tag):\n",
|
| 216 |
+
" \"\"\"Trích (hoặc nạp cache) cho danh sách stem. Trả về dict {stem: (emb, probs5)}.\n",
|
| 217 |
+
" Cache lưu tại CACHE_DIR/<tag>.npz; tự b�� qua stem đã có để chạy nối tiếp được.\"\"\"\n",
|
| 218 |
+
" from tqdm.auto import tqdm\n",
|
| 219 |
+
" cache_path = os.path.join(CACHE_DIR, f\"{tag}.npz\")\n",
|
| 220 |
+
" store = {}\n",
|
| 221 |
+
" if os.path.exists(cache_path):\n",
|
| 222 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 223 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 224 |
+
" print(f\"[{tag}] nạp cache: {len(store)} mẫu\")\n",
|
| 225 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 226 |
+
" if not todo:\n",
|
| 227 |
+
" print(f\"[{tag}] đủ cache, bỏ qua trích.\")\n",
|
| 228 |
+
" else:\n",
|
| 229 |
+
" miss = 0\n",
|
| 230 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"trích {tag}\")):\n",
|
| 231 |
+
" res = extract_one(os.path.join(WAV_DIR, s + \".wav\"))\n",
|
| 232 |
+
" if res is None:\n",
|
| 233 |
+
" miss += 1\n",
|
| 234 |
+
" continue\n",
|
| 235 |
+
" emb, probs5 = res\n",
|
| 236 |
+
" store[s] = np.concatenate([emb, probs5]).astype(np.float32) # [D + 5]\n",
|
| 237 |
+
" if (i + 1) % 500 == 0: # lưu cache định kỳ phòng ngắt session\n",
|
| 238 |
+
" np.savez(cache_path, **store)\n",
|
| 239 |
+
" np.savez(cache_path, **store)\n",
|
| 240 |
+
" if miss:\n",
|
| 241 |
+
" print(f\"[{tag}] {miss} file thiếu/ lỗi → bỏ qua.\")\n",
|
| 242 |
+
" print(f\"[{tag}] tổng cache: {len(store)} mẫu → {cache_path}\")\n",
|
| 243 |
+
" # tách lại thành (emb, probs5)\n",
|
| 244 |
+
" out = {}\n",
|
| 245 |
+
" for s, vec in store.items():\n",
|
| 246 |
+
" out[s] = (vec[:-5], vec[-5:])\n",
|
| 247 |
+
" return out\n",
|
| 248 |
+
"\n",
|
| 249 |
+
"# Trích cho tập train\n",
|
| 250 |
+
"train_stems = list(train_df[\"wavID\"])\n",
|
| 251 |
+
"if LIMIT_TRAIN:\n",
|
| 252 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 253 |
+
"train_feat = extract_set(train_stems, \"train\")\n",
|
| 254 |
+
"EMB_DIM = next(iter(train_feat.values()))[0].shape[0]\n",
|
| 255 |
+
"print(\"EMB_DIM =\", EMB_DIM)"
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "markdown",
|
| 260 |
+
"id": "015834c1",
|
| 261 |
+
"metadata": {},
|
| 262 |
+
"source": [
|
| 263 |
+
"## 4. Dựng feature + nhãn cho train\n",
|
| 264 |
+
"Feature mỗi wav = `[embedding | (probs5 nếu bật) | one-hot target(5)]`. Bỏ wav thiếu target/feature."
|
| 265 |
+
]
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
"cell_type": "code",
|
| 269 |
+
"execution_count": null,
|
| 270 |
+
"id": "62e048bd",
|
| 271 |
+
"metadata": {},
|
| 272 |
+
"outputs": [],
|
| 273 |
+
"source": [
|
| 274 |
+
"def onehot_target(tgt):\n",
|
| 275 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 276 |
+
" if tgt in EMOTIONS5:\n",
|
| 277 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 278 |
+
" return v\n",
|
| 279 |
+
"\n",
|
| 280 |
+
"def build_feature(stem_id, feat_map):\n",
|
| 281 |
+
" pack = feat_map.get(stem_id)\n",
|
| 282 |
+
" if pack is None:\n",
|
| 283 |
+
" return None\n",
|
| 284 |
+
" emb, probs5 = pack\n",
|
| 285 |
+
" tgt = target_map.get(stem_id)\n",
|
| 286 |
+
" if tgt is None: # không biết cảm xúc target → không train được mẫu này\n",
|
| 287 |
+
" return None\n",
|
| 288 |
+
" parts = [emb]\n",
|
| 289 |
+
" if USE_CLASSPROB:\n",
|
| 290 |
+
" parts.append(probs5)\n",
|
| 291 |
+
" parts.append(onehot_target(tgt))\n",
|
| 292 |
+
" return np.concatenate(parts).astype(np.float32)\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"emos_label = dict(zip(train_df[\"wavID\"], train_df[\"emos\"]))\n",
|
| 295 |
+
"X, y = [], []\n",
|
| 296 |
+
"for s in train_stems:\n",
|
| 297 |
+
" f = build_feature(s, train_feat)\n",
|
| 298 |
+
" if f is None or s not in emos_label:\n",
|
| 299 |
+
" continue\n",
|
| 300 |
+
" X.append(f); y.append(emos_label[s])\n",
|
| 301 |
+
"X = np.stack(X); y = np.array(y, dtype=np.float32)\n",
|
| 302 |
+
"FEAT_DIM = X.shape[1]\n",
|
| 303 |
+
"print(f\"Train: X={X.shape} y={y.shape} FEAT_DIM={FEAT_DIM}\")\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"# Chuẩn hóa feature (z-score) — lưu mean/std để áp dụng y hệt lúc dự đoán DEV.\n",
|
| 306 |
+
"feat_mean = X.mean(0, keepdims=True)\n",
|
| 307 |
+
"feat_std = X.std(0, keepdims=True) + 1e-6\n",
|
| 308 |
+
"Xn = (X - feat_mean) / feat_std"
|
| 309 |
+
]
|
| 310 |
+
},
|
| 311 |
+
{
|
| 312 |
+
"cell_type": "markdown",
|
| 313 |
+
"id": "02718889",
|
| 314 |
+
"metadata": {},
|
| 315 |
+
"source": [
|
| 316 |
+
"## 5. Model (MLP head) + train loop\n",
|
| 317 |
+
"Loss = MSE. Theo dõi **SRCC** trên validation nội bộ; lưu model tốt nhất (early stopping)."
|
| 318 |
+
]
|
| 319 |
+
},
|
| 320 |
+
{
|
| 321 |
+
"cell_type": "code",
|
| 322 |
+
"execution_count": null,
|
| 323 |
+
"id": "f5c52fa4",
|
| 324 |
+
"metadata": {
|
| 325 |
+
"lines_to_next_cell": 1
|
| 326 |
+
},
|
| 327 |
+
"outputs": [],
|
| 328 |
+
"source": [
|
| 329 |
+
"import torch, torch.nn as nn\n",
|
| 330 |
+
"from scipy.stats import spearmanr\n",
|
| 331 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 334 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 335 |
+
"print(\"Device:\", device)\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"Xtr, Xva, ytr, yva = train_test_split(Xn, y, test_size=VAL_FRAC, random_state=SEED)\n",
|
| 338 |
+
"Xtr_t = torch.tensor(Xtr, device=device); ytr_t = torch.tensor(ytr, device=device).unsqueeze(1)\n",
|
| 339 |
+
"Xva_t = torch.tensor(Xva, device=device); yva_t = torch.tensor(yva, device=device).unsqueeze(1)\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"class EmosHead(nn.Module):\n",
|
| 342 |
+
" def __init__(self, d_in, hidden, p):\n",
|
| 343 |
+
" super().__init__()\n",
|
| 344 |
+
" self.net = nn.Sequential(\n",
|
| 345 |
+
" nn.Linear(d_in, hidden), nn.ReLU(), nn.Dropout(p),\n",
|
| 346 |
+
" nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Dropout(p),\n",
|
| 347 |
+
" nn.Linear(hidden // 2, 1),\n",
|
| 348 |
+
" )\n",
|
| 349 |
+
" def forward(self, x):\n",
|
| 350 |
+
" return self.net(x)\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"model = EmosHead(FEAT_DIM, HIDDEN, DROPOUT).to(device)\n",
|
| 353 |
+
"opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)\n",
|
| 354 |
+
"lossf = nn.MSELoss()\n",
|
| 355 |
+
"\n",
|
| 356 |
+
"def val_srcc():\n",
|
| 357 |
+
" model.eval()\n",
|
| 358 |
+
" with torch.no_grad():\n",
|
| 359 |
+
" pred = model(Xva_t).cpu().numpy().ravel()\n",
|
| 360 |
+
" return spearmanr(pred, yva).correlation\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"best_srcc, best_state, bad = -1.0, None, 0\n",
|
| 363 |
+
"n = Xtr_t.shape[0]\n",
|
| 364 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 365 |
+
" model.train()\n",
|
| 366 |
+
" perm = torch.randperm(n, device=device)\n",
|
| 367 |
+
" tot = 0.0\n",
|
| 368 |
+
" for i in range(0, n, BATCH):\n",
|
| 369 |
+
" idx = perm[i:i + BATCH]\n",
|
| 370 |
+
" opt.zero_grad()\n",
|
| 371 |
+
" out = model(Xtr_t[idx])\n",
|
| 372 |
+
" loss = lossf(out, ytr_t[idx])\n",
|
| 373 |
+
" loss.backward(); opt.step()\n",
|
| 374 |
+
" tot += loss.item() * len(idx)\n",
|
| 375 |
+
" srcc = val_srcc()\n",
|
| 376 |
+
" if srcc > best_srcc:\n",
|
| 377 |
+
" best_srcc, best_state, bad = srcc, {k: v.cpu().clone() for k, v in model.state_dict().items()}, 0\n",
|
| 378 |
+
" else:\n",
|
| 379 |
+
" bad += 1\n",
|
| 380 |
+
" if ep % 5 == 0 or ep == 1:\n",
|
| 381 |
+
" print(f\"epoch {ep:3d} | train MSE {tot/n:.4f} | val SRCC {srcc:.4f} | best {best_srcc:.4f}\")\n",
|
| 382 |
+
" if bad >= PATIENCE:\n",
|
| 383 |
+
" print(f\"Early stop ở epoch {ep} (val SRCC không tăng {PATIENCE} epoch).\")\n",
|
| 384 |
+
" break\n",
|
| 385 |
+
"\n",
|
| 386 |
+
"model.load_state_dict(best_state)\n",
|
| 387 |
+
"print(f\"\\n✅ VAL SRCC tốt nhất = {best_srcc:.4f} (baseline exp01 ≈ 0.194 — so ở đây)\")\n",
|
| 388 |
+
"\n",
|
| 389 |
+
"# Lưu model + tham số chuẩn hóa để tái dùng / mô tả hệ thống.\n",
|
| 390 |
+
"torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
|
| 391 |
+
" \"EMB_DIM\": EMB_DIM, \"FEAT_DIM\": FEAT_DIM, \"USE_CLASSPROB\": USE_CLASSPROB,\n",
|
| 392 |
+
" \"EMOTIONS5\": EMOTIONS5, \"val_srcc\": float(best_srcc)},\n",
|
| 393 |
+
" os.path.join(OUT_DIR, \"emos_head.pt\"))\n",
|
| 394 |
+
"print(\"Đã lưu\", os.path.join(OUT_DIR, \"emos_head.pt\"))"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"cell_type": "markdown",
|
| 399 |
+
"id": "aef6f3ee",
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"source": [
|
| 402 |
+
"## 6. Dự đoán DEV → `answer.txt` đầy đủ\n",
|
| 403 |
+
"- **EMOS** = head vừa train (cần embedding + target của từng wav DEV).\n",
|
| 404 |
+
"- **CAT** = xác suất 5 lớp emotion2vec (đã có sẵn khi trích đặc trưng).\n",
|
| 405 |
+
"- **QMOS** = SpeechMOS (UTMOS) — bắt buộc, chạy thêm ở đây để answer.txt hợp lệ."
|
| 406 |
+
]
|
| 407 |
+
},
|
| 408 |
+
{
|
| 409 |
+
"cell_type": "code",
|
| 410 |
+
"execution_count": null,
|
| 411 |
+
"id": "8f94d7ab",
|
| 412 |
+
"metadata": {
|
| 413 |
+
"lines_to_next_cell": 1
|
| 414 |
+
},
|
| 415 |
+
"outputs": [],
|
| 416 |
+
"source": [
|
| 417 |
+
"def list_dev():\n",
|
| 418 |
+
" with open(DEV_SCP) as f:\n",
|
| 419 |
+
" return [ln.strip() for ln in f if ln.strip()] # tên file .wav\n",
|
| 420 |
+
"\n",
|
| 421 |
+
"dev_names = list_dev()\n",
|
| 422 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 423 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 424 |
+
"\n",
|
| 425 |
+
"# 6a. Trích đặc trưng emotion2vec cho DEV (cache riêng)\n",
|
| 426 |
+
"dev_feat = extract_set(dev_stems, \"dev\")\n",
|
| 427 |
+
"\n",
|
| 428 |
+
"# 6b. EMOS từ head đã train\n",
|
| 429 |
+
"def predict_emos(stem_id):\n",
|
| 430 |
+
" f = build_feature(stem_id, dev_feat)\n",
|
| 431 |
+
" if f is None:\n",
|
| 432 |
+
" return None\n",
|
| 433 |
+
" fn = (f[None, :] - feat_mean) / feat_std\n",
|
| 434 |
+
" model.eval()\n",
|
| 435 |
+
" with torch.no_grad():\n",
|
| 436 |
+
" return float(model(torch.tensor(fn, dtype=torch.float32, device=device)).item())\n",
|
| 437 |
+
"\n",
|
| 438 |
+
"# 6c. QMOS = SpeechMOS\n",
|
| 439 |
+
"def run_qmos(names):\n",
|
| 440 |
+
" import librosa\n",
|
| 441 |
+
" predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True)\n",
|
| 442 |
+
" out = {}\n",
|
| 443 |
+
" from tqdm.auto import tqdm\n",
|
| 444 |
+
" for n in tqdm(names, desc=\"QMOS\"):\n",
|
| 445 |
+
" p = os.path.join(WAV_DIR, n)\n",
|
| 446 |
+
" if not os.path.exists(p):\n",
|
| 447 |
+
" continue\n",
|
| 448 |
+
" wave, _ = librosa.load(p, sr=16000, mono=True)\n",
|
| 449 |
+
" out[n] = float(predictor(torch.from_numpy(wave).unsqueeze(0), sr=16000).mean().item())\n",
|
| 450 |
+
" return out\n",
|
| 451 |
+
"\n",
|
| 452 |
+
"qmos_scores = run_qmos(dev_names)"
|
| 453 |
+
]
|
| 454 |
+
},
|
| 455 |
+
{
|
| 456 |
+
"cell_type": "code",
|
| 457 |
+
"execution_count": null,
|
| 458 |
+
"id": "6a6680d0",
|
| 459 |
+
"metadata": {
|
| 460 |
+
"lines_to_next_cell": 1
|
| 461 |
+
},
|
| 462 |
+
"outputs": [],
|
| 463 |
+
"source": [
|
| 464 |
+
"def fmt_cat(probs5):\n",
|
| 465 |
+
" return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"def build_answer(out_path):\n",
|
| 468 |
+
" n_emos = n_default = 0\n",
|
| 469 |
+
" with open(out_path, \"w\") as f:\n",
|
| 470 |
+
" f.write(\"wav,QMOS,EMOS,CAT\\n\")\n",
|
| 471 |
+
" for name in dev_names:\n",
|
| 472 |
+
" sid = stem(name)\n",
|
| 473 |
+
" emos = predict_emos(sid)\n",
|
| 474 |
+
" if emos is None:\n",
|
| 475 |
+
" emos = 3.0; n_default += 1\n",
|
| 476 |
+
" else:\n",
|
| 477 |
+
" n_emos += 1\n",
|
| 478 |
+
" qmos = qmos_scores.get(name, 3.0)\n",
|
| 479 |
+
" probs5 = dev_feat[sid][1] if sid in dev_feat else np.full(5, 0.2, dtype=np.float32)\n",
|
| 480 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(probs5)}\\n\")\n",
|
| 481 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default}\")\n",
|
| 482 |
+
"\n",
|
| 483 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 484 |
+
"build_answer(answer_path)"
|
| 485 |
+
]
|
| 486 |
+
},
|
| 487 |
+
{
|
| 488 |
+
"cell_type": "markdown",
|
| 489 |
+
"id": "6179873f",
|
| 490 |
+
"metadata": {},
|
| 491 |
+
"source": [
|
| 492 |
+
"## 7. Validate + đóng zip"
|
| 493 |
+
]
|
| 494 |
+
},
|
| 495 |
+
{
|
| 496 |
+
"cell_type": "code",
|
| 497 |
+
"execution_count": null,
|
| 498 |
+
"id": "30ee8626",
|
| 499 |
+
"metadata": {},
|
| 500 |
+
"outputs": [],
|
| 501 |
+
"source": [
|
| 502 |
+
"def validate(path):\n",
|
| 503 |
+
" import csv\n",
|
| 504 |
+
" with open(path) as f:\n",
|
| 505 |
+
" rows = list(csv.reader(f))\n",
|
| 506 |
+
" header = rows[0]\n",
|
| 507 |
+
" assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
|
| 508 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 509 |
+
" assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
|
| 510 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
|
| 511 |
+
"\n",
|
| 512 |
+
"validate(answer_path)\n",
|
| 513 |
+
"!cd /kaggle/working && zip -j submission_track2_exp02.zip answer.txt && unzip -l submission_track2_exp02.zip\n",
|
| 514 |
+
"print(\"Sẵn sàng nộp: /kaggle/working/submission_track2_exp02.zip\")"
|
| 515 |
+
]
|
| 516 |
+
},
|
| 517 |
+
{
|
| 518 |
+
"cell_type": "markdown",
|
| 519 |
+
"id": "316e3e1f",
|
| 520 |
+
"metadata": {},
|
| 521 |
+
"source": [
|
| 522 |
+
"## Ghi chú\n",
|
| 523 |
+
"- **VAL SRCC** in ở mục 5 là ước lượng nội bộ (10% train) — so với baseline 0.194 để biết có khá hơn không.\n",
|
| 524 |
+
" Điểm DEV thật phải nộp lên CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).\n",
|
| 525 |
+
"- Muốn thử nhanh: đặt `LIMIT_TRAIN = 300` ở cell 0.\n",
|
| 526 |
+
"- Embedding đã cache trong `/kaggle/working/emb_cache/` → **Save Version** để giữ, lần sau train head khỏi trích lại.\n",
|
| 527 |
+
"- Hướng cải tiến tiếp: thêm head QMOS/CAT/VAD dùng chung backbone (exp02 multi-task đầy đủ);\n",
|
| 528 |
+
" thử backbone wav2vec2/WavLM; thêm ranking loss; fine-tune nhẹ backbone.\n",
|
| 529 |
+
"- Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp02)."
|
| 530 |
+
]
|
| 531 |
+
}
|
| 532 |
+
],
|
| 533 |
+
"metadata": {
|
| 534 |
+
"jupytext": {
|
| 535 |
+
"cell_metadata_filter": "-all",
|
| 536 |
+
"main_language": "python",
|
| 537 |
+
"notebook_metadata_filter": "-all"
|
| 538 |
+
}
|
| 539 |
+
},
|
| 540 |
+
"nbformat": 4,
|
| 541 |
+
"nbformat_minor": 5
|
| 542 |
+
}
|
track2/exp02_train_emos_pipeline.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp02 (EMOS có train) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục tiêu:** train một model dự đoán **EMOS** (độ khớp cảm xúc target) từ ~12.746 mẫu
|
| 5 |
+
# có nhãn người nghe trong `sets/train.csv`, kỳ vọng **vượt baseline 0.194** (exp01 offline).
|
| 6 |
+
#
|
| 7 |
+
# ## Ý tưởng (đọc 1 lần cho hiểu)
|
| 8 |
+
# EMOS phụ thuộc **cả audio LẪN cảm xúc target** (cùng audio "vui": target=happy → điểm cao,
|
| 9 |
+
# target=sad → điểm thấp). Vì vậy model phải nhận vào cả hai:
|
| 10 |
+
#
|
| 11 |
+
# ```
|
| 12 |
+
# mỗi wav ─► emotion2vec ─► (a) embedding ~D chiều ┐
|
| 13 |
+
# (b) xác suất 5 cảm xúc ├─► nối ─► MLP head ─► EMOS (1–5)
|
| 14 |
+
# target emotion ───► one-hot 5 chiều ┘ (CÁI MÌNH TRAIN)
|
| 15 |
+
# ```
|
| 16 |
+
#
|
| 17 |
+
# - **Backbone emotion2vec ĐÓNG BĂNG** (không train lại) → chỉ trích đặc trưng. Nhẹ GPU, ít data vẫn ổn.
|
| 18 |
+
# - **Chỉ train MLP head nhỏ** → học ánh xạ `(đặc trưng + target) → điểm người chấm`.
|
| 19 |
+
# - **Nhãn vàng** = trung bình `eMOS` của mọi listener trên cùng 1 wav (gộp theo `wavID`).
|
| 20 |
+
# - Embedding **trích 1 lần → cache .npz** (12.746 file rất lâu, chạy lại tốn giờ GPU).
|
| 21 |
+
# - Tách 10% train làm **validation nội bộ** → đo SRCC trong lúc train (DEV không có nhãn để tự chấm).
|
| 22 |
+
# - Cuối cùng xuất `answer.txt` **đầy đủ**: QMOS=SpeechMOS · CAT=emotion2vec · **EMOS=head vừa train** → nộp được ngay.
|
| 23 |
+
#
|
| 24 |
+
# **Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On** → + Add Input dataset
|
| 25 |
+
# Track 2 (15.477 wav, có `sets/train.csv`) → sửa `DATA_ROOT` ở cell 0 → Run All.
|
| 26 |
+
|
| 27 |
+
# %% [markdown]
|
| 28 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 29 |
+
|
| 30 |
+
# %%
|
| 31 |
+
import os, glob, json, time
|
| 32 |
+
|
| 33 |
+
# ── Data Track 2 (dataset 15.477 wav đã ráp, có sets/train.csv) ──────────────
|
| 34 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
|
| 35 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 36 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) → target emotion
|
| 37 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # nhãn người nghe: lisID,wavID,qMOS,emoCat,eMOS,val,dom,aro
|
| 38 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV (tập cần nộp ở training phase)
|
| 39 |
+
|
| 40 |
+
OUT_DIR = "/kaggle/working"
|
| 41 |
+
CACHE_DIR = "/kaggle/working/emb_cache" # nơi lưu embedding đã trích (tái dùng giữa các lần chạy)
|
| 42 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 43 |
+
|
| 44 |
+
# ── Siêu tham số train (đổi nếu muốn thử nghiệm) ─────────────────────────────
|
| 45 |
+
DEVICE = "cuda" # "cuda" trên Kaggle GPU; "cpu" nếu không có GPU
|
| 46 |
+
HIDDEN = 256 # số neuron lớp ẩn của MLP head
|
| 47 |
+
DROPOUT = 0.3
|
| 48 |
+
LR = 1e-3
|
| 49 |
+
EPOCHS = 60
|
| 50 |
+
BATCH = 64
|
| 51 |
+
VAL_FRAC = 0.10 # 10% train → validation nội bộ (đo SRCC)
|
| 52 |
+
PATIENCE = 12 # early stop: dừng nếu val-SRCC không cải thiện sau N epoch
|
| 53 |
+
SEED = 42
|
| 54 |
+
|
| 55 |
+
LIMIT_TRAIN = None # đặt số nhỏ (vd 300) để chạy thử nhanh; None = full
|
| 56 |
+
USE_CLASSPROB = True # thêm 5 xác suất cảm xúc của emotion2vec vào feature (tín hiệu exp01)
|
| 57 |
+
|
| 58 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 59 |
+
|
| 60 |
+
_EMO_ALIAS = {
|
| 61 |
+
"angry": "angry", "anger": "angry",
|
| 62 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 63 |
+
"neutral": "neutral", "calm": "neutral",
|
| 64 |
+
"sad": "sad", "sadness": "sad",
|
| 65 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
def norm_emotion(label):
|
| 69 |
+
"""Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp."""
|
| 70 |
+
key = str(label).strip().lower()
|
| 71 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 72 |
+
|
| 73 |
+
def stem(path_or_name):
|
| 74 |
+
"""Lấy tên file không đuôi, để khớp wavID giữa train.csv / metadata / dev.scp."""
|
| 75 |
+
return os.path.splitext(os.path.basename(str(path_or_name)))[0]
|
| 76 |
+
|
| 77 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 78 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 79 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 80 |
+
|
| 81 |
+
# %% [markdown]
|
| 82 |
+
# ## 1. Cài đặt
|
| 83 |
+
|
| 84 |
+
# %%
|
| 85 |
+
# !pip install -q speechmos funasr librosa soundfile pandas scipy scikit-learn tqdm
|
| 86 |
+
|
| 87 |
+
# %% [markdown]
|
| 88 |
+
# ## 2. Đọc & gộp nhãn
|
| 89 |
+
# - `train.csv`: mỗi dòng = 1 listener chấm 1 wav → **gộp trung bình eMOS theo wavID** = nhãn vàng.
|
| 90 |
+
# - `metadata.csv`: lấy **cảm xúc target** cho mỗi wav (chuẩn hóa về 5 lớp).
|
| 91 |
+
|
| 92 |
+
# %%
|
| 93 |
+
import pandas as pd
|
| 94 |
+
|
| 95 |
+
def load_target_emotions():
|
| 96 |
+
"""metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}."""
|
| 97 |
+
tgt = {}
|
| 98 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 99 |
+
for ln in f:
|
| 100 |
+
parts = ln.strip().split("|")
|
| 101 |
+
if len(parts) < 2:
|
| 102 |
+
continue
|
| 103 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 104 |
+
return tgt
|
| 105 |
+
|
| 106 |
+
def load_train_labels():
|
| 107 |
+
"""train.csv → DataFrame [wavID(stem), emos] đã gộp trung bình theo wav."""
|
| 108 |
+
df = pd.read_csv(TRAIN_CSV)
|
| 109 |
+
# Chuẩn hóa tên cột (phòng khi viết hoa/thường khác nhau)
|
| 110 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 111 |
+
wav_col = cols.get("wavid") or cols.get("wav") or list(df.columns)[1]
|
| 112 |
+
emos_col = cols.get("emos") or cols.get("emo") or cols.get("emomos")
|
| 113 |
+
assert emos_col, f"Không thấy cột eMOS trong train.csv (cột hiện có: {list(df.columns)})"
|
| 114 |
+
g = df.groupby(df[wav_col].map(stem))[emos_col].mean()
|
| 115 |
+
out = g.reset_index()
|
| 116 |
+
out.columns = ["wavID", "emos"]
|
| 117 |
+
return out
|
| 118 |
+
|
| 119 |
+
target_map = load_target_emotions()
|
| 120 |
+
train_df = load_train_labels()
|
| 121 |
+
print(f"Target emotions: {len(target_map)} | wav train (đã gộp): {len(train_df)}")
|
| 122 |
+
print("eMOS thống kê:", train_df["emos"].describe()[["mean", "std", "min", "max"]].to_dict())
|
| 123 |
+
train_df.head()
|
| 124 |
+
|
| 125 |
+
# %% [markdown]
|
| 126 |
+
# ## 3. Trích đặc trưng emotion2vec (có cache)
|
| 127 |
+
# Mỗi wav → 1 lần `generate(extract_embedding=True)` cho ra **embedding** (cho EMOS) +
|
| 128 |
+
# **xác suất 5 lớp** (cho CAT và làm feature). Lưu cache `.npz` để lần sau khỏi chạy lại.
|
| 129 |
+
|
| 130 |
+
# %%
|
| 131 |
+
import numpy as np
|
| 132 |
+
|
| 133 |
+
_e2v_model = None
|
| 134 |
+
def get_e2v():
|
| 135 |
+
global _e2v_model
|
| 136 |
+
if _e2v_model is None:
|
| 137 |
+
from funasr import AutoModel
|
| 138 |
+
_e2v_model = AutoModel(model="iic/emotion2vec_plus_large", hub="hf")
|
| 139 |
+
return _e2v_model
|
| 140 |
+
|
| 141 |
+
def extract_one(wav_path):
|
| 142 |
+
"""→ (emb: np.float32[D], probs5: np.float32[5] tổng=1). None nếu lỗi/thiếu file."""
|
| 143 |
+
if not os.path.exists(wav_path):
|
| 144 |
+
return None
|
| 145 |
+
rec = get_e2v().generate(wav_path, granularity="utterance", extract_embedding=True)
|
| 146 |
+
r = rec[0]
|
| 147 |
+
emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
|
| 148 |
+
probs = {e: 0.0 for e in EMOTIONS5}
|
| 149 |
+
for lab, sc in zip(r["labels"], r["scores"]):
|
| 150 |
+
name = lab.split("/")[-1]
|
| 151 |
+
if name in probs:
|
| 152 |
+
probs[name] = float(sc)
|
| 153 |
+
tot = sum(probs.values())
|
| 154 |
+
if tot > 0:
|
| 155 |
+
probs = {k: v / tot for k, v in probs.items()}
|
| 156 |
+
probs5 = np.array([probs[e] for e in EMOTIONS5], dtype=np.float32)
|
| 157 |
+
return emb, probs5
|
| 158 |
+
|
| 159 |
+
def extract_set(stems, tag):
|
| 160 |
+
"""Trích (hoặc nạp cache) cho danh sách stem. Trả về dict {stem: (emb, probs5)}.
|
| 161 |
+
Cache lưu tại CACHE_DIR/<tag>.npz; tự bỏ qua stem đã có để chạy nối tiếp được."""
|
| 162 |
+
from tqdm.auto import tqdm
|
| 163 |
+
cache_path = os.path.join(CACHE_DIR, f"{tag}.npz")
|
| 164 |
+
store = {}
|
| 165 |
+
if os.path.exists(cache_path):
|
| 166 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 167 |
+
store = {k: z[k] for k in z.files}
|
| 168 |
+
print(f"[{tag}] nạp cache: {len(store)} mẫu")
|
| 169 |
+
todo = [s for s in stems if s not in store]
|
| 170 |
+
if not todo:
|
| 171 |
+
print(f"[{tag}] đủ cache, bỏ qua trích.")
|
| 172 |
+
else:
|
| 173 |
+
miss = 0
|
| 174 |
+
for i, s in enumerate(tqdm(todo, desc=f"trích {tag}")):
|
| 175 |
+
res = extract_one(os.path.join(WAV_DIR, s + ".wav"))
|
| 176 |
+
if res is None:
|
| 177 |
+
miss += 1
|
| 178 |
+
continue
|
| 179 |
+
emb, probs5 = res
|
| 180 |
+
store[s] = np.concatenate([emb, probs5]).astype(np.float32) # [D + 5]
|
| 181 |
+
if (i + 1) % 500 == 0: # lưu cache định kỳ phòng ngắt session
|
| 182 |
+
np.savez(cache_path, **store)
|
| 183 |
+
np.savez(cache_path, **store)
|
| 184 |
+
if miss:
|
| 185 |
+
print(f"[{tag}] {miss} file thiếu/ lỗi → bỏ qua.")
|
| 186 |
+
print(f"[{tag}] tổng cache: {len(store)} mẫu → {cache_path}")
|
| 187 |
+
# tách lại thành (emb, probs5)
|
| 188 |
+
out = {}
|
| 189 |
+
for s, vec in store.items():
|
| 190 |
+
out[s] = (vec[:-5], vec[-5:])
|
| 191 |
+
return out
|
| 192 |
+
|
| 193 |
+
# Trích cho tập train
|
| 194 |
+
train_stems = list(train_df["wavID"])
|
| 195 |
+
if LIMIT_TRAIN:
|
| 196 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 197 |
+
train_feat = extract_set(train_stems, "train")
|
| 198 |
+
EMB_DIM = next(iter(train_feat.values()))[0].shape[0]
|
| 199 |
+
print("EMB_DIM =", EMB_DIM)
|
| 200 |
+
|
| 201 |
+
# %% [markdown]
|
| 202 |
+
# ## 4. Dựng feature + nhãn cho train
|
| 203 |
+
# Feature mỗi wav = `[embedding | (probs5 nếu bật) | one-hot target(5)]`. Bỏ wav thiếu target/feature.
|
| 204 |
+
|
| 205 |
+
# %%
|
| 206 |
+
def onehot_target(tgt):
|
| 207 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 208 |
+
if tgt in EMOTIONS5:
|
| 209 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 210 |
+
return v
|
| 211 |
+
|
| 212 |
+
def build_feature(stem_id, feat_map):
|
| 213 |
+
pack = feat_map.get(stem_id)
|
| 214 |
+
if pack is None:
|
| 215 |
+
return None
|
| 216 |
+
emb, probs5 = pack
|
| 217 |
+
tgt = target_map.get(stem_id)
|
| 218 |
+
if tgt is None: # không biết cảm xúc target → không train được mẫu này
|
| 219 |
+
return None
|
| 220 |
+
parts = [emb]
|
| 221 |
+
if USE_CLASSPROB:
|
| 222 |
+
parts.append(probs5)
|
| 223 |
+
parts.append(onehot_target(tgt))
|
| 224 |
+
return np.concatenate(parts).astype(np.float32)
|
| 225 |
+
|
| 226 |
+
emos_label = dict(zip(train_df["wavID"], train_df["emos"]))
|
| 227 |
+
X, y = [], []
|
| 228 |
+
for s in train_stems:
|
| 229 |
+
f = build_feature(s, train_feat)
|
| 230 |
+
if f is None or s not in emos_label:
|
| 231 |
+
continue
|
| 232 |
+
X.append(f); y.append(emos_label[s])
|
| 233 |
+
X = np.stack(X); y = np.array(y, dtype=np.float32)
|
| 234 |
+
FEAT_DIM = X.shape[1]
|
| 235 |
+
print(f"Train: X={X.shape} y={y.shape} FEAT_DIM={FEAT_DIM}")
|
| 236 |
+
|
| 237 |
+
# Chuẩn hóa feature (z-score) — lưu mean/std để áp dụng y hệt lúc dự đoán DEV.
|
| 238 |
+
feat_mean = X.mean(0, keepdims=True)
|
| 239 |
+
feat_std = X.std(0, keepdims=True) + 1e-6
|
| 240 |
+
Xn = (X - feat_mean) / feat_std
|
| 241 |
+
|
| 242 |
+
# %% [markdown]
|
| 243 |
+
# ## 5. Model (MLP head) + train loop
|
| 244 |
+
# Loss = MSE. Theo dõi **SRCC** trên validation nội bộ; lưu model tốt nhất (early stopping).
|
| 245 |
+
|
| 246 |
+
# %%
|
| 247 |
+
import torch, torch.nn as nn
|
| 248 |
+
from scipy.stats import spearmanr
|
| 249 |
+
from sklearn.model_selection import train_test_split
|
| 250 |
+
|
| 251 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 252 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 253 |
+
print("Device:", device)
|
| 254 |
+
|
| 255 |
+
Xtr, Xva, ytr, yva = train_test_split(Xn, y, test_size=VAL_FRAC, random_state=SEED)
|
| 256 |
+
Xtr_t = torch.tensor(Xtr, device=device); ytr_t = torch.tensor(ytr, device=device).unsqueeze(1)
|
| 257 |
+
Xva_t = torch.tensor(Xva, device=device); yva_t = torch.tensor(yva, device=device).unsqueeze(1)
|
| 258 |
+
|
| 259 |
+
class EmosHead(nn.Module):
|
| 260 |
+
def __init__(self, d_in, hidden, p):
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.net = nn.Sequential(
|
| 263 |
+
nn.Linear(d_in, hidden), nn.ReLU(), nn.Dropout(p),
|
| 264 |
+
nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Dropout(p),
|
| 265 |
+
nn.Linear(hidden // 2, 1),
|
| 266 |
+
)
|
| 267 |
+
def forward(self, x):
|
| 268 |
+
return self.net(x)
|
| 269 |
+
|
| 270 |
+
model = EmosHead(FEAT_DIM, HIDDEN, DROPOUT).to(device)
|
| 271 |
+
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
|
| 272 |
+
lossf = nn.MSELoss()
|
| 273 |
+
|
| 274 |
+
def val_srcc():
|
| 275 |
+
model.eval()
|
| 276 |
+
with torch.no_grad():
|
| 277 |
+
pred = model(Xva_t).cpu().numpy().ravel()
|
| 278 |
+
return spearmanr(pred, yva).correlation
|
| 279 |
+
|
| 280 |
+
best_srcc, best_state, bad = -1.0, None, 0
|
| 281 |
+
n = Xtr_t.shape[0]
|
| 282 |
+
for ep in range(1, EPOCHS + 1):
|
| 283 |
+
model.train()
|
| 284 |
+
perm = torch.randperm(n, device=device)
|
| 285 |
+
tot = 0.0
|
| 286 |
+
for i in range(0, n, BATCH):
|
| 287 |
+
idx = perm[i:i + BATCH]
|
| 288 |
+
opt.zero_grad()
|
| 289 |
+
out = model(Xtr_t[idx])
|
| 290 |
+
loss = lossf(out, ytr_t[idx])
|
| 291 |
+
loss.backward(); opt.step()
|
| 292 |
+
tot += loss.item() * len(idx)
|
| 293 |
+
srcc = val_srcc()
|
| 294 |
+
if srcc > best_srcc:
|
| 295 |
+
best_srcc, best_state, bad = srcc, {k: v.cpu().clone() for k, v in model.state_dict().items()}, 0
|
| 296 |
+
else:
|
| 297 |
+
bad += 1
|
| 298 |
+
if ep % 5 == 0 or ep == 1:
|
| 299 |
+
print(f"epoch {ep:3d} | train MSE {tot/n:.4f} | val SRCC {srcc:.4f} | best {best_srcc:.4f}")
|
| 300 |
+
if bad >= PATIENCE:
|
| 301 |
+
print(f"Early stop ở epoch {ep} (val SRCC không tăng {PATIENCE} epoch).")
|
| 302 |
+
break
|
| 303 |
+
|
| 304 |
+
model.load_state_dict(best_state)
|
| 305 |
+
print(f"\n✅ VAL SRCC tốt nhất = {best_srcc:.4f} (baseline exp01 ≈ 0.194 — so ở đây)")
|
| 306 |
+
|
| 307 |
+
# Lưu model + tham số chuẩn hóa để tái dùng / mô tả hệ thống.
|
| 308 |
+
torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
|
| 309 |
+
"EMB_DIM": EMB_DIM, "FEAT_DIM": FEAT_DIM, "USE_CLASSPROB": USE_CLASSPROB,
|
| 310 |
+
"EMOTIONS5": EMOTIONS5, "val_srcc": float(best_srcc)},
|
| 311 |
+
os.path.join(OUT_DIR, "emos_head.pt"))
|
| 312 |
+
print("Đã lưu", os.path.join(OUT_DIR, "emos_head.pt"))
|
| 313 |
+
|
| 314 |
+
# %% [markdown]
|
| 315 |
+
# ## 6. Dự đoán DEV → `answer.txt` đầy đủ
|
| 316 |
+
# - **EMOS** = head vừa train (cần embedding + target của từng wav DEV).
|
| 317 |
+
# - **CAT** = xác suất 5 lớp emotion2vec (đã có sẵn khi trích đặc trưng).
|
| 318 |
+
# - **QMOS** = SpeechMOS (UTMOS) — bắt buộc, chạy thêm ở đây để answer.txt hợp lệ.
|
| 319 |
+
|
| 320 |
+
# %%
|
| 321 |
+
def list_dev():
|
| 322 |
+
with open(DEV_SCP) as f:
|
| 323 |
+
return [ln.strip() for ln in f if ln.strip()] # tên file .wav
|
| 324 |
+
|
| 325 |
+
dev_names = list_dev()
|
| 326 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 327 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 328 |
+
|
| 329 |
+
# 6a. Trích đặc trưng emotion2vec cho DEV (cache riêng)
|
| 330 |
+
dev_feat = extract_set(dev_stems, "dev")
|
| 331 |
+
|
| 332 |
+
# 6b. EMOS từ head đã train
|
| 333 |
+
def predict_emos(stem_id):
|
| 334 |
+
f = build_feature(stem_id, dev_feat)
|
| 335 |
+
if f is None:
|
| 336 |
+
return None
|
| 337 |
+
fn = (f[None, :] - feat_mean) / feat_std
|
| 338 |
+
model.eval()
|
| 339 |
+
with torch.no_grad():
|
| 340 |
+
return float(model(torch.tensor(fn, dtype=torch.float32, device=device)).item())
|
| 341 |
+
|
| 342 |
+
# 6c. QMOS = SpeechMOS
|
| 343 |
+
def run_qmos(names):
|
| 344 |
+
import librosa
|
| 345 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True)
|
| 346 |
+
out = {}
|
| 347 |
+
from tqdm.auto import tqdm
|
| 348 |
+
for n in tqdm(names, desc="QMOS"):
|
| 349 |
+
p = os.path.join(WAV_DIR, n)
|
| 350 |
+
if not os.path.exists(p):
|
| 351 |
+
continue
|
| 352 |
+
wave, _ = librosa.load(p, sr=16000, mono=True)
|
| 353 |
+
out[n] = float(predictor(torch.from_numpy(wave).unsqueeze(0), sr=16000).mean().item())
|
| 354 |
+
return out
|
| 355 |
+
|
| 356 |
+
qmos_scores = run_qmos(dev_names)
|
| 357 |
+
|
| 358 |
+
# %%
|
| 359 |
+
def fmt_cat(probs5):
|
| 360 |
+
return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 361 |
+
|
| 362 |
+
def build_answer(out_path):
|
| 363 |
+
n_emos = n_default = 0
|
| 364 |
+
with open(out_path, "w") as f:
|
| 365 |
+
f.write("wav,QMOS,EMOS,CAT\n")
|
| 366 |
+
for name in dev_names:
|
| 367 |
+
sid = stem(name)
|
| 368 |
+
emos = predict_emos(sid)
|
| 369 |
+
if emos is None:
|
| 370 |
+
emos = 3.0; n_default += 1
|
| 371 |
+
else:
|
| 372 |
+
n_emos += 1
|
| 373 |
+
qmos = qmos_scores.get(name, 3.0)
|
| 374 |
+
probs5 = dev_feat[sid][1] if sid in dev_feat else np.full(5, 0.2, dtype=np.float32)
|
| 375 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(probs5)}\n")
|
| 376 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default}")
|
| 377 |
+
|
| 378 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 379 |
+
build_answer(answer_path)
|
| 380 |
+
|
| 381 |
+
# %% [markdown]
|
| 382 |
+
# ## 7. Validate + đóng zip
|
| 383 |
+
|
| 384 |
+
# %%
|
| 385 |
+
def validate(path):
|
| 386 |
+
import csv
|
| 387 |
+
with open(path) as f:
|
| 388 |
+
rows = list(csv.reader(f))
|
| 389 |
+
header = rows[0]
|
| 390 |
+
assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
|
| 391 |
+
for i, r in enumerate(rows[1:], 2):
|
| 392 |
+
assert len(r) == len(header), f"Dòng {i} sai số cột"
|
| 393 |
+
print(f"OK: {len(rows)-1} dòng, header = {header}")
|
| 394 |
+
|
| 395 |
+
validate(answer_path)
|
| 396 |
+
# !cd /kaggle/working && zip -j submission_track2_exp02.zip answer.txt && unzip -l submission_track2_exp02.zip
|
| 397 |
+
print("Sẵn sàng nộp: /kaggle/working/submission_track2_exp02.zip")
|
| 398 |
+
|
| 399 |
+
# %% [markdown]
|
| 400 |
+
# ## Ghi chú
|
| 401 |
+
# - **VAL SRCC** in ở mục 5 là ước lượng nội bộ (10% train) — so với baseline 0.194 để biết có khá hơn không.
|
| 402 |
+
# Điểm DEV thật phải nộp lên CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).
|
| 403 |
+
# - Muốn thử nhanh: đặt `LIMIT_TRAIN = 300` ở cell 0.
|
| 404 |
+
# - Embedding đã cache trong `/kaggle/working/emb_cache/` → **Save Version** để giữ, lần sau train head khỏi trích lại.
|
| 405 |
+
# - Hướng cải tiến tiếp: thêm head QMOS/CAT/VAD dùng chung backbone (exp02 multi-task đầy đủ);
|
| 406 |
+
# thử backbone wav2vec2/WavLM; thêm ranking loss; fine-tune nhẹ backbone.
|
| 407 |
+
# - Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp02).
|
track2/exp03_emos_sailer.ipynb
ADDED
|
@@ -0,0 +1,392 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "a6ae46f8",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp03 (EMOS bằng SAILER, offline) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục tiêu:** chấm **EMOS** (độ khớp cảm xúc target) bằng model **SAILER**\n",
|
| 11 |
+
"(`tiantiaf/wavlm-large-categorical-emotion`, vô địch Interspeech 2025 SER),\n",
|
| 12 |
+
"thay cho emotion2vec — KHÔNG train, chỉ lấy xác suất lớp cảm xúc target.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Ý tưởng (đọc 1 lần cho hiểu)\n",
|
| 15 |
+
"SAILER nhận 1 wav → xuất **logits 9 lớp cảm xúc** → softmax → **xác suất từng lớp**.\n",
|
| 16 |
+
"EMOS = mức khớp cảm xúc target → lấy thẳng **P(cảm xúc target)** rồi kéo về thang 1–5:\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"```\n",
|
| 19 |
+
"mỗi wav ─► SAILER (WavLM-large) ─► softmax 9 lớp ─┬─► P(target) ─► EMOS = 1 + 4·P\n",
|
| 20 |
+
" └─► 5 lớp (renorm) ─► CAT\n",
|
| 21 |
+
" target emotion (metadata.csv) ─────────────────┘\n",
|
| 22 |
+
"```\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"- **9 lớp SAILER:** `Anger, Contempt, Disgust, Fear, Happiness, Neutral, Sadness, Surprise, Other`.\n",
|
| 25 |
+
" → đủ cả 5 lớp challenge (angry/happy/neutral/sad/surprised).\n",
|
| 26 |
+
"- **EMOS** = `1 + 4·P(target)` (scale [0,1]→[1,5]); SRCC bất biến với scale tuyến tính.\n",
|
| 27 |
+
"- **CAT** = lấy xác suất 5 lớp challenge từ chính SAILER (renormalize tổng=1).\n",
|
| 28 |
+
"- **VAD** = arousal/valence/dominance SAILER xuất sẵn (sigmoid 0–1 → 1–5) → 1 model lo EMOS+CAT+VAD!\n",
|
| 29 |
+
"- **QMOS** = SpeechMOS (UTMOS) — bắt buộc để `answer.txt` hợp lệ.\n",
|
| 30 |
+
"- KHÔNG train → nộp được ngay. So điểm EMOS với baseline emotion2vec (0.194) và exp01.\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"**Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**\n",
|
| 33 |
+
"→ + Add Input dataset Track 2 (15.477 wav, có `sets/dev.scp`, `metadata.csv`)\n",
|
| 34 |
+
"→ sửa `DATA_ROOT` ở cell 0 → Run All.\n",
|
| 35 |
+
"\n",
|
| 36 |
+
"⚠️ License SAILER = **Open RAIL** (phi thương mại) → phải khai báo trong `docs/12_system_description.md`."
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "markdown",
|
| 41 |
+
"id": "f25f6ed7",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"source": [
|
| 44 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"id": "4cb1fd8c",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"import os\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"# ── Data Track 2 (dataset 15.477 wav đã ráp) ────────────────────────────────\n",
|
| 57 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
|
| 58 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 59 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) → target emotion\n",
|
| 60 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV (tập cần nộp ở training phase)\n",
|
| 61 |
+
"\n",
|
| 62 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"DEVICE = \"cuda\" # \"cuda\" trên Kaggle GPU; \"cpu\" nếu không có GPU\n",
|
| 65 |
+
"MAX_SECONDS = 15 # SAILER nhận tối đa 15s (giới hạn của model)\n",
|
| 66 |
+
"SR = 16000 # SAILER cần 16kHz mono\n",
|
| 67 |
+
"LIMIT = None # đặt số nhỏ (vd 20) để chạy thử nhanh; None = full DEV\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"# 5 lớp cảm xúc challenge (thứ tự cố định cho cột CAT)\n",
|
| 70 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"# 9 lớp SAILER (đúng thứ tự model xuất) + chỉ số của 5 lớp challenge trong đó\n",
|
| 73 |
+
"SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
|
| 74 |
+
"EMO2SAILER = {\"angry\": 0, \"happy\": 4, \"neutral\": 5, \"sad\": 6, \"surprised\": 7} # EMOTIONS5 → index trong SAILER9\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"_EMO_ALIAS = {\n",
|
| 77 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 78 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 79 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 80 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 81 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 82 |
+
"}\n",
|
| 83 |
+
"\n",
|
| 84 |
+
"def norm_emotion(label):\n",
|
| 85 |
+
" \"\"\"Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp.\"\"\"\n",
|
| 86 |
+
" key = str(label).strip().lower()\n",
|
| 87 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"def stem(path_or_name):\n",
|
| 90 |
+
" return os.path.splitext(os.path.basename(str(path_or_name)))[0]\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 93 |
+
"for p in [WAV_DIR, METADATA_CSV, DEV_SCP]:\n",
|
| 94 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "markdown",
|
| 99 |
+
"id": "18c48274",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"source": [
|
| 102 |
+
"## 1. Cài đặt + tải code SAILER\n",
|
| 103 |
+
"SAILER cần file `WavLMWrapper` trong repo `vox-profile-release`.\n",
|
| 104 |
+
"⚠️ **KHÔNG** `pip install -e .` (build wheel của repo hay lỗi trên Kaggle). Thay vào đó:\n",
|
| 105 |
+
"chỉ **clone + thêm repo vào `sys.path`** rồi cài đúng vài thư viện model cần\n",
|
| 106 |
+
"(`transformers/torch/huggingface_hub` Kaggle đã có sẵn; chỉ thiếu `loralib`, `speechbrain`)."
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
{
|
| 110 |
+
"cell_type": "code",
|
| 111 |
+
"execution_count": null,
|
| 112 |
+
"id": "bd8f98d9",
|
| 113 |
+
"metadata": {},
|
| 114 |
+
"outputs": [],
|
| 115 |
+
"source": [
|
| 116 |
+
"import sys, subprocess\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"def pip_install(*pkgs):\n",
|
| 119 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 122 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 123 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 124 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"# Deps mà WavLMWrapper cần (xem import trong src/model/emotion/wavlm_emotion.py) + thư viện chấm QMOS.\n",
|
| 127 |
+
"pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\", \"scipy\", \"tqdm\")\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"if REPO_DIR not in sys.path:\n",
|
| 130 |
+
" sys.path.insert(0, REPO_DIR) # để `from src.model.emotion... import WavLMWrapper` chạy được"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "markdown",
|
| 135 |
+
"id": "00a49544",
|
| 136 |
+
"metadata": {},
|
| 137 |
+
"source": [
|
| 138 |
+
"## 2. Nạp model SAILER"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
{
|
| 142 |
+
"cell_type": "code",
|
| 143 |
+
"execution_count": null,
|
| 144 |
+
"id": "2756567f",
|
| 145 |
+
"metadata": {
|
| 146 |
+
"lines_to_next_cell": 1
|
| 147 |
+
},
|
| 148 |
+
"outputs": [],
|
| 149 |
+
"source": [
|
| 150 |
+
"import torch\n",
|
| 151 |
+
"import torch.nn.functional as F\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 154 |
+
"print(\"Device:\", device)\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device)\n",
|
| 159 |
+
"sailer.eval()\n",
|
| 160 |
+
"print(\"✅ Đã nạp SAILER (wavlm-large-categorical-emotion)\")"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "markdown",
|
| 165 |
+
"id": "6c7b6e84",
|
| 166 |
+
"metadata": {},
|
| 167 |
+
"source": [
|
| 168 |
+
"## 3. Đọc cảm xúc target cho mỗi wav (từ metadata.csv)"
|
| 169 |
+
]
|
| 170 |
+
},
|
| 171 |
+
{
|
| 172 |
+
"cell_type": "code",
|
| 173 |
+
"execution_count": null,
|
| 174 |
+
"id": "8e6f4b36",
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"outputs": [],
|
| 177 |
+
"source": [
|
| 178 |
+
"def load_target_emotions():\n",
|
| 179 |
+
" \"\"\"metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}.\"\"\"\n",
|
| 180 |
+
" tgt = {}\n",
|
| 181 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 182 |
+
" for ln in f:\n",
|
| 183 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 184 |
+
" if len(parts) < 2:\n",
|
| 185 |
+
" continue\n",
|
| 186 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 187 |
+
" return tgt\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"target_map = load_target_emotions()\n",
|
| 190 |
+
"print(f\"Target emotions: {len(target_map)} wav | ví dụ:\", dict(list(target_map.items())[:3]))"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"cell_type": "markdown",
|
| 195 |
+
"id": "7467cfc6",
|
| 196 |
+
"metadata": {},
|
| 197 |
+
"source": [
|
| 198 |
+
"## 4. Hàm chấm 1 wav bằng SAILER → xác suất 9 lớp + VAD\n",
|
| 199 |
+
"WavLMWrapper khi `return_feature=True` trả **6 giá trị**:\n",
|
| 200 |
+
"`predicted(logits 9 lớp), features, detailed_logits, arousal, valence, dominance` (VAD sigmoid 0–1).\n",
|
| 201 |
+
"→ 1 model lo cả **EMOS** (P target), **CAT** (5 lớp renorm) **và VAD** (mở 3 cột đang trống!)."
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": null,
|
| 207 |
+
"id": "d3bb3d01",
|
| 208 |
+
"metadata": {
|
| 209 |
+
"lines_to_next_cell": 1
|
| 210 |
+
},
|
| 211 |
+
"outputs": [],
|
| 212 |
+
"source": [
|
| 213 |
+
"import numpy as np\n",
|
| 214 |
+
"import librosa\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"@torch.no_grad()\n",
|
| 217 |
+
"def sailer_infer(wav_path):\n",
|
| 218 |
+
" \"\"\"→ (probs9: float32[9], vad3: float32[3] theo thứ tự [VAL,ARO,DOM] thang 1–5);\n",
|
| 219 |
+
" None nếu thiếu/lỗi file.\"\"\"\n",
|
| 220 |
+
" if not os.path.exists(wav_path):\n",
|
| 221 |
+
" return None\n",
|
| 222 |
+
" wave, _ = librosa.load(wav_path, sr=SR, mono=True)\n",
|
| 223 |
+
" wave = wave[: MAX_SECONDS * SR] # cắt tối đa 15s\n",
|
| 224 |
+
" data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
|
| 225 |
+
" logits, _feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
|
| 226 |
+
" probs9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
|
| 227 |
+
" # VAD sigmoid [0,1] → thang 1–5 cho khớp ví dụ BTC (SRCC bất biến với scale tuyến tính)\n",
|
| 228 |
+
" v, a, d = float(valence.item()), float(arousal.item()), float(dominance.item())\n",
|
| 229 |
+
" vad3 = np.array([1 + 4 * v, 1 + 4 * a, 1 + 4 * d], dtype=np.float32) # [VAL, ARO, DOM]\n",
|
| 230 |
+
" return probs9, vad3\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"def emos_from_probs(probs9, target):\n",
|
| 233 |
+
" \"\"\"EMOS = 1 + 4·P(target). None nếu không biết target → để caller xử lý mặc định.\"\"\"\n",
|
| 234 |
+
" if target is None or target not in EMO2SAILER:\n",
|
| 235 |
+
" return None\n",
|
| 236 |
+
" return 1.0 + 4.0 * float(probs9[EMO2SAILER[target]])\n",
|
| 237 |
+
"\n",
|
| 238 |
+
"def cat5_from_probs(probs9):\n",
|
| 239 |
+
" \"\"\"Lấy 5 lớp challenge từ 9 lớp SAILER rồi renormalize tổng=1.\"\"\"\n",
|
| 240 |
+
" v = np.array([probs9[EMO2SAILER[e]] for e in EMOTIONS5], dtype=np.float32)\n",
|
| 241 |
+
" s = v.sum()\n",
|
| 242 |
+
" return v / s if s > 0 else np.full(5, 0.2, dtype=np.float32)"
|
| 243 |
+
]
|
| 244 |
+
},
|
| 245 |
+
{
|
| 246 |
+
"cell_type": "markdown",
|
| 247 |
+
"id": "14f8e54a",
|
| 248 |
+
"metadata": {},
|
| 249 |
+
"source": [
|
| 250 |
+
"## 5. QMOS = SpeechMOS (UTMOS) — bắt buộc cho answer.txt"
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "code",
|
| 255 |
+
"execution_count": null,
|
| 256 |
+
"id": "992bd84b",
|
| 257 |
+
"metadata": {
|
| 258 |
+
"lines_to_next_cell": 1
|
| 259 |
+
},
|
| 260 |
+
"outputs": [],
|
| 261 |
+
"source": [
|
| 262 |
+
"@torch.no_grad()\n",
|
| 263 |
+
"def run_qmos(names):\n",
|
| 264 |
+
" predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True).to(device).eval()\n",
|
| 265 |
+
" from tqdm.auto import tqdm\n",
|
| 266 |
+
" out = {}\n",
|
| 267 |
+
" for n in tqdm(names, desc=\"QMOS\"):\n",
|
| 268 |
+
" p = os.path.join(WAV_DIR, n)\n",
|
| 269 |
+
" if not os.path.exists(p):\n",
|
| 270 |
+
" continue\n",
|
| 271 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 272 |
+
" x = torch.from_numpy(wave).unsqueeze(0).to(device) # đẩy input lên GPU\n",
|
| 273 |
+
" out[n] = float(predictor(x, sr=SR).mean().item())\n",
|
| 274 |
+
" return out"
|
| 275 |
+
]
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"cell_type": "markdown",
|
| 279 |
+
"id": "58afec1d",
|
| 280 |
+
"metadata": {},
|
| 281 |
+
"source": [
|
| 282 |
+
"## 6. Chạy trên DEV → `answer.txt` đầy đủ (QMOS, EMOS, CAT, VAL, ARO, DOM)"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "code",
|
| 287 |
+
"execution_count": null,
|
| 288 |
+
"id": "77b1eb8c",
|
| 289 |
+
"metadata": {
|
| 290 |
+
"lines_to_next_cell": 1
|
| 291 |
+
},
|
| 292 |
+
"outputs": [],
|
| 293 |
+
"source": [
|
| 294 |
+
"def list_dev():\n",
|
| 295 |
+
" with open(DEV_SCP) as f:\n",
|
| 296 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"dev_names = list_dev()\n",
|
| 299 |
+
"if LIMIT:\n",
|
| 300 |
+
" dev_names = dev_names[:LIMIT]\n",
|
| 301 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 302 |
+
"\n",
|
| 303 |
+
"qmos_scores = run_qmos(dev_names)\n",
|
| 304 |
+
"\n",
|
| 305 |
+
"def fmt_cat(probs5):\n",
|
| 306 |
+
" return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 307 |
+
"\n",
|
| 308 |
+
"def build_answer(out_path):\n",
|
| 309 |
+
" from tqdm.auto import tqdm\n",
|
| 310 |
+
" n_emos = n_default = 0\n",
|
| 311 |
+
" with open(out_path, \"w\") as f:\n",
|
| 312 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 313 |
+
" for name in tqdm(dev_names, desc=\"SAILER EMOS/CAT/VAD\"):\n",
|
| 314 |
+
" sid = stem(name)\n",
|
| 315 |
+
" out = sailer_infer(os.path.join(WAV_DIR, name))\n",
|
| 316 |
+
" if out is None:\n",
|
| 317 |
+
" emos, cat5 = 3.0, np.full(5, 0.2, dtype=np.float32)\n",
|
| 318 |
+
" vad3 = np.array([3.0, 3.0, 3.0], dtype=np.float32)\n",
|
| 319 |
+
" n_default += 1\n",
|
| 320 |
+
" else:\n",
|
| 321 |
+
" probs9, vad3 = out\n",
|
| 322 |
+
" emos = emos_from_probs(probs9, target_map.get(sid))\n",
|
| 323 |
+
" if emos is None:\n",
|
| 324 |
+
" emos = 3.0; n_default += 1\n",
|
| 325 |
+
" else:\n",
|
| 326 |
+
" n_emos += 1\n",
|
| 327 |
+
" cat5 = cat5_from_probs(probs9)\n",
|
| 328 |
+
" qmos = qmos_scores.get(name, 3.0)\n",
|
| 329 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
|
| 330 |
+
" f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 331 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default}\")\n",
|
| 332 |
+
"\n",
|
| 333 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 334 |
+
"build_answer(answer_path)"
|
| 335 |
+
]
|
| 336 |
+
},
|
| 337 |
+
{
|
| 338 |
+
"cell_type": "markdown",
|
| 339 |
+
"id": "5e3efd8d",
|
| 340 |
+
"metadata": {},
|
| 341 |
+
"source": [
|
| 342 |
+
"## 7. Validate + đóng zip"
|
| 343 |
+
]
|
| 344 |
+
},
|
| 345 |
+
{
|
| 346 |
+
"cell_type": "code",
|
| 347 |
+
"execution_count": null,
|
| 348 |
+
"id": "f816cbb2",
|
| 349 |
+
"metadata": {},
|
| 350 |
+
"outputs": [],
|
| 351 |
+
"source": [
|
| 352 |
+
"def validate(path):\n",
|
| 353 |
+
" import csv\n",
|
| 354 |
+
" with open(path) as f:\n",
|
| 355 |
+
" rows = list(csv.reader(f))\n",
|
| 356 |
+
" header = rows[0]\n",
|
| 357 |
+
" assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
|
| 358 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 359 |
+
" assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
|
| 360 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"validate(answer_path)\n",
|
| 363 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp03_sailer.zip answer.txt && unzip -l submission_track2_exp03_sailer.zip\")\n",
|
| 364 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp03_sailer.zip\"))"
|
| 365 |
+
]
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"cell_type": "markdown",
|
| 369 |
+
"id": "8ca84ef6",
|
| 370 |
+
"metadata": {},
|
| 371 |
+
"source": [
|
| 372 |
+
"## Ghi chú\n",
|
| 373 |
+
"- **Chưa chạy thật bao giờ** → lần đầu đặt `LIMIT = 20` ở cell 0 để bắt lỗi setup (clone repo / import / model).\n",
|
| 374 |
+
"- Điểm DEV thật phải nộp lên CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).\n",
|
| 375 |
+
"- Notebook này đổi **EMOS + CAT + VAD** sang SAILER (1 model lo 6 cột metric). QMOS vẫn SpeechMOS cũ.\n",
|
| 376 |
+
" Muốn ablation EMOS sạch (giữ CAT=emotion2vec) thì chỉ lấy cột EMOS từ đây, ghép với CAT của `track2_baseline`.\n",
|
| 377 |
+
"- Rủi ro setup duy nhất = import `src.model.emotion.wavlm_emotion` (cần repo vox-profile-release).\n",
|
| 378 |
+
" Nếu lỗi import: kiểm tra `REPO_DIR` đã clone + `sys.path` đã thêm REPO_DIR (KHÔNG dùng pip install -e .).\n",
|
| 379 |
+
"- Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp03)."
|
| 380 |
+
]
|
| 381 |
+
}
|
| 382 |
+
],
|
| 383 |
+
"metadata": {
|
| 384 |
+
"jupytext": {
|
| 385 |
+
"cell_metadata_filter": "-all",
|
| 386 |
+
"main_language": "python",
|
| 387 |
+
"notebook_metadata_filter": "-all"
|
| 388 |
+
}
|
| 389 |
+
},
|
| 390 |
+
"nbformat": 4,
|
| 391 |
+
"nbformat_minor": 5
|
| 392 |
+
}
|
track2/exp03_emos_sailer_pipeline.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp03 (EMOS bằng SAILER, offline) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục tiêu:** chấm **EMOS** (độ khớp cảm xúc target) bằng model **SAILER**
|
| 5 |
+
# (`tiantiaf/wavlm-large-categorical-emotion`, vô địch Interspeech 2025 SER),
|
| 6 |
+
# thay cho emotion2vec — KHÔNG train, chỉ lấy xác suất lớp cảm xúc target.
|
| 7 |
+
#
|
| 8 |
+
# ## Ý tưởng (đọc 1 lần cho hiểu)
|
| 9 |
+
# SAILER nhận 1 wav → xuất **logits 9 lớp cảm xúc** → softmax → **xác suất từng lớp**.
|
| 10 |
+
# EMOS = mức khớp cảm xúc target → lấy thẳng **P(cảm xúc target)** rồi kéo về thang 1–5:
|
| 11 |
+
#
|
| 12 |
+
# ```
|
| 13 |
+
# mỗi wav ─► SAILER (WavLM-large) ─► softmax 9 lớp ─┬─► P(target) ─► EMOS = 1 + 4·P
|
| 14 |
+
# └─► 5 lớp (renorm) ─► CAT
|
| 15 |
+
# target emotion (metadata.csv) ─────────────────┘
|
| 16 |
+
# ```
|
| 17 |
+
#
|
| 18 |
+
# - **9 lớp SAILER:** `Anger, Contempt, Disgust, Fear, Happiness, Neutral, Sadness, Surprise, Other`.
|
| 19 |
+
# → đủ cả 5 lớp challenge (angry/happy/neutral/sad/surprised).
|
| 20 |
+
# - **EMOS** = `1 + 4·P(target)` (scale [0,1]→[1,5]); SRCC bất biến với scale tuyến tính.
|
| 21 |
+
# - **CAT** = lấy xác suất 5 lớp challenge từ chính SAILER (renormalize tổng=1).
|
| 22 |
+
# - **VAD** = arousal/valence/dominance SAILER xuất sẵn (sigmoid 0–1 → 1–5) → 1 model lo EMOS+CAT+VAD!
|
| 23 |
+
# - **QMOS** = SpeechMOS (UTMOS) — bắt buộc để `answer.txt` hợp lệ.
|
| 24 |
+
# - KHÔNG train → nộp được ngay. So điểm EMOS với baseline emotion2vec (0.194) và exp01.
|
| 25 |
+
#
|
| 26 |
+
# **Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**
|
| 27 |
+
# → + Add Input dataset Track 2 (15.477 wav, có `sets/dev.scp`, `metadata.csv`)
|
| 28 |
+
# → sửa `DATA_ROOT` ở cell 0 → Run All.
|
| 29 |
+
#
|
| 30 |
+
# ⚠️ License SAILER = **Open RAIL** (phi thương mại) → phải khai báo trong `docs/12_system_description.md`.
|
| 31 |
+
|
| 32 |
+
# %% [markdown]
|
| 33 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 34 |
+
|
| 35 |
+
# %%
|
| 36 |
+
import os
|
| 37 |
+
|
| 38 |
+
# ── Data Track 2 (dataset 15.477 wav đã ráp) ────────────────────────────────
|
| 39 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
|
| 40 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 41 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) → target emotion
|
| 42 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV (tập cần nộp ở training phase)
|
| 43 |
+
|
| 44 |
+
OUT_DIR = "/kaggle/working"
|
| 45 |
+
|
| 46 |
+
DEVICE = "cuda" # "cuda" trên Kaggle GPU; "cpu" nếu không có GPU
|
| 47 |
+
MAX_SECONDS = 15 # SAILER nhận tối đa 15s (giới hạn của model)
|
| 48 |
+
SR = 16000 # SAILER cần 16kHz mono
|
| 49 |
+
LIMIT = None # đặt số nhỏ (vd 20) để chạy thử nhanh; None = full DEV
|
| 50 |
+
|
| 51 |
+
# 5 lớp cảm xúc challenge (thứ tự cố định cho cột CAT)
|
| 52 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 53 |
+
|
| 54 |
+
# 9 lớp SAILER (đúng thứ tự model xuất) + chỉ số của 5 lớp challenge trong đó
|
| 55 |
+
SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
|
| 56 |
+
EMO2SAILER = {"angry": 0, "happy": 4, "neutral": 5, "sad": 6, "surprised": 7} # EMOTIONS5 → index trong SAILER9
|
| 57 |
+
|
| 58 |
+
_EMO_ALIAS = {
|
| 59 |
+
"angry": "angry", "anger": "angry",
|
| 60 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 61 |
+
"neutral": "neutral", "calm": "neutral",
|
| 62 |
+
"sad": "sad", "sadness": "sad",
|
| 63 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
def norm_emotion(label):
|
| 67 |
+
"""Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp."""
|
| 68 |
+
key = str(label).strip().lower()
|
| 69 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 70 |
+
|
| 71 |
+
def stem(path_or_name):
|
| 72 |
+
return os.path.splitext(os.path.basename(str(path_or_name)))[0]
|
| 73 |
+
|
| 74 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 75 |
+
for p in [WAV_DIR, METADATA_CSV, DEV_SCP]:
|
| 76 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 77 |
+
|
| 78 |
+
# %% [markdown]
|
| 79 |
+
# ## 1. Cài đặt + tải code SAILER
|
| 80 |
+
# SAILER cần file `WavLMWrapper` trong repo `vox-profile-release`.
|
| 81 |
+
# ⚠️ **KHÔNG** `pip install -e .` (build wheel của repo hay lỗi trên Kaggle). Thay vào đó:
|
| 82 |
+
# chỉ **clone + thêm repo vào `sys.path`** rồi cài đúng vài thư viện model cần
|
| 83 |
+
# (`transformers/torch/huggingface_hub` Kaggle đã có sẵn; chỉ thiếu `loralib`, `speechbrain`).
|
| 84 |
+
|
| 85 |
+
# %%
|
| 86 |
+
import sys, subprocess
|
| 87 |
+
|
| 88 |
+
def pip_install(*pkgs):
|
| 89 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 90 |
+
|
| 91 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 92 |
+
if not os.path.exists(REPO_DIR):
|
| 93 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 94 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 95 |
+
|
| 96 |
+
# Deps mà WavLMWrapper cần (xem import trong src/model/emotion/wavlm_emotion.py) + thư viện chấm QMOS.
|
| 97 |
+
pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile", "scipy", "tqdm")
|
| 98 |
+
|
| 99 |
+
if REPO_DIR not in sys.path:
|
| 100 |
+
sys.path.insert(0, REPO_DIR) # để `from src.model.emotion... import WavLMWrapper` chạy được
|
| 101 |
+
|
| 102 |
+
# %% [markdown]
|
| 103 |
+
# ## 2. Nạp model SAILER
|
| 104 |
+
|
| 105 |
+
# %%
|
| 106 |
+
import torch
|
| 107 |
+
import torch.nn.functional as F
|
| 108 |
+
|
| 109 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 110 |
+
print("Device:", device)
|
| 111 |
+
|
| 112 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 113 |
+
|
| 114 |
+
sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device)
|
| 115 |
+
sailer.eval()
|
| 116 |
+
print("✅ Đã nạp SAILER (wavlm-large-categorical-emotion)")
|
| 117 |
+
|
| 118 |
+
# %% [markdown]
|
| 119 |
+
# ## 3. Đọc cảm xúc target cho mỗi wav (từ metadata.csv)
|
| 120 |
+
|
| 121 |
+
# %%
|
| 122 |
+
def load_target_emotions():
|
| 123 |
+
"""metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}."""
|
| 124 |
+
tgt = {}
|
| 125 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 126 |
+
for ln in f:
|
| 127 |
+
parts = ln.strip().split("|")
|
| 128 |
+
if len(parts) < 2:
|
| 129 |
+
continue
|
| 130 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 131 |
+
return tgt
|
| 132 |
+
|
| 133 |
+
target_map = load_target_emotions()
|
| 134 |
+
print(f"Target emotions: {len(target_map)} wav | ví dụ:", dict(list(target_map.items())[:3]))
|
| 135 |
+
|
| 136 |
+
# %% [markdown]
|
| 137 |
+
# ## 4. Hàm chấm 1 wav bằng SAILER → xác suất 9 lớp + VAD
|
| 138 |
+
# WavLMWrapper khi `return_feature=True` trả **6 giá trị**:
|
| 139 |
+
# `predicted(logits 9 lớp), features, detailed_logits, arousal, valence, dominance` (VAD sigmoid 0–1).
|
| 140 |
+
# → 1 model lo cả **EMOS** (P target), **CAT** (5 lớp renorm) **và VAD** (mở 3 cột đang trống!).
|
| 141 |
+
|
| 142 |
+
# %%
|
| 143 |
+
import numpy as np
|
| 144 |
+
import librosa
|
| 145 |
+
|
| 146 |
+
@torch.no_grad()
|
| 147 |
+
def sailer_infer(wav_path):
|
| 148 |
+
"""→ (probs9: float32[9], vad3: float32[3] theo thứ tự [VAL,ARO,DOM] thang 1–5);
|
| 149 |
+
None nếu thiếu/lỗi file."""
|
| 150 |
+
if not os.path.exists(wav_path):
|
| 151 |
+
return None
|
| 152 |
+
wave, _ = librosa.load(wav_path, sr=SR, mono=True)
|
| 153 |
+
wave = wave[: MAX_SECONDS * SR] # cắt tối đa 15s
|
| 154 |
+
data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
|
| 155 |
+
logits, _feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
|
| 156 |
+
probs9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
|
| 157 |
+
# VAD sigmoid [0,1] → thang 1–5 cho khớp ví dụ BTC (SRCC bất biến với scale tuyến tính)
|
| 158 |
+
v, a, d = float(valence.item()), float(arousal.item()), float(dominance.item())
|
| 159 |
+
vad3 = np.array([1 + 4 * v, 1 + 4 * a, 1 + 4 * d], dtype=np.float32) # [VAL, ARO, DOM]
|
| 160 |
+
return probs9, vad3
|
| 161 |
+
|
| 162 |
+
def emos_from_probs(probs9, target):
|
| 163 |
+
"""EMOS = 1 + 4·P(target). None nếu không biết target → để caller xử lý mặc định."""
|
| 164 |
+
if target is None or target not in EMO2SAILER:
|
| 165 |
+
return None
|
| 166 |
+
return 1.0 + 4.0 * float(probs9[EMO2SAILER[target]])
|
| 167 |
+
|
| 168 |
+
def cat5_from_probs(probs9):
|
| 169 |
+
"""Lấy 5 lớp challenge từ 9 lớp SAILER rồi renormalize tổng=1."""
|
| 170 |
+
v = np.array([probs9[EMO2SAILER[e]] for e in EMOTIONS5], dtype=np.float32)
|
| 171 |
+
s = v.sum()
|
| 172 |
+
return v / s if s > 0 else np.full(5, 0.2, dtype=np.float32)
|
| 173 |
+
|
| 174 |
+
# %% [markdown]
|
| 175 |
+
# ## 5. QMOS = SpeechMOS (UTMOS) — bắt buộc cho answer.txt
|
| 176 |
+
|
| 177 |
+
# %%
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def run_qmos(names):
|
| 180 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).to(device).eval()
|
| 181 |
+
from tqdm.auto import tqdm
|
| 182 |
+
out = {}
|
| 183 |
+
for n in tqdm(names, desc="QMOS"):
|
| 184 |
+
p = os.path.join(WAV_DIR, n)
|
| 185 |
+
if not os.path.exists(p):
|
| 186 |
+
continue
|
| 187 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 188 |
+
x = torch.from_numpy(wave).unsqueeze(0).to(device) # đẩy input lên GPU
|
| 189 |
+
out[n] = float(predictor(x, sr=SR).mean().item())
|
| 190 |
+
return out
|
| 191 |
+
|
| 192 |
+
# %% [markdown]
|
| 193 |
+
# ## 6. Chạy trên DEV → `answer.txt` đầy đủ (QMOS, EMOS, CAT, VAL, ARO, DOM)
|
| 194 |
+
|
| 195 |
+
# %%
|
| 196 |
+
def list_dev():
|
| 197 |
+
with open(DEV_SCP) as f:
|
| 198 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 199 |
+
|
| 200 |
+
dev_names = list_dev()
|
| 201 |
+
if LIMIT:
|
| 202 |
+
dev_names = dev_names[:LIMIT]
|
| 203 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 204 |
+
|
| 205 |
+
qmos_scores = run_qmos(dev_names)
|
| 206 |
+
|
| 207 |
+
def fmt_cat(probs5):
|
| 208 |
+
return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 209 |
+
|
| 210 |
+
def build_answer(out_path):
|
| 211 |
+
from tqdm.auto import tqdm
|
| 212 |
+
n_emos = n_default = 0
|
| 213 |
+
with open(out_path, "w") as f:
|
| 214 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 215 |
+
for name in tqdm(dev_names, desc="SAILER EMOS/CAT/VAD"):
|
| 216 |
+
sid = stem(name)
|
| 217 |
+
out = sailer_infer(os.path.join(WAV_DIR, name))
|
| 218 |
+
if out is None:
|
| 219 |
+
emos, cat5 = 3.0, np.full(5, 0.2, dtype=np.float32)
|
| 220 |
+
vad3 = np.array([3.0, 3.0, 3.0], dtype=np.float32)
|
| 221 |
+
n_default += 1
|
| 222 |
+
else:
|
| 223 |
+
probs9, vad3 = out
|
| 224 |
+
emos = emos_from_probs(probs9, target_map.get(sid))
|
| 225 |
+
if emos is None:
|
| 226 |
+
emos = 3.0; n_default += 1
|
| 227 |
+
else:
|
| 228 |
+
n_emos += 1
|
| 229 |
+
cat5 = cat5_from_probs(probs9)
|
| 230 |
+
qmos = qmos_scores.get(name, 3.0)
|
| 231 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
|
| 232 |
+
f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 233 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default}")
|
| 234 |
+
|
| 235 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 236 |
+
build_answer(answer_path)
|
| 237 |
+
|
| 238 |
+
# %% [markdown]
|
| 239 |
+
# ## 7. Validate + đóng zip
|
| 240 |
+
|
| 241 |
+
# %%
|
| 242 |
+
def validate(path):
|
| 243 |
+
import csv
|
| 244 |
+
with open(path) as f:
|
| 245 |
+
rows = list(csv.reader(f))
|
| 246 |
+
header = rows[0]
|
| 247 |
+
assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
|
| 248 |
+
for i, r in enumerate(rows[1:], 2):
|
| 249 |
+
assert len(r) == len(header), f"Dòng {i} sai số cột"
|
| 250 |
+
print(f"OK: {len(rows)-1} dòng, header = {header}")
|
| 251 |
+
|
| 252 |
+
validate(answer_path)
|
| 253 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp03_sailer.zip answer.txt && unzip -l submission_track2_exp03_sailer.zip")
|
| 254 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp03_sailer.zip"))
|
| 255 |
+
|
| 256 |
+
# %% [markdown]
|
| 257 |
+
# ## Ghi chú
|
| 258 |
+
# - **Chưa chạy thật bao giờ** → lần đầu đặt `LIMIT = 20` ở cell 0 để bắt lỗi setup (clone repo / import / model).
|
| 259 |
+
# - Điểm DEV thật phải nộp lên CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).
|
| 260 |
+
# - Notebook này đổi **EMOS + CAT + VAD** sang SAILER (1 model lo 6 cột metric). QMOS vẫn SpeechMOS cũ.
|
| 261 |
+
# Muốn ablation EMOS sạch (giữ CAT=emotion2vec) thì chỉ lấy cột EMOS từ đây, ghép với CAT của `track2_baseline`.
|
| 262 |
+
# - Rủi ro setup duy nhất = import `src.model.emotion.wavlm_emotion` (cần repo vox-profile-release).
|
| 263 |
+
# Nếu lỗi import: kiểm tra `REPO_DIR` đã clone + `sys.path` đã thêm REPO_DIR (KHÔNG dùng pip install -e .).
|
| 264 |
+
# - Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp03).
|
track2/exp04_fusion.ipynb
ADDED
|
@@ -0,0 +1,790 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "d85dcf89",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp04 (FUSION multi-task) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục tiêu:** gộp 2 backbone bổ sung nhau (**emotion2vec** thắng EMOS · **SAILER/WavLM** thắng VAD)\n",
|
| 11 |
+
"thành **1 model multi-task** dự đoán chung 5 đầu ra cảm xúc: **EMOS · CAT · VAL · ARO · DOM**.\n",
|
| 12 |
+
"QMOS để **riêng** (giữ SpeechMOS) — đúng thiết kế đã chốt: *\"QMOS riêng + 5 cảm xúc chung\"*.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Ý tưởng (đọc 1 lần cho hiểu)\n",
|
| 15 |
+
"Bằng chứng để fusion (từ exp01 & exp03): emotion2vec đứng đầu **EMOS** (0.637), SAILER đứng đầu\n",
|
| 16 |
+
"**VAD** (ARO 0.712 / DOM 0.630). Hai model \"nhìn\" cảm xúc theo cách khác nhau → **nối đặc trưng**\n",
|
| 17 |
+
"của cả hai rồi cho một mạng nhỏ học → kỳ vọng mạnh hơn từng model lẻ.\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"```\n",
|
| 20 |
+
" ┌─ emotion2vec ─► embedding ~D1 + xác suất 5 lớp ─┐\n",
|
| 21 |
+
" mỗi wav ──────►│ ├─► NỐI ─► TRUNK chung\n",
|
| 22 |
+
" └─ SAILER(WavLM) ► embedding ~D2 + 9 lớp + VAD3 ─┘ (Linear+ReLU)\n",
|
| 23 |
+
" │\n",
|
| 24 |
+
" ┌───────────────────────────────────────────────┤\n",
|
| 25 |
+
" target emotion(one-hot)│ │\n",
|
| 26 |
+
" ▼ ▼\n",
|
| 27 |
+
" [EMOS head] [CAT head] [VAD head]\n",
|
| 28 |
+
" (cần target) (5 lớp) (VAL/ARO/DOM)\n",
|
| 29 |
+
"```\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"- **Cả 2 backbone ĐÓNG BĂNG** → chỉ trích đặc trưng (cache `.npz`), **chỉ train phần trunk + head nhỏ**\n",
|
| 32 |
+
" → nhẹ GPU, train vài phút, hợp T4. (Né fine-tune end-to-end lúc đầu.)\n",
|
| 33 |
+
"- **EMOS phụ thuộc target** (cùng audio, target khác → điểm khác) → EMOS head nhận thêm one-hot target.\n",
|
| 34 |
+
" **CAT/VAD** là cảm nhận về chính audio → chỉ cần trunk (không cần target).\n",
|
| 35 |
+
"- **Nhãn vàng** gộp theo `wavID` từ `sets/train.csv`:\n",
|
| 36 |
+
" EMOS = TB `eMOS` · VAL/ARO/DOM = TB `val/aro/dom` · CAT = **tỉ lệ vote 5 lớp** của `emoCat`.\n",
|
| 37 |
+
"- **Cân loss = uncertainty weighting** (Kendall 2018): mỗi task có 1 trọng số σ **tự học**\n",
|
| 38 |
+
" → không phải dò tay. Có cờ `USE_UNCERTAINTY=False` để quay về trọng số cố định khi cần debug.\n",
|
| 39 |
+
"- Cuối cùng xuất `answer.txt` **đủ 7 cột**: `wav,QMOS,EMOS,CAT,VAL,ARO,DOM`\n",
|
| 40 |
+
" (QMOS=SpeechMOS · 5 cột còn lại = model fusion) → nộp được ngay. So mốc: EMOS 0.637 · VAD ARO 0.712.\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"**Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**\n",
|
| 43 |
+
"→ + Add Input dataset Track 2 (15.477 wav, có `sets/train.csv`, `sets/dev.scp`, `metadata.csv`)\n",
|
| 44 |
+
"→ sửa `DATA_ROOT` ở cell 0 → Run All. Lần đầu nên đặt `LIMIT_TRAIN = 300`, `LIMIT_DEV = 20` để bắt lỗi setup."
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "markdown",
|
| 49 |
+
"id": "5101bb4e",
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"source": [
|
| 52 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 53 |
+
]
|
| 54 |
+
},
|
| 55 |
+
{
|
| 56 |
+
"cell_type": "code",
|
| 57 |
+
"execution_count": null,
|
| 58 |
+
"id": "3fee9b16",
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"outputs": [],
|
| 61 |
+
"source": [
|
| 62 |
+
"import os\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"# ── Data Track 2 (dataset 15.477 wav đã ráp, có sets/train.csv) ──────────────\n",
|
| 65 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
|
| 66 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 67 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) → target emotion\n",
|
| 68 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # nhãn người nghe: lisID,wavID,qMOS,emoCat,eMOS,val,dom,aro\n",
|
| 69 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV (tập cần nộp ở training phase)\n",
|
| 70 |
+
"\n",
|
| 71 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 72 |
+
"CACHE_DIR = \"/kaggle/working/fusion_cache\" # cache embedding 2 backbone (tái dùng giữa các lần chạy)\n",
|
| 73 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"# ── Siêu tham số train ───────────────────────────────────────────────────────\n",
|
| 76 |
+
"DEVICE = \"cuda\" # \"cuda\" trên Kaggle GPU; \"cpu\" nếu không có GPU\n",
|
| 77 |
+
"TRUNK_HIDDEN = 512 # số neuron lớp trunk chung\n",
|
| 78 |
+
"HEAD_HIDDEN = 128 # số neuron lớp ẩn mỗi head\n",
|
| 79 |
+
"DROPOUT = 0.3\n",
|
| 80 |
+
"LR = 1e-3\n",
|
| 81 |
+
"EPOCHS = 80\n",
|
| 82 |
+
"BATCH = 64\n",
|
| 83 |
+
"VAL_FRAC = 0.10 # 10% train → validation nội bộ (đo SRCC từng task)\n",
|
| 84 |
+
"PATIENCE = 15 # early stop theo điểm tổng val (xem SCORE_FOR_STOP)\n",
|
| 85 |
+
"SEED = 42\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"USE_UNCERTAINTY = True # True = tự cân loss (Kendall); False = dùng LOSS_W cố định bên dưới\n",
|
| 88 |
+
"LOSS_W = {\"emos\": 1.0, \"cat\": 1.0, \"val\": 1.0, \"aro\": 1.0, \"dom\": 1.0} # chỉ dùng khi tắt uncertainty\n",
|
| 89 |
+
"USE_E2V = True # bật/tắt nhánh emotion2vec trong fusion (để ablation)\n",
|
| 90 |
+
"USE_SAILER = True # bật/tắt nhánh SAILER trong fusion (để ablation)\n",
|
| 91 |
+
"USE_CLASSPROB = True # thêm xác suất lớp (e2v 5 + sailer 9) + VAD3 của SAILER vào feature\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"LIMIT_TRAIN = None # đặt số nhỏ (vd 300) để chạy thử nhanh; None = full\n",
|
| 94 |
+
"LIMIT_DEV = None # đặt số nhỏ (vd 20) để chạy thử nhanh; None = full\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"# 9 lớp SAILER (đúng thứ tự model xuất) + chỉ số của 5 lớp challenge trong đó\n",
|
| 99 |
+
"SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
|
| 100 |
+
"EMO2SAILER = {\"angry\": 0, \"happy\": 4, \"neutral\": 5, \"sad\": 6, \"surprised\": 7} # EMOTIONS5 → index trong SAILER9\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"_EMO_ALIAS = {\n",
|
| 103 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 104 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 105 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 106 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 107 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 108 |
+
"}\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"def norm_emotion(label):\n",
|
| 111 |
+
" \"\"\"Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp.\"\"\"\n",
|
| 112 |
+
" key = str(label).strip().lower()\n",
|
| 113 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"def stem(path_or_name):\n",
|
| 116 |
+
" \"\"\"Lấy tên file không đuôi, để khớp wavID giữa train.csv / metadata / dev.scp.\"\"\"\n",
|
| 117 |
+
" return os.path.splitext(os.path.basename(str(path_or_name)))[0]\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"assert USE_E2V or USE_SAILER, \"Phải bật ít nhất 1 backbone (USE_E2V hoặc USE_SAILER).\"\n",
|
| 120 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 121 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 122 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 123 |
+
]
|
| 124 |
+
},
|
| 125 |
+
{
|
| 126 |
+
"cell_type": "markdown",
|
| 127 |
+
"id": "580854fb",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"source": [
|
| 130 |
+
"## 1. Cài đặt + tải code SAILER\n",
|
| 131 |
+
"emotion2vec qua `funasr` (offline). SAILER cần `WavLMWrapper` trong repo `vox-profile-release`\n",
|
| 132 |
+
"→ **clone + sys.path** (KHÔNG `pip install -e .` vì build wheel hay lỗi trên Kaggle)."
|
| 133 |
+
]
|
| 134 |
+
},
|
| 135 |
+
{
|
| 136 |
+
"cell_type": "code",
|
| 137 |
+
"execution_count": null,
|
| 138 |
+
"id": "a0ea1faa",
|
| 139 |
+
"metadata": {},
|
| 140 |
+
"outputs": [],
|
| 141 |
+
"source": [
|
| 142 |
+
"import sys, subprocess\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"def pip_install(*pkgs):\n",
|
| 145 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"pip_install(\"speechmos\", \"funasr\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"if USE_SAILER:\n",
|
| 150 |
+
" pip_install(\"loralib\", \"speechbrain\") # deps WavLMWrapper cần\n",
|
| 151 |
+
" REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 152 |
+
" if not os.path.exists(REPO_DIR):\n",
|
| 153 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 154 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 155 |
+
" if REPO_DIR not in sys.path:\n",
|
| 156 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "markdown",
|
| 161 |
+
"id": "43033a70",
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"source": [
|
| 164 |
+
"## 2. Đọc & gộp nhãn (gộp theo wavID)\n",
|
| 165 |
+
"- `train.csv`: mỗi dòng = 1 listener chấm 1 wav → gộp **theo wavID**:\n",
|
| 166 |
+
" EMOS=TB `eMOS` · VAL/ARO/DOM=TB `val/aro/dom` · CAT=**tỉ lệ vote 5 lớp** của `emoCat`.\n",
|
| 167 |
+
"- `metadata.csv`: lấy **cảm xúc target** cho mỗi wav (để feed EMOS head)."
|
| 168 |
+
]
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"cell_type": "code",
|
| 172 |
+
"execution_count": null,
|
| 173 |
+
"id": "d4051547",
|
| 174 |
+
"metadata": {},
|
| 175 |
+
"outputs": [],
|
| 176 |
+
"source": [
|
| 177 |
+
"import numpy as np\n",
|
| 178 |
+
"import pandas as pd\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"def load_target_emotions():\n",
|
| 181 |
+
" \"\"\"metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}.\"\"\"\n",
|
| 182 |
+
" tgt = {}\n",
|
| 183 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 184 |
+
" for ln in f:\n",
|
| 185 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 186 |
+
" if len(parts) < 2:\n",
|
| 187 |
+
" continue\n",
|
| 188 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 189 |
+
" return tgt\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"def _col(cols_map, *names, default_idx=None, df=None):\n",
|
| 192 |
+
" for n in names:\n",
|
| 193 |
+
" if n in cols_map:\n",
|
| 194 |
+
" return cols_map[n]\n",
|
| 195 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 196 |
+
"\n",
|
| 197 |
+
"def parse_emocat_votes(cell):\n",
|
| 198 |
+
" \"\"\"1 ô emoCat (có thể đa nhãn, vd 'happy;surprised') → vector đếm 5 lớp (chưa chuẩn hóa).\"\"\"\n",
|
| 199 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 200 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 201 |
+
" e = norm_emotion(tok)\n",
|
| 202 |
+
" if e in EMOTIONS5:\n",
|
| 203 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 204 |
+
" return v\n",
|
| 205 |
+
"\n",
|
| 206 |
+
"def load_train_labels():\n",
|
| 207 |
+
" \"\"\"train.csv → DataFrame [wavID, emos, val, aro, dom, cat0..cat4] gộp theo wav.\n",
|
| 208 |
+
" CAT = tỉ lệ vote 5 lớp (tổng=1); nếu wav không có vote hợp lệ → phân phối đều.\"\"\"\n",
|
| 209 |
+
" # train.csv phân tách bằng \"|\"; cột emoCat đa nhãn dùng \",\" bên trong (vd \"Angry,Surprised\").\n",
|
| 210 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 211 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 212 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", default_idx=1, df=df)\n",
|
| 213 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 214 |
+
" val_col = _col(cols, \"val\", \"valence\")\n",
|
| 215 |
+
" aro_col = _col(cols, \"aro\", \"arousal\")\n",
|
| 216 |
+
" dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 217 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 218 |
+
" assert emos_col, f\"Không thấy cột eMOS trong train.csv (cột: {list(df.columns)})\"\n",
|
| 219 |
+
"\n",
|
| 220 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 221 |
+
" rows = []\n",
|
| 222 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 223 |
+
" rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
|
| 224 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 225 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 226 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 227 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 228 |
+
" if cat_col:\n",
|
| 229 |
+
" for cell in g[cat_col]:\n",
|
| 230 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 231 |
+
" s = votes.sum()\n",
|
| 232 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)\n",
|
| 233 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 234 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 235 |
+
" rows.append(rec)\n",
|
| 236 |
+
" return pd.DataFrame(rows)\n",
|
| 237 |
+
"\n",
|
| 238 |
+
"target_map = load_target_emotions()\n",
|
| 239 |
+
"train_df = load_train_labels()\n",
|
| 240 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 241 |
+
"print(f\"Target emotions: {len(target_map)} | wav train (gộp): {len(train_df)} | có nhãn VAD: {HAS_VAD}\")\n",
|
| 242 |
+
"print(\"eMOS:\", train_df[\"emos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
|
| 243 |
+
"train_df.head()"
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
{
|
| 247 |
+
"cell_type": "markdown",
|
| 248 |
+
"id": "6c5e27b9",
|
| 249 |
+
"metadata": {},
|
| 250 |
+
"source": [
|
| 251 |
+
"## 3. Trích đặc trưng 2 backbone (có cache riêng từng model)\n",
|
| 252 |
+
"- **emotion2vec** → embedding + xác suất 5 lớp (như exp02).\n",
|
| 253 |
+
"- **SAILER** → embedding (features) + xác suất 9 lớp + VAD3 (như exp03).\n",
|
| 254 |
+
"Mỗi backbone cache riêng (`e2v_<tag>.npz`, `sailer_<tag>.npz`) → chạy nối tiếp được, đổi 1 backbone\n",
|
| 255 |
+
"không phải trích lại cái kia. Trích xong **giải phóng GPU** rồi mới nạp backbone sau."
|
| 256 |
+
]
|
| 257 |
+
},
|
| 258 |
+
{
|
| 259 |
+
"cell_type": "code",
|
| 260 |
+
"execution_count": null,
|
| 261 |
+
"id": "ebaac593",
|
| 262 |
+
"metadata": {
|
| 263 |
+
"lines_to_next_cell": 1
|
| 264 |
+
},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"import torch\n",
|
| 268 |
+
"import torch.nn.functional as F\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 271 |
+
"print(\"Device:\", device)\n",
|
| 272 |
+
"if device == \"cuda\":\n",
|
| 273 |
+
" print(\" ✅ GPU:\", torch.cuda.get_device_name(0))\n",
|
| 274 |
+
"else:\n",
|
| 275 |
+
" print(\" ⚠️ KHÔNG thấy GPU! Trích đặc trưng ~15k file trên CPU rất lâu.\")\n",
|
| 276 |
+
" print(\" → Settings → Accelerator = GPU T4 rồi chạy lại.\")\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"# ---- emotion2vec ----\n",
|
| 279 |
+
"def extract_e2v(stems, tag):\n",
|
| 280 |
+
" \"\"\"→ dict {stem: (emb[D1], probs5[5])}. Cache CACHE_DIR/e2v_<tag>.npz.\"\"\"\n",
|
| 281 |
+
" from tqdm.auto import tqdm\n",
|
| 282 |
+
" cache_path = os.path.join(CACHE_DIR, f\"e2v_{tag}.npz\")\n",
|
| 283 |
+
" store = {}\n",
|
| 284 |
+
" if os.path.exists(cache_path):\n",
|
| 285 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 286 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 287 |
+
" print(f\"[e2v/{tag}] nạp cache: {len(store)}\")\n",
|
| 288 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 289 |
+
" if todo:\n",
|
| 290 |
+
" import logging\n",
|
| 291 |
+
" logging.getLogger(\"funasr\").setLevel(logging.ERROR) # bớt log ồn của funasr\n",
|
| 292 |
+
" from funasr import AutoModel\n",
|
| 293 |
+
" m = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\", device=device,\n",
|
| 294 |
+
" disable_update=True, disable_pbar=True, disable_log=True) # ép GPU + tắt log\n",
|
| 295 |
+
" miss = 0\n",
|
| 296 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"e2v {tag}\")):\n",
|
| 297 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 298 |
+
" if not os.path.exists(wav):\n",
|
| 299 |
+
" miss += 1; continue\n",
|
| 300 |
+
" r = m.generate(wav, granularity=\"utterance\", extract_embedding=True)[0]\n",
|
| 301 |
+
" emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
|
| 302 |
+
" probs = {e: 0.0 for e in EMOTIONS5}\n",
|
| 303 |
+
" for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
|
| 304 |
+
" name = lab.split(\"/\")[-1]\n",
|
| 305 |
+
" if name in probs:\n",
|
| 306 |
+
" probs[name] = float(sc)\n",
|
| 307 |
+
" tot = sum(probs.values())\n",
|
| 308 |
+
" p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)\n",
|
| 309 |
+
" store[s] = np.concatenate([emb, p5]).astype(np.float32) # [D1 + 5]\n",
|
| 310 |
+
" if (i + 1) % 500 == 0:\n",
|
| 311 |
+
" np.savez(cache_path, **store)\n",
|
| 312 |
+
" np.savez(cache_path, **store)\n",
|
| 313 |
+
" del m\n",
|
| 314 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 315 |
+
" if miss:\n",
|
| 316 |
+
" print(f\"[e2v/{tag}] {miss} file thiếu → bỏ qua.\")\n",
|
| 317 |
+
" return {s: (v[:-5], v[-5:]) for s, v in store.items()}\n",
|
| 318 |
+
"\n",
|
| 319 |
+
"# ---- SAILER ----\n",
|
| 320 |
+
"def _pool_feat(features):\n",
|
| 321 |
+
" \"\"\"features (tensor) → vector 1 chiều (mean-pool nếu còn chiều thời gian).\"\"\"\n",
|
| 322 |
+
" f = features.detach().cpu().numpy()\n",
|
| 323 |
+
" if f.ndim <= 1:\n",
|
| 324 |
+
" return f.reshape(-1).astype(np.float32)\n",
|
| 325 |
+
" return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"def extract_sailer(stems, tag):\n",
|
| 328 |
+
" \"\"\"→ dict {stem: (emb[D2], probs9[9], vad3[3] thang 1–5)}. Cache CACHE_DIR/sailer_<tag>.npz.\n",
|
| 329 |
+
" Mỗi mẫu lưu vector [emb | probs9(9) | vad3(3)] → cắt lại khi nạp.\"\"\"\n",
|
| 330 |
+
" import librosa\n",
|
| 331 |
+
" from tqdm.auto import tqdm\n",
|
| 332 |
+
" cache_path = os.path.join(CACHE_DIR, f\"sailer_{tag}.npz\")\n",
|
| 333 |
+
" store = {}\n",
|
| 334 |
+
" if os.path.exists(cache_path):\n",
|
| 335 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 336 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 337 |
+
" print(f\"[sailer/{tag}] nạp cache: {len(store)}\")\n",
|
| 338 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 339 |
+
" if todo:\n",
|
| 340 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
|
| 341 |
+
" sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device).eval()\n",
|
| 342 |
+
" miss = 0\n",
|
| 343 |
+
" with torch.no_grad():\n",
|
| 344 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"sailer {tag}\")):\n",
|
| 345 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 346 |
+
" if not os.path.exists(wav):\n",
|
| 347 |
+
" miss += 1; continue\n",
|
| 348 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 349 |
+
" wave = wave[: 15 * 16000]\n",
|
| 350 |
+
" data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
|
| 351 |
+
" logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
|
| 352 |
+
" emb = _pool_feat(feat)\n",
|
| 353 |
+
" p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
|
| 354 |
+
" vad3 = np.array([1 + 4 * float(valence.item()),\n",
|
| 355 |
+
" 1 + 4 * float(arousal.item()),\n",
|
| 356 |
+
" 1 + 4 * float(dominance.item())], dtype=np.float32) # [VAL,ARO,DOM]\n",
|
| 357 |
+
" store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32) # [D2 + 9 + 3]\n",
|
| 358 |
+
" if (i + 1) % 500 == 0:\n",
|
| 359 |
+
" np.savez(cache_path, **store)\n",
|
| 360 |
+
" np.savez(cache_path, **store)\n",
|
| 361 |
+
" del sailer\n",
|
| 362 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 363 |
+
" if miss:\n",
|
| 364 |
+
" print(f\"[sailer/{tag}] {miss} file thiếu → bỏ qua.\")\n",
|
| 365 |
+
" return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}"
|
| 366 |
+
]
|
| 367 |
+
},
|
| 368 |
+
{
|
| 369 |
+
"cell_type": "markdown",
|
| 370 |
+
"id": "751d646c",
|
| 371 |
+
"metadata": {},
|
| 372 |
+
"source": [
|
| 373 |
+
"## 4. Dựng feature + nhãn cho train\n",
|
| 374 |
+
"Feature audio (KHÔNG gồm target) = nối các phần đang bật:\n",
|
| 375 |
+
"`[e2v_emb | e2v_probs5 | sailer_emb | sailer_probs9 | sailer_vad3]`.\n",
|
| 376 |
+
"One-hot target để **riêng** (chỉ EMOS head dùng). Bỏ wav thiếu feature."
|
| 377 |
+
]
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"cell_type": "code",
|
| 381 |
+
"execution_count": null,
|
| 382 |
+
"id": "005cdf2f",
|
| 383 |
+
"metadata": {},
|
| 384 |
+
"outputs": [],
|
| 385 |
+
"source": [
|
| 386 |
+
"train_stems = list(train_df[\"wavID\"])\n",
|
| 387 |
+
"if LIMIT_TRAIN:\n",
|
| 388 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 389 |
+
"\n",
|
| 390 |
+
"e2v_tr = extract_e2v(train_stems, \"train\") if USE_E2V else {}\n",
|
| 391 |
+
"sailer_tr = extract_sailer(train_stems, \"train\") if USE_SAILER else {}\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"def audio_feature(sid, e2v_map, sailer_map):\n",
|
| 394 |
+
" \"\"\"Nối đặc trưng audio cho 1 wav. None nếu thiếu phần bắt buộc.\"\"\"\n",
|
| 395 |
+
" parts = []\n",
|
| 396 |
+
" if USE_E2V:\n",
|
| 397 |
+
" pk = e2v_map.get(sid)\n",
|
| 398 |
+
" if pk is None:\n",
|
| 399 |
+
" return None\n",
|
| 400 |
+
" emb, p5 = pk\n",
|
| 401 |
+
" parts.append(emb)\n",
|
| 402 |
+
" if USE_CLASSPROB:\n",
|
| 403 |
+
" parts.append(p5)\n",
|
| 404 |
+
" if USE_SAILER:\n",
|
| 405 |
+
" pk = sailer_map.get(sid)\n",
|
| 406 |
+
" if pk is None:\n",
|
| 407 |
+
" return None\n",
|
| 408 |
+
" emb, p9, vad3 = pk\n",
|
| 409 |
+
" parts.append(emb)\n",
|
| 410 |
+
" if USE_CLASSPROB:\n",
|
| 411 |
+
" parts.append(p9); parts.append(vad3)\n",
|
| 412 |
+
" return np.concatenate(parts).astype(np.float32)\n",
|
| 413 |
+
"\n",
|
| 414 |
+
"def onehot_target(tgt):\n",
|
| 415 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 416 |
+
" if tgt in EMOTIONS5:\n",
|
| 417 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 418 |
+
" return v\n",
|
| 419 |
+
"\n",
|
| 420 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 421 |
+
"X, T, y_emos, y_vad, y_cat = [], [], [], [], []\n",
|
| 422 |
+
"for s in train_stems:\n",
|
| 423 |
+
" f = audio_feature(s, e2v_tr, sailer_tr)\n",
|
| 424 |
+
" tgt = target_map.get(s)\n",
|
| 425 |
+
" if f is None or tgt is None or s not in lab.index:\n",
|
| 426 |
+
" continue\n",
|
| 427 |
+
" X.append(f)\n",
|
| 428 |
+
" T.append(onehot_target(tgt))\n",
|
| 429 |
+
" y_emos.append(lab.loc[s, \"emos\"])\n",
|
| 430 |
+
" y_vad.append([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]])\n",
|
| 431 |
+
" y_cat.append([lab.loc[s, f\"cat{i}\"] for i in range(len(EMOTIONS5))])\n",
|
| 432 |
+
"\n",
|
| 433 |
+
"X = np.stack(X).astype(np.float32)\n",
|
| 434 |
+
"T = np.stack(T).astype(np.float32)\n",
|
| 435 |
+
"y_emos = np.array(y_emos, dtype=np.float32)\n",
|
| 436 |
+
"y_vad = np.array(y_vad, dtype=np.float32) # [N,3] (VAL,ARO,DOM) — có thể toàn NaN nếu thiếu nhãn\n",
|
| 437 |
+
"y_cat = np.array(y_cat, dtype=np.float32) # [N,5] phân phối tổng=1\n",
|
| 438 |
+
"FEAT_DIM = X.shape[1]\n",
|
| 439 |
+
"print(f\"Train: X={X.shape} target={T.shape} emos={y_emos.shape} vad={y_vad.shape} cat={y_cat.shape}\")\n",
|
| 440 |
+
"\n",
|
| 441 |
+
"# Chuẩn hóa feature audio (z-score) — lưu mean/std để áp dụng y hệt lúc dự đoán DEV.\n",
|
| 442 |
+
"feat_mean = X.mean(0, keepdims=True)\n",
|
| 443 |
+
"feat_std = X.std(0, keepdims=True) + 1e-6\n",
|
| 444 |
+
"Xn = (X - feat_mean) / feat_std\n",
|
| 445 |
+
"\n",
|
| 446 |
+
"# Chuẩn hóa nhãn liên tục (eMOS, VAD) về z-score → các MSE cùng thang (uncertainty weighting ổn định hơn).\n",
|
| 447 |
+
"# SRCC bất biến với scale → khi xuất answer.txt chỉ cần đảo z-score về thang gốc cho đẹp.\n",
|
| 448 |
+
"emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6)\n",
|
| 449 |
+
"y_emos_z = (y_emos - emos_mu) / emos_sd\n",
|
| 450 |
+
"if HAS_VAD:\n",
|
| 451 |
+
" vad_mu = np.nanmean(y_vad, axis=0)\n",
|
| 452 |
+
" vad_sd = np.nanstd(y_vad, axis=0) + 1e-6\n",
|
| 453 |
+
" y_vad_z = (y_vad - vad_mu) / vad_sd\n",
|
| 454 |
+
"else:\n",
|
| 455 |
+
" vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
|
| 456 |
+
" y_vad_z = np.zeros_like(y_vad)"
|
| 457 |
+
]
|
| 458 |
+
},
|
| 459 |
+
{
|
| 460 |
+
"cell_type": "markdown",
|
| 461 |
+
"id": "f41faa42",
|
| 462 |
+
"metadata": {},
|
| 463 |
+
"source": [
|
| 464 |
+
"## 5. Model fusion multi-task + train loop\n",
|
| 465 |
+
"- **Trunk** chung: `Linear(FEAT_DIM→TRUNK_HIDDEN)+ReLU+Dropout` (×2).\n",
|
| 466 |
+
"- **EMOS head**: nối `[trunk | one-hot target]` → MLP → 1 (vì EMOS phụ thuộc target).\n",
|
| 467 |
+
"- **CAT head**: trunk → 5 logits → softmax (dự đoán phân phối vote). Loss = soft-CE (KL).\n",
|
| 468 |
+
"- **VAD head**: trunk → 3 (VAL/ARO/DOM). Loss = MSE (bỏ qua nếu thiếu nhãn VAD).\n",
|
| 469 |
+
"- **Cân loss**: uncertainty weighting — tổng `Σ exp(-sᵢ)·Lᵢ + sᵢ`, `sᵢ=log σᵢ²` **học được**."
|
| 470 |
+
]
|
| 471 |
+
},
|
| 472 |
+
{
|
| 473 |
+
"cell_type": "code",
|
| 474 |
+
"execution_count": null,
|
| 475 |
+
"id": "dc5e0242",
|
| 476 |
+
"metadata": {
|
| 477 |
+
"lines_to_next_cell": 1
|
| 478 |
+
},
|
| 479 |
+
"outputs": [],
|
| 480 |
+
"source": [
|
| 481 |
+
"import torch.nn as nn\n",
|
| 482 |
+
"from scipy.stats import spearmanr\n",
|
| 483 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 484 |
+
"\n",
|
| 485 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 486 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 487 |
+
"\n",
|
| 488 |
+
"idx_all = np.arange(X.shape[0])\n",
|
| 489 |
+
"tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"def to_t(a):\n",
|
| 492 |
+
" return torch.tensor(a, dtype=torch.float32, device=device)\n",
|
| 493 |
+
"\n",
|
| 494 |
+
"Xn_t, T_t = to_t(Xn), to_t(T)\n",
|
| 495 |
+
"emos_t = to_t(y_emos_z).unsqueeze(1)\n",
|
| 496 |
+
"vad_t = to_t(y_vad_z)\n",
|
| 497 |
+
"cat_t = to_t(y_cat)\n",
|
| 498 |
+
"\n",
|
| 499 |
+
"class FusionMTL(nn.Module):\n",
|
| 500 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 501 |
+
" super().__init__()\n",
|
| 502 |
+
" self.trunk = nn.Sequential(\n",
|
| 503 |
+
" nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 504 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 505 |
+
" )\n",
|
| 506 |
+
" self.emos = nn.Sequential( # nhận [trunk | target]\n",
|
| 507 |
+
" nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 508 |
+
" self.cat = nn.Sequential(\n",
|
| 509 |
+
" nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 510 |
+
" self.vad = nn.Sequential(\n",
|
| 511 |
+
" nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 512 |
+
"\n",
|
| 513 |
+
" def forward(self, x, tgt):\n",
|
| 514 |
+
" h = self.trunk(x)\n",
|
| 515 |
+
" emos = self.emos(torch.cat([h, tgt], dim=1))\n",
|
| 516 |
+
" cat_logits = self.cat(h)\n",
|
| 517 |
+
" vad = self.vad(h)\n",
|
| 518 |
+
" return emos, cat_logits, vad\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"model = FusionMTL(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
|
| 521 |
+
"\n",
|
| 522 |
+
"# Trọng số bất định (log σ²) cho 5 task: emos, cat, val, aro, dom.\n",
|
| 523 |
+
"TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 524 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 525 |
+
"params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 526 |
+
"opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)\n",
|
| 527 |
+
"\n",
|
| 528 |
+
"mse = nn.MSELoss(reduction=\"none\")\n",
|
| 529 |
+
"\n",
|
| 530 |
+
"def soft_ce(logits, target_dist):\n",
|
| 531 |
+
" \"\"\"Cross-entropy với nhãn mềm (phân phối): −Σ p·log q.\"\"\"\n",
|
| 532 |
+
" logq = F.log_softmax(logits, dim=1)\n",
|
| 533 |
+
" return -(target_dist * logq).sum(dim=1)\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"def task_losses(emos_p, cat_logits, vad_p, b):\n",
|
| 536 |
+
" \"\"\"Trả về dict loss TB từng task cho 1 batch (chỉ số b).\"\"\"\n",
|
| 537 |
+
" L = {}\n",
|
| 538 |
+
" L[\"emos\"] = mse(emos_p, emos_t[b]).mean()\n",
|
| 539 |
+
" L[\"cat\"] = soft_ce(cat_logits, cat_t[b]).mean()\n",
|
| 540 |
+
" if HAS_VAD:\n",
|
| 541 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()\n",
|
| 542 |
+
" L[\"aro\"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()\n",
|
| 543 |
+
" L[\"dom\"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()\n",
|
| 544 |
+
" else:\n",
|
| 545 |
+
" z = torch.zeros((), device=device)\n",
|
| 546 |
+
" L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 547 |
+
" return L\n",
|
| 548 |
+
"\n",
|
| 549 |
+
"def combine(L):\n",
|
| 550 |
+
" \"\"\"Gộp 5 loss thành 1 số: uncertainty weighting hoặc trọng số cố định.\"\"\"\n",
|
| 551 |
+
" if USE_UNCERTAINTY:\n",
|
| 552 |
+
" tot = 0.0\n",
|
| 553 |
+
" for i, t in enumerate(TASKS):\n",
|
| 554 |
+
" tot = tot + torch.exp(-log_var[i]) * L[t] + log_var[i]\n",
|
| 555 |
+
" return tot\n",
|
| 556 |
+
" return sum(LOSS_W[t] * L[t] for t in TASKS)\n",
|
| 557 |
+
"\n",
|
| 558 |
+
"@torch.no_grad()\n",
|
| 559 |
+
"def eval_val():\n",
|
| 560 |
+
" \"\"\"SRCC từng task trên tập val nội bộ (CAT báo bằng −KL để 'cao=tốt' cho early-stop).\"\"\"\n",
|
| 561 |
+
" model.eval()\n",
|
| 562 |
+
" ep, cl, vp = model(Xn_t[va_idx], T_t[va_idx])\n",
|
| 563 |
+
" ep = ep.cpu().numpy().ravel()\n",
|
| 564 |
+
" out = {\"emos\": spearmanr(ep, y_emos[va_idx]).correlation}\n",
|
| 565 |
+
" if HAS_VAD:\n",
|
| 566 |
+
" vp = vp.cpu().numpy()\n",
|
| 567 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 568 |
+
" out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation\n",
|
| 569 |
+
" # CAT: dùng −KL(p‖q) trung bình (càng gần 0 càng tốt) → đổi dấu để hợp early-stop\n",
|
| 570 |
+
" q = F.softmax(cl, dim=1).cpu().numpy()\n",
|
| 571 |
+
" p = y_cat[va_idx]\n",
|
| 572 |
+
" kl = (p * (np.log(p + 1e-9) - np.log(q + 1e-9))).sum(1).mean()\n",
|
| 573 |
+
" out[\"cat_negkl\"] = float(-kl)\n",
|
| 574 |
+
" return out\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"def val_score(m):\n",
|
| 577 |
+
" \"\"\"Điểm tổng để early-stop = TB SRCC các task liên tục có nhãn.\"\"\"\n",
|
| 578 |
+
" keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 579 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"best_score, best_state, bad = -1e9, None, 0\n",
|
| 582 |
+
"tr_t = torch.tensor(tr_idx, device=device)\n",
|
| 583 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 584 |
+
" model.train()\n",
|
| 585 |
+
" perm = tr_t[torch.randperm(len(tr_t), device=device)]\n",
|
| 586 |
+
" run = 0.0\n",
|
| 587 |
+
" for i in range(0, len(perm), BATCH):\n",
|
| 588 |
+
" b = perm[i:i + BATCH]\n",
|
| 589 |
+
" opt.zero_grad()\n",
|
| 590 |
+
" emos_p, cat_logits, vad_p = model(Xn_t[b], T_t[b])\n",
|
| 591 |
+
" L = task_losses(emos_p, cat_logits, vad_p, b)\n",
|
| 592 |
+
" loss = combine(L)\n",
|
| 593 |
+
" loss.backward(); opt.step()\n",
|
| 594 |
+
" run += loss.item() * len(b)\n",
|
| 595 |
+
" m = eval_val()\n",
|
| 596 |
+
" sc = val_score(m)\n",
|
| 597 |
+
" if sc > best_score:\n",
|
| 598 |
+
" best_score = sc\n",
|
| 599 |
+
" best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
|
| 600 |
+
" bad = 0\n",
|
| 601 |
+
" else:\n",
|
| 602 |
+
" bad += 1\n",
|
| 603 |
+
" if ep % 5 == 0 or ep == 1:\n",
|
| 604 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in m)\n",
|
| 605 |
+
" print(f\"epoch {ep:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}\")\n",
|
| 606 |
+
" if bad >= PATIENCE:\n",
|
| 607 |
+
" print(f\"Early stop ở epoch {ep}.\")\n",
|
| 608 |
+
" break\n",
|
| 609 |
+
"\n",
|
| 610 |
+
"model.load_state_dict(best_state)\n",
|
| 611 |
+
"final = eval_val()\n",
|
| 612 |
+
"print(\"\\n✅ VAL (nội bộ) tốt nhất:\")\n",
|
| 613 |
+
"print(f\" EMOS SRCC = {final['emos']:.4f} (so mốc exp01 emotion2vec = 0.637)\")\n",
|
| 614 |
+
"if HAS_VAD:\n",
|
| 615 |
+
" print(f\" VAL/ARO/DOM SRCC = {final['val']:.4f} / {final['aro']:.4f} / {final['dom']:.4f}\"\n",
|
| 616 |
+
" f\" (so mốc SAILER = 0.341 / 0.712 / 0.630)\")\n",
|
| 617 |
+
"if USE_UNCERTAINTY:\n",
|
| 618 |
+
" print(\" log σ² mỗi task:\", {t: round(float(log_var[i]), 3) for i, t in enumerate(TASKS)})\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"# Lưu model + tham số chuẩn hóa.\n",
|
| 621 |
+
"torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
|
| 622 |
+
" \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
|
| 623 |
+
" \"FEAT_DIM\": FEAT_DIM, \"EMOTIONS5\": EMOTIONS5, \"HAS_VAD\": HAS_VAD,\n",
|
| 624 |
+
" \"USE_E2V\": USE_E2V, \"USE_SAILER\": USE_SAILER, \"USE_CLASSPROB\": USE_CLASSPROB,\n",
|
| 625 |
+
" \"TRUNK_HIDDEN\": TRUNK_HIDDEN, \"HEAD_HIDDEN\": HEAD_HIDDEN, \"val_score\": best_score},\n",
|
| 626 |
+
" os.path.join(OUT_DIR, \"fusion_mtl.pt\"))\n",
|
| 627 |
+
"print(\"Đã lưu\", os.path.join(OUT_DIR, \"fusion_mtl.pt\"))"
|
| 628 |
+
]
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"cell_type": "markdown",
|
| 632 |
+
"id": "39e3c014",
|
| 633 |
+
"metadata": {},
|
| 634 |
+
"source": [
|
| 635 |
+
"## 6. Dự đoán DEV → `answer.txt` đầy đủ 7 cột\n",
|
| 636 |
+
"- **EMOS/CAT/VAD** = model fusion (đảo z-score về thang gốc cho EMOS/VAD; CAT = softmax 5 lớp).\n",
|
| 637 |
+
"- **QMOS** = SpeechMOS (UTMOS) — để riêng, đúng thiết kế."
|
| 638 |
+
]
|
| 639 |
+
},
|
| 640 |
+
{
|
| 641 |
+
"cell_type": "code",
|
| 642 |
+
"execution_count": null,
|
| 643 |
+
"id": "c9d06ec4",
|
| 644 |
+
"metadata": {
|
| 645 |
+
"lines_to_next_cell": 1
|
| 646 |
+
},
|
| 647 |
+
"outputs": [],
|
| 648 |
+
"source": [
|
| 649 |
+
"def list_dev():\n",
|
| 650 |
+
" with open(DEV_SCP) as f:\n",
|
| 651 |
+
" return [ln.strip() for ln in f if ln.strip()] # tên file .wav\n",
|
| 652 |
+
"\n",
|
| 653 |
+
"dev_names = list_dev()\n",
|
| 654 |
+
"if LIMIT_DEV:\n",
|
| 655 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 656 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 657 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 658 |
+
"\n",
|
| 659 |
+
"# 6a. Trích đặc trưng 2 backbone cho DEV (cache riêng)\n",
|
| 660 |
+
"e2v_dev = extract_e2v(dev_stems, \"dev\") if USE_E2V else {}\n",
|
| 661 |
+
"sailer_dev = extract_sailer(dev_stems, \"dev\") if USE_SAILER else {}\n",
|
| 662 |
+
"\n",
|
| 663 |
+
"# 6b. Dự đoán 5 cột cảm xúc bằng model fusion\n",
|
| 664 |
+
"@torch.no_grad()\n",
|
| 665 |
+
"def predict_emotion(sid):\n",
|
| 666 |
+
" f = audio_feature(sid, e2v_dev, sailer_dev)\n",
|
| 667 |
+
" if f is None:\n",
|
| 668 |
+
" return None\n",
|
| 669 |
+
" fn = (f[None, :] - feat_mean) / feat_std\n",
|
| 670 |
+
" tgt = onehot_target(target_map.get(sid))[None, :]\n",
|
| 671 |
+
" model.eval()\n",
|
| 672 |
+
" emos_p, cat_logits, vad_p = model(to_t(fn), to_t(tgt))\n",
|
| 673 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu # đảo z-score\n",
|
| 674 |
+
" cat5 = F.softmax(cat_logits, dim=1)[0].cpu().numpy()\n",
|
| 675 |
+
" vad3 = vad_p[0].cpu().numpy() * vad_sd + vad_mu # [VAL,ARO,DOM]\n",
|
| 676 |
+
" return emos, cat5, vad3\n",
|
| 677 |
+
"\n",
|
| 678 |
+
"# 6c. QMOS = SpeechMOS (để riêng)\n",
|
| 679 |
+
"@torch.no_grad()\n",
|
| 680 |
+
"def run_qmos(names):\n",
|
| 681 |
+
" import librosa\n",
|
| 682 |
+
" from tqdm.auto import tqdm\n",
|
| 683 |
+
" predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True).to(device).eval()\n",
|
| 684 |
+
" out = {}\n",
|
| 685 |
+
" for n in tqdm(names, desc=\"QMOS\"):\n",
|
| 686 |
+
" p = os.path.join(WAV_DIR, n)\n",
|
| 687 |
+
" if not os.path.exists(p):\n",
|
| 688 |
+
" continue\n",
|
| 689 |
+
" wave, _ = librosa.load(p, sr=16000, mono=True)\n",
|
| 690 |
+
" out[n] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device), sr=16000).mean().item())\n",
|
| 691 |
+
" return out\n",
|
| 692 |
+
"\n",
|
| 693 |
+
"qmos_scores = run_qmos(dev_names)"
|
| 694 |
+
]
|
| 695 |
+
},
|
| 696 |
+
{
|
| 697 |
+
"cell_type": "code",
|
| 698 |
+
"execution_count": null,
|
| 699 |
+
"id": "999f19fc",
|
| 700 |
+
"metadata": {
|
| 701 |
+
"lines_to_next_cell": 1
|
| 702 |
+
},
|
| 703 |
+
"outputs": [],
|
| 704 |
+
"source": [
|
| 705 |
+
"def fmt_cat(probs5):\n",
|
| 706 |
+
" return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 707 |
+
"\n",
|
| 708 |
+
"def build_answer(out_path):\n",
|
| 709 |
+
" from tqdm.auto import tqdm\n",
|
| 710 |
+
" n_real = n_default = 0\n",
|
| 711 |
+
" with open(out_path, \"w\") as f:\n",
|
| 712 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 713 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 714 |
+
" sid = stem(name)\n",
|
| 715 |
+
" pred = predict_emotion(sid)\n",
|
| 716 |
+
" if pred is None:\n",
|
| 717 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])\n",
|
| 718 |
+
" n_default += 1\n",
|
| 719 |
+
" else:\n",
|
| 720 |
+
" emos, cat5, vad3 = pred\n",
|
| 721 |
+
" n_real += 1\n",
|
| 722 |
+
" qmos = qmos_scores.get(name, 3.0)\n",
|
| 723 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
|
| 724 |
+
" f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 725 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | fusion thật {n_real}, mặc định {n_default}\")\n",
|
| 726 |
+
"\n",
|
| 727 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 728 |
+
"build_answer(answer_path)"
|
| 729 |
+
]
|
| 730 |
+
},
|
| 731 |
+
{
|
| 732 |
+
"cell_type": "markdown",
|
| 733 |
+
"id": "708acd7a",
|
| 734 |
+
"metadata": {},
|
| 735 |
+
"source": [
|
| 736 |
+
"## 7. Validate + đóng zip"
|
| 737 |
+
]
|
| 738 |
+
},
|
| 739 |
+
{
|
| 740 |
+
"cell_type": "code",
|
| 741 |
+
"execution_count": null,
|
| 742 |
+
"id": "ba406750",
|
| 743 |
+
"metadata": {},
|
| 744 |
+
"outputs": [],
|
| 745 |
+
"source": [
|
| 746 |
+
"def validate(path):\n",
|
| 747 |
+
" import csv\n",
|
| 748 |
+
" with open(path) as f:\n",
|
| 749 |
+
" rows = list(csv.reader(f))\n",
|
| 750 |
+
" header = rows[0]\n",
|
| 751 |
+
" assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
|
| 752 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 753 |
+
" assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
|
| 754 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
|
| 755 |
+
"\n",
|
| 756 |
+
"validate(answer_path)\n",
|
| 757 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp04_fusion.zip answer.txt && unzip -l submission_track2_exp04_fusion.zip\")\n",
|
| 758 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp04_fusion.zip\"))"
|
| 759 |
+
]
|
| 760 |
+
},
|
| 761 |
+
{
|
| 762 |
+
"cell_type": "markdown",
|
| 763 |
+
"id": "c0f4e2ae",
|
| 764 |
+
"metadata": {},
|
| 765 |
+
"source": [
|
| 766 |
+
"## Ghi chú\n",
|
| 767 |
+
"- **Lần đầu**: đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` ở cell 0 để bắt lỗi setup (clone repo / import / model).\n",
|
| 768 |
+
" Chạy OK rồi đặt `None` chạy full.\n",
|
| 769 |
+
"- **VAL SRCC** ở mục 5 là ước lượng nội bộ (10% train) → so mốc EMOS 0.637 / ARO 0.712. Điểm DEV thật\n",
|
| 770 |
+
" phải nộp CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).\n",
|
| 771 |
+
"- Embedding đã cache trong `/kaggle/working/fusion_cache/` → **Save Version** để giữ; lần sau đổi\n",
|
| 772 |
+
" siêu tham số/đổi cách cân loss chỉ train lại head (vài phút), khỏi trích lại.\n",
|
| 773 |
+
"- **Ablation cho paper** (đổi cờ ở cell 0, train lại head):\n",
|
| 774 |
+
" `USE_E2V=False` (chỉ SAILER) · `USE_SAILER=False` (chỉ emotion2vec) · `USE_UNCERTAINTY=False` (trọng số tay)\n",
|
| 775 |
+
" · `USE_CLASSPROB=False` (chỉ embedding) → điền bảng ablation `docs/04_experiments_log.md`.\n",
|
| 776 |
+
"- License SAILER = **Open RAIL (phi thương mại)** → nhắc trong `docs/12_system_description.md`.\n",
|
| 777 |
+
"- Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp04)."
|
| 778 |
+
]
|
| 779 |
+
}
|
| 780 |
+
],
|
| 781 |
+
"metadata": {
|
| 782 |
+
"jupytext": {
|
| 783 |
+
"cell_metadata_filter": "-all",
|
| 784 |
+
"main_language": "python",
|
| 785 |
+
"notebook_metadata_filter": "-all"
|
| 786 |
+
}
|
| 787 |
+
},
|
| 788 |
+
"nbformat": 4,
|
| 789 |
+
"nbformat_minor": 5
|
| 790 |
+
}
|
track2/exp04_fusion_pipeline.py
ADDED
|
@@ -0,0 +1,652 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp04 (FUSION multi-task) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục tiêu:** gộp 2 backbone bổ sung nhau (**emotion2vec** thắng EMOS · **SAILER/WavLM** thắng VAD)
|
| 5 |
+
# thành **1 model multi-task** dự đoán chung 5 đầu ra cảm xúc: **EMOS · CAT · VAL · ARO · DOM**.
|
| 6 |
+
# QMOS để **riêng** (giữ SpeechMOS) — đúng thiết kế đã chốt: *"QMOS riêng + 5 cảm xúc chung"*.
|
| 7 |
+
#
|
| 8 |
+
# ## Ý tưởng (đọc 1 lần cho hiểu)
|
| 9 |
+
# Bằng chứng để fusion (từ exp01 & exp03): emotion2vec đứng đầu **EMOS** (0.637), SAILER đứng đầu
|
| 10 |
+
# **VAD** (ARO 0.712 / DOM 0.630). Hai model "nhìn" cảm xúc theo cách khác nhau → **nối đặc trưng**
|
| 11 |
+
# của cả hai rồi cho một mạng nhỏ học → kỳ vọng mạnh hơn từng model lẻ.
|
| 12 |
+
#
|
| 13 |
+
# ```
|
| 14 |
+
# ┌─ emotion2vec ─► embedding ~D1 + xác suất 5 lớp ─┐
|
| 15 |
+
# mỗi wav ──────►│ ├─► NỐI ─► TRUNK chung
|
| 16 |
+
# └─ SAILER(WavLM) ► embedding ~D2 + 9 lớp + VAD3 ─┘ (Linear+ReLU)
|
| 17 |
+
# │
|
| 18 |
+
# ┌───────────────────────────────────────────────┤
|
| 19 |
+
# target emotion(one-hot)│ │
|
| 20 |
+
# ▼ ▼
|
| 21 |
+
# [EMOS head] [CAT head] [VAD head]
|
| 22 |
+
# (cần target) (5 lớp) (VAL/ARO/DOM)
|
| 23 |
+
# ```
|
| 24 |
+
#
|
| 25 |
+
# - **Cả 2 backbone ĐÓNG BĂNG** → chỉ trích đặc trưng (cache `.npz`), **chỉ train phần trunk + head nhỏ**
|
| 26 |
+
# → nhẹ GPU, train vài phút, hợp T4. (Né fine-tune end-to-end lúc đầu.)
|
| 27 |
+
# - **EMOS phụ thuộc target** (cùng audio, target khác → điểm khác) → EMOS head nhận thêm one-hot target.
|
| 28 |
+
# **CAT/VAD** là cảm nhận về chính audio → chỉ cần trunk (không cần target).
|
| 29 |
+
# - **Nhãn vàng** gộp theo `wavID` từ `sets/train.csv`:
|
| 30 |
+
# EMOS = TB `eMOS` · VAL/ARO/DOM = TB `val/aro/dom` · CAT = **tỉ lệ vote 5 lớp** của `emoCat`.
|
| 31 |
+
# - **Cân loss = uncertainty weighting** (Kendall 2018): mỗi task có 1 trọng số σ **tự học**
|
| 32 |
+
# → không phải dò tay. Có cờ `USE_UNCERTAINTY=False` để quay về trọng số cố định khi cần debug.
|
| 33 |
+
# - Cuối cùng xuất `answer.txt` **đủ 7 cột**: `wav,QMOS,EMOS,CAT,VAL,ARO,DOM`
|
| 34 |
+
# (QMOS=SpeechMOS · 5 cột còn lại = model fusion) → nộp được ngay. So mốc: EMOS 0.637 · VAD ARO 0.712.
|
| 35 |
+
#
|
| 36 |
+
# **Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**
|
| 37 |
+
# → + Add Input dataset Track 2 (15.477 wav, có `sets/train.csv`, `sets/dev.scp`, `metadata.csv`)
|
| 38 |
+
# → sửa `DATA_ROOT` ở cell 0 → Run All. Lần đầu nên đặt `LIMIT_TRAIN = 300`, `LIMIT_DEV = 20` để bắt lỗi setup.
|
| 39 |
+
|
| 40 |
+
# %% [markdown]
|
| 41 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 42 |
+
|
| 43 |
+
# %%
|
| 44 |
+
import os
|
| 45 |
+
|
| 46 |
+
# ── Data Track 2 (dataset 15.477 wav đã ráp, có sets/train.csv) ──────────────
|
| 47 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
|
| 48 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 49 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) → target emotion
|
| 50 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # nhãn người nghe: lisID,wavID,qMOS,emoCat,eMOS,val,dom,aro
|
| 51 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV (tập cần nộp ở training phase)
|
| 52 |
+
|
| 53 |
+
OUT_DIR = "/kaggle/working"
|
| 54 |
+
CACHE_DIR = "/kaggle/working/fusion_cache" # cache embedding 2 backbone (tái dùng giữa các lần chạy)
|
| 55 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
# ── Siêu tham số train ───────────────────────────────────────────────────────
|
| 58 |
+
DEVICE = "cuda" # "cuda" trên Kaggle GPU; "cpu" nếu không có GPU
|
| 59 |
+
TRUNK_HIDDEN = 512 # số neuron lớp trunk chung
|
| 60 |
+
HEAD_HIDDEN = 128 # số neuron lớp ẩn mỗi head
|
| 61 |
+
DROPOUT = 0.3
|
| 62 |
+
LR = 1e-3
|
| 63 |
+
EPOCHS = 80
|
| 64 |
+
BATCH = 64
|
| 65 |
+
VAL_FRAC = 0.10 # 10% train → validation nội bộ (đo SRCC từng task)
|
| 66 |
+
PATIENCE = 15 # early stop theo điểm tổng val (xem SCORE_FOR_STOP)
|
| 67 |
+
SEED = 42
|
| 68 |
+
|
| 69 |
+
USE_UNCERTAINTY = True # True = tự cân loss (Kendall); False = dùng LOSS_W cố định bên dưới
|
| 70 |
+
LOSS_W = {"emos": 1.0, "cat": 1.0, "val": 1.0, "aro": 1.0, "dom": 1.0} # chỉ dùng khi tắt uncertainty
|
| 71 |
+
USE_E2V = True # bật/tắt nhánh emotion2vec trong fusion (để ablation)
|
| 72 |
+
USE_SAILER = True # bật/tắt nhánh SAILER trong fusion (để ablation)
|
| 73 |
+
USE_CLASSPROB = True # thêm xác suất lớp (e2v 5 + sailer 9) + VAD3 của SAILER vào feature
|
| 74 |
+
|
| 75 |
+
LIMIT_TRAIN = None # đặt số nhỏ (vd 300) để chạy thử nhanh; None = full
|
| 76 |
+
LIMIT_DEV = None # đặt số nhỏ (vd 20) để chạy thử nhanh; None = full
|
| 77 |
+
|
| 78 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 79 |
+
|
| 80 |
+
# 9 lớp SAILER (đúng thứ tự model xuất) + chỉ số của 5 lớp challenge trong đó
|
| 81 |
+
SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
|
| 82 |
+
EMO2SAILER = {"angry": 0, "happy": 4, "neutral": 5, "sad": 6, "surprised": 7} # EMOTIONS5 → index trong SAILER9
|
| 83 |
+
|
| 84 |
+
_EMO_ALIAS = {
|
| 85 |
+
"angry": "angry", "anger": "angry",
|
| 86 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 87 |
+
"neutral": "neutral", "calm": "neutral",
|
| 88 |
+
"sad": "sad", "sadness": "sad",
|
| 89 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def norm_emotion(label):
|
| 93 |
+
"""Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp."""
|
| 94 |
+
key = str(label).strip().lower()
|
| 95 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 96 |
+
|
| 97 |
+
def stem(path_or_name):
|
| 98 |
+
"""Lấy tên file không đuôi, để khớp wavID giữa train.csv / metadata / dev.scp."""
|
| 99 |
+
return os.path.splitext(os.path.basename(str(path_or_name)))[0]
|
| 100 |
+
|
| 101 |
+
assert USE_E2V or USE_SAILER, "Phải bật ít nhất 1 backbone (USE_E2V hoặc USE_SAILER)."
|
| 102 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 103 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 104 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 105 |
+
|
| 106 |
+
# %% [markdown]
|
| 107 |
+
# ## 1. Cài đặt + tải code SAILER
|
| 108 |
+
# emotion2vec qua `funasr` (offline). SAILER cần `WavLMWrapper` trong repo `vox-profile-release`
|
| 109 |
+
# → **clone + sys.path** (KHÔNG `pip install -e .` vì build wheel hay lỗi trên Kaggle).
|
| 110 |
+
|
| 111 |
+
# %%
|
| 112 |
+
import sys, subprocess
|
| 113 |
+
|
| 114 |
+
def pip_install(*pkgs):
|
| 115 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 116 |
+
|
| 117 |
+
pip_install("speechmos", "funasr", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
|
| 118 |
+
|
| 119 |
+
if USE_SAILER:
|
| 120 |
+
pip_install("loralib", "speechbrain") # deps WavLMWrapper cần
|
| 121 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 122 |
+
if not os.path.exists(REPO_DIR):
|
| 123 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 124 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 125 |
+
if REPO_DIR not in sys.path:
|
| 126 |
+
sys.path.insert(0, REPO_DIR)
|
| 127 |
+
|
| 128 |
+
# %% [markdown]
|
| 129 |
+
# ## 2. Đọc & gộp nhãn (gộp theo wavID)
|
| 130 |
+
# - `train.csv`: mỗi dòng = 1 listener chấm 1 wav → gộp **theo wavID**:
|
| 131 |
+
# EMOS=TB `eMOS` · VAL/ARO/DOM=TB `val/aro/dom` · CAT=**tỉ lệ vote 5 lớp** của `emoCat`.
|
| 132 |
+
# - `metadata.csv`: lấy **cảm xúc target** cho mỗi wav (để feed EMOS head).
|
| 133 |
+
|
| 134 |
+
# %%
|
| 135 |
+
import numpy as np
|
| 136 |
+
import pandas as pd
|
| 137 |
+
|
| 138 |
+
def load_target_emotions():
|
| 139 |
+
"""metadata.csv (wavID|emotion|transcript, KHÔNG header) → {stem: emotion_chuẩn|None}."""
|
| 140 |
+
tgt = {}
|
| 141 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 142 |
+
for ln in f:
|
| 143 |
+
parts = ln.strip().split("|")
|
| 144 |
+
if len(parts) < 2:
|
| 145 |
+
continue
|
| 146 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 147 |
+
return tgt
|
| 148 |
+
|
| 149 |
+
def _col(cols_map, *names, default_idx=None, df=None):
|
| 150 |
+
for n in names:
|
| 151 |
+
if n in cols_map:
|
| 152 |
+
return cols_map[n]
|
| 153 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 154 |
+
|
| 155 |
+
def parse_emocat_votes(cell):
|
| 156 |
+
"""1 ô emoCat (có thể đa nhãn, vd 'happy;surprised') → vector đếm 5 lớp (chưa chuẩn hóa)."""
|
| 157 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 158 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 159 |
+
e = norm_emotion(tok)
|
| 160 |
+
if e in EMOTIONS5:
|
| 161 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 162 |
+
return v
|
| 163 |
+
|
| 164 |
+
def load_train_labels():
|
| 165 |
+
"""train.csv → DataFrame [wavID, emos, val, aro, dom, cat0..cat4] gộp theo wav.
|
| 166 |
+
CAT = tỉ lệ vote 5 lớp (tổng=1); nếu wav không có vote hợp lệ → phân phối đều."""
|
| 167 |
+
# train.csv phân tách bằng "|"; cột emoCat đa nhãn dùng "," bên trong (vd "Angry,Surprised").
|
| 168 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 169 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 170 |
+
wav_col = _col(cols, "wavid", "wav", default_idx=1, df=df)
|
| 171 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 172 |
+
val_col = _col(cols, "val", "valence")
|
| 173 |
+
aro_col = _col(cols, "aro", "arousal")
|
| 174 |
+
dom_col = _col(cols, "dom", "dominance")
|
| 175 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 176 |
+
assert emos_col, f"Không thấy cột eMOS trong train.csv (cột: {list(df.columns)})"
|
| 177 |
+
|
| 178 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 179 |
+
rows = []
|
| 180 |
+
for sid, g in df.groupby("_stem"):
|
| 181 |
+
rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
|
| 182 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 183 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 184 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 185 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 186 |
+
if cat_col:
|
| 187 |
+
for cell in g[cat_col]:
|
| 188 |
+
votes += parse_emocat_votes(cell)
|
| 189 |
+
s = votes.sum()
|
| 190 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)
|
| 191 |
+
for i in range(len(EMOTIONS5)):
|
| 192 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 193 |
+
rows.append(rec)
|
| 194 |
+
return pd.DataFrame(rows)
|
| 195 |
+
|
| 196 |
+
target_map = load_target_emotions()
|
| 197 |
+
train_df = load_train_labels()
|
| 198 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 199 |
+
print(f"Target emotions: {len(target_map)} | wav train (gộp): {len(train_df)} | có nhãn VAD: {HAS_VAD}")
|
| 200 |
+
print("eMOS:", train_df["emos"].describe()[["mean", "std", "min", "max"]].to_dict())
|
| 201 |
+
train_df.head()
|
| 202 |
+
|
| 203 |
+
# %% [markdown]
|
| 204 |
+
# ## 3. Trích đặc trưng 2 backbone (có cache riêng từng model)
|
| 205 |
+
# - **emotion2vec** → embedding + xác suất 5 lớp (như exp02).
|
| 206 |
+
# - **SAILER** → embedding (features) + xác suất 9 lớp + VAD3 (như exp03).
|
| 207 |
+
# Mỗi backbone cache riêng (`e2v_<tag>.npz`, `sailer_<tag>.npz`) → chạy nối tiếp được, đổi 1 backbone
|
| 208 |
+
# không phải trích lại cái kia. Trích xong **giải phóng GPU** rồi mới nạp backbone sau.
|
| 209 |
+
|
| 210 |
+
# %%
|
| 211 |
+
import torch
|
| 212 |
+
import torch.nn.functional as F
|
| 213 |
+
|
| 214 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 215 |
+
print("Device:", device)
|
| 216 |
+
if device == "cuda":
|
| 217 |
+
print(" ✅ GPU:", torch.cuda.get_device_name(0))
|
| 218 |
+
else:
|
| 219 |
+
print(" ⚠️ KHÔNG thấy GPU! Trích đặc trưng ~15k file trên CPU rất lâu.")
|
| 220 |
+
print(" → Settings → Accelerator = GPU T4 rồi chạy lại.")
|
| 221 |
+
|
| 222 |
+
# ---- emotion2vec ----
|
| 223 |
+
def extract_e2v(stems, tag):
|
| 224 |
+
"""→ dict {stem: (emb[D1], probs5[5])}. Cache CACHE_DIR/e2v_<tag>.npz."""
|
| 225 |
+
from tqdm.auto import tqdm
|
| 226 |
+
cache_path = os.path.join(CACHE_DIR, f"e2v_{tag}.npz")
|
| 227 |
+
store = {}
|
| 228 |
+
if os.path.exists(cache_path):
|
| 229 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 230 |
+
store = {k: z[k] for k in z.files}
|
| 231 |
+
print(f"[e2v/{tag}] nạp cache: {len(store)}")
|
| 232 |
+
todo = [s for s in stems if s not in store]
|
| 233 |
+
if todo:
|
| 234 |
+
from funasr import AutoModel
|
| 235 |
+
m = AutoModel(model="iic/emotion2vec_plus_large", hub="hf", device=device) # ép GPU
|
| 236 |
+
miss = 0
|
| 237 |
+
for i, s in enumerate(tqdm(todo, desc=f"e2v {tag}")):
|
| 238 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 239 |
+
if not os.path.exists(wav):
|
| 240 |
+
miss += 1; continue
|
| 241 |
+
r = m.generate(wav, granularity="utterance", extract_embedding=True)[0]
|
| 242 |
+
emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
|
| 243 |
+
probs = {e: 0.0 for e in EMOTIONS5}
|
| 244 |
+
for lab, sc in zip(r["labels"], r["scores"]):
|
| 245 |
+
name = lab.split("/")[-1]
|
| 246 |
+
if name in probs:
|
| 247 |
+
probs[name] = float(sc)
|
| 248 |
+
tot = sum(probs.values())
|
| 249 |
+
p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)
|
| 250 |
+
store[s] = np.concatenate([emb, p5]).astype(np.float32) # [D1 + 5]
|
| 251 |
+
if (i + 1) % 500 == 0:
|
| 252 |
+
np.savez(cache_path, **store)
|
| 253 |
+
np.savez(cache_path, **store)
|
| 254 |
+
del m
|
| 255 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 256 |
+
if miss:
|
| 257 |
+
print(f"[e2v/{tag}] {miss} file thiếu → bỏ qua.")
|
| 258 |
+
return {s: (v[:-5], v[-5:]) for s, v in store.items()}
|
| 259 |
+
|
| 260 |
+
# ---- SAILER ----
|
| 261 |
+
def _pool_feat(features):
|
| 262 |
+
"""features (tensor) → vector 1 chiều (mean-pool nếu còn chiều thời gian)."""
|
| 263 |
+
f = features.detach().cpu().numpy()
|
| 264 |
+
if f.ndim <= 1:
|
| 265 |
+
return f.reshape(-1).astype(np.float32)
|
| 266 |
+
return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)
|
| 267 |
+
|
| 268 |
+
def extract_sailer(stems, tag):
|
| 269 |
+
"""→ dict {stem: (emb[D2], probs9[9], vad3[3] thang 1–5)}. Cache CACHE_DIR/sailer_<tag>.npz.
|
| 270 |
+
Mỗi mẫu lưu vector [emb | probs9(9) | vad3(3)] → cắt lại khi nạp."""
|
| 271 |
+
import librosa
|
| 272 |
+
from tqdm.auto import tqdm
|
| 273 |
+
cache_path = os.path.join(CACHE_DIR, f"sailer_{tag}.npz")
|
| 274 |
+
store = {}
|
| 275 |
+
if os.path.exists(cache_path):
|
| 276 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 277 |
+
store = {k: z[k] for k in z.files}
|
| 278 |
+
print(f"[sailer/{tag}] nạp cache: {len(store)}")
|
| 279 |
+
todo = [s for s in stems if s not in store]
|
| 280 |
+
if todo:
|
| 281 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper
|
| 282 |
+
sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device).eval()
|
| 283 |
+
miss = 0
|
| 284 |
+
with torch.no_grad():
|
| 285 |
+
for i, s in enumerate(tqdm(todo, desc=f"sailer {tag}")):
|
| 286 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 287 |
+
if not os.path.exists(wav):
|
| 288 |
+
miss += 1; continue
|
| 289 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 290 |
+
wave = wave[: 15 * 16000]
|
| 291 |
+
data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
|
| 292 |
+
logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
|
| 293 |
+
emb = _pool_feat(feat)
|
| 294 |
+
p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
|
| 295 |
+
vad3 = np.array([1 + 4 * float(valence.item()),
|
| 296 |
+
1 + 4 * float(arousal.item()),
|
| 297 |
+
1 + 4 * float(dominance.item())], dtype=np.float32) # [VAL,ARO,DOM]
|
| 298 |
+
store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32) # [D2 + 9 + 3]
|
| 299 |
+
if (i + 1) % 500 == 0:
|
| 300 |
+
np.savez(cache_path, **store)
|
| 301 |
+
np.savez(cache_path, **store)
|
| 302 |
+
del sailer
|
| 303 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 304 |
+
if miss:
|
| 305 |
+
print(f"[sailer/{tag}] {miss} file thiếu → bỏ qua.")
|
| 306 |
+
return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}
|
| 307 |
+
|
| 308 |
+
# %% [markdown]
|
| 309 |
+
# ## 4. Dựng feature + nhãn cho train
|
| 310 |
+
# Feature audio (KHÔNG gồm target) = nối các phần đang bật:
|
| 311 |
+
# `[e2v_emb | e2v_probs5 | sailer_emb | sailer_probs9 | sailer_vad3]`.
|
| 312 |
+
# One-hot target để **riêng** (chỉ EMOS head dùng). Bỏ wav thiếu feature.
|
| 313 |
+
|
| 314 |
+
# %%
|
| 315 |
+
train_stems = list(train_df["wavID"])
|
| 316 |
+
if LIMIT_TRAIN:
|
| 317 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 318 |
+
|
| 319 |
+
e2v_tr = extract_e2v(train_stems, "train") if USE_E2V else {}
|
| 320 |
+
sailer_tr = extract_sailer(train_stems, "train") if USE_SAILER else {}
|
| 321 |
+
|
| 322 |
+
def audio_feature(sid, e2v_map, sailer_map):
|
| 323 |
+
"""Nối đặc trưng audio cho 1 wav. None nếu thiếu phần bắt buộc."""
|
| 324 |
+
parts = []
|
| 325 |
+
if USE_E2V:
|
| 326 |
+
pk = e2v_map.get(sid)
|
| 327 |
+
if pk is None:
|
| 328 |
+
return None
|
| 329 |
+
emb, p5 = pk
|
| 330 |
+
parts.append(emb)
|
| 331 |
+
if USE_CLASSPROB:
|
| 332 |
+
parts.append(p5)
|
| 333 |
+
if USE_SAILER:
|
| 334 |
+
pk = sailer_map.get(sid)
|
| 335 |
+
if pk is None:
|
| 336 |
+
return None
|
| 337 |
+
emb, p9, vad3 = pk
|
| 338 |
+
parts.append(emb)
|
| 339 |
+
if USE_CLASSPROB:
|
| 340 |
+
parts.append(p9); parts.append(vad3)
|
| 341 |
+
return np.concatenate(parts).astype(np.float32)
|
| 342 |
+
|
| 343 |
+
def onehot_target(tgt):
|
| 344 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 345 |
+
if tgt in EMOTIONS5:
|
| 346 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 347 |
+
return v
|
| 348 |
+
|
| 349 |
+
lab = train_df.set_index("wavID")
|
| 350 |
+
X, T, y_emos, y_vad, y_cat = [], [], [], [], []
|
| 351 |
+
for s in train_stems:
|
| 352 |
+
f = audio_feature(s, e2v_tr, sailer_tr)
|
| 353 |
+
tgt = target_map.get(s)
|
| 354 |
+
if f is None or tgt is None or s not in lab.index:
|
| 355 |
+
continue
|
| 356 |
+
X.append(f)
|
| 357 |
+
T.append(onehot_target(tgt))
|
| 358 |
+
y_emos.append(lab.loc[s, "emos"])
|
| 359 |
+
y_vad.append([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]])
|
| 360 |
+
y_cat.append([lab.loc[s, f"cat{i}"] for i in range(len(EMOTIONS5))])
|
| 361 |
+
|
| 362 |
+
X = np.stack(X).astype(np.float32)
|
| 363 |
+
T = np.stack(T).astype(np.float32)
|
| 364 |
+
y_emos = np.array(y_emos, dtype=np.float32)
|
| 365 |
+
y_vad = np.array(y_vad, dtype=np.float32) # [N,3] (VAL,ARO,DOM) — có thể toàn NaN nếu thiếu nhãn
|
| 366 |
+
y_cat = np.array(y_cat, dtype=np.float32) # [N,5] phân phối tổng=1
|
| 367 |
+
FEAT_DIM = X.shape[1]
|
| 368 |
+
print(f"Train: X={X.shape} target={T.shape} emos={y_emos.shape} vad={y_vad.shape} cat={y_cat.shape}")
|
| 369 |
+
|
| 370 |
+
# Chuẩn hóa feature audio (z-score) — lưu mean/std để áp dụng y hệt lúc dự đoán DEV.
|
| 371 |
+
feat_mean = X.mean(0, keepdims=True)
|
| 372 |
+
feat_std = X.std(0, keepdims=True) + 1e-6
|
| 373 |
+
Xn = (X - feat_mean) / feat_std
|
| 374 |
+
|
| 375 |
+
# Chuẩn hóa nhãn liên tục (eMOS, VAD) về z-score → các MSE cùng thang (uncertainty weighting ổn định hơn).
|
| 376 |
+
# SRCC bất biến với scale → khi xuất answer.txt chỉ cần đảo z-score về thang gốc cho đẹp.
|
| 377 |
+
emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6)
|
| 378 |
+
y_emos_z = (y_emos - emos_mu) / emos_sd
|
| 379 |
+
if HAS_VAD:
|
| 380 |
+
vad_mu = np.nanmean(y_vad, axis=0)
|
| 381 |
+
vad_sd = np.nanstd(y_vad, axis=0) + 1e-6
|
| 382 |
+
y_vad_z = (y_vad - vad_mu) / vad_sd
|
| 383 |
+
else:
|
| 384 |
+
vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
|
| 385 |
+
y_vad_z = np.zeros_like(y_vad)
|
| 386 |
+
|
| 387 |
+
# %% [markdown]
|
| 388 |
+
# ## 5. Model fusion multi-task + train loop
|
| 389 |
+
# - **Trunk** chung: `Linear(FEAT_DIM→TRUNK_HIDDEN)+ReLU+Dropout` (×2).
|
| 390 |
+
# - **EMOS head**: nối `[trunk | one-hot target]` → MLP → 1 (vì EMOS phụ thuộc target).
|
| 391 |
+
# - **CAT head**: trunk → 5 logits → softmax (dự đoán phân phối vote). Loss = soft-CE (KL).
|
| 392 |
+
# - **VAD head**: trunk → 3 (VAL/ARO/DOM). Loss = MSE (bỏ qua nếu thiếu nhãn VAD).
|
| 393 |
+
# - **Cân loss**: uncertainty weighting — tổng `Σ exp(-sᵢ)·Lᵢ + sᵢ`, `sᵢ=log σᵢ²` **học được**.
|
| 394 |
+
|
| 395 |
+
# %%
|
| 396 |
+
import torch.nn as nn
|
| 397 |
+
from scipy.stats import spearmanr
|
| 398 |
+
from sklearn.model_selection import train_test_split
|
| 399 |
+
|
| 400 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 401 |
+
N_EMO = len(EMOTIONS5)
|
| 402 |
+
|
| 403 |
+
idx_all = np.arange(X.shape[0])
|
| 404 |
+
tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)
|
| 405 |
+
|
| 406 |
+
def to_t(a):
|
| 407 |
+
return torch.tensor(a, dtype=torch.float32, device=device)
|
| 408 |
+
|
| 409 |
+
Xn_t, T_t = to_t(Xn), to_t(T)
|
| 410 |
+
emos_t = to_t(y_emos_z).unsqueeze(1)
|
| 411 |
+
vad_t = to_t(y_vad_z)
|
| 412 |
+
cat_t = to_t(y_cat)
|
| 413 |
+
|
| 414 |
+
class FusionMTL(nn.Module):
|
| 415 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 416 |
+
super().__init__()
|
| 417 |
+
self.trunk = nn.Sequential(
|
| 418 |
+
nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 419 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 420 |
+
)
|
| 421 |
+
self.emos = nn.Sequential( # nhận [trunk | target]
|
| 422 |
+
nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 423 |
+
self.cat = nn.Sequential(
|
| 424 |
+
nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 425 |
+
self.vad = nn.Sequential(
|
| 426 |
+
nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 427 |
+
|
| 428 |
+
def forward(self, x, tgt):
|
| 429 |
+
h = self.trunk(x)
|
| 430 |
+
emos = self.emos(torch.cat([h, tgt], dim=1))
|
| 431 |
+
cat_logits = self.cat(h)
|
| 432 |
+
vad = self.vad(h)
|
| 433 |
+
return emos, cat_logits, vad
|
| 434 |
+
|
| 435 |
+
model = FusionMTL(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
|
| 436 |
+
|
| 437 |
+
# Trọng số bất định (log σ²) cho 5 task: emos, cat, val, aro, dom.
|
| 438 |
+
TASKS = ["emos", "cat", "val", "aro", "dom"]
|
| 439 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 440 |
+
params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
|
| 441 |
+
opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)
|
| 442 |
+
|
| 443 |
+
mse = nn.MSELoss(reduction="none")
|
| 444 |
+
|
| 445 |
+
def soft_ce(logits, target_dist):
|
| 446 |
+
"""Cross-entropy với nhãn mềm (phân phối): −Σ p·log q."""
|
| 447 |
+
logq = F.log_softmax(logits, dim=1)
|
| 448 |
+
return -(target_dist * logq).sum(dim=1)
|
| 449 |
+
|
| 450 |
+
def task_losses(emos_p, cat_logits, vad_p, b):
|
| 451 |
+
"""Trả về dict loss TB từng task cho 1 batch (chỉ số b)."""
|
| 452 |
+
L = {}
|
| 453 |
+
L["emos"] = mse(emos_p, emos_t[b]).mean()
|
| 454 |
+
L["cat"] = soft_ce(cat_logits, cat_t[b]).mean()
|
| 455 |
+
if HAS_VAD:
|
| 456 |
+
L["val"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()
|
| 457 |
+
L["aro"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()
|
| 458 |
+
L["dom"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()
|
| 459 |
+
else:
|
| 460 |
+
z = torch.zeros((), device=device)
|
| 461 |
+
L["val"] = L["aro"] = L["dom"] = z
|
| 462 |
+
return L
|
| 463 |
+
|
| 464 |
+
def combine(L):
|
| 465 |
+
"""Gộp 5 loss thành 1 số: uncertainty weighting hoặc trọng số cố định."""
|
| 466 |
+
if USE_UNCERTAINTY:
|
| 467 |
+
tot = 0.0
|
| 468 |
+
for i, t in enumerate(TASKS):
|
| 469 |
+
tot = tot + torch.exp(-log_var[i]) * L[t] + log_var[i]
|
| 470 |
+
return tot
|
| 471 |
+
return sum(LOSS_W[t] * L[t] for t in TASKS)
|
| 472 |
+
|
| 473 |
+
@torch.no_grad()
|
| 474 |
+
def eval_val():
|
| 475 |
+
"""SRCC từng task trên tập val nội bộ (CAT báo bằng −KL để 'cao=tốt' cho early-stop)."""
|
| 476 |
+
model.eval()
|
| 477 |
+
ep, cl, vp = model(Xn_t[va_idx], T_t[va_idx])
|
| 478 |
+
ep = ep.cpu().numpy().ravel()
|
| 479 |
+
out = {"emos": spearmanr(ep, y_emos[va_idx]).correlation}
|
| 480 |
+
if HAS_VAD:
|
| 481 |
+
vp = vp.cpu().numpy()
|
| 482 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 483 |
+
out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation
|
| 484 |
+
# CAT: dùng −KL(p‖q) trung bình (càng gần 0 càng tốt) → đổi dấu để hợp early-stop
|
| 485 |
+
q = F.softmax(cl, dim=1).cpu().numpy()
|
| 486 |
+
p = y_cat[va_idx]
|
| 487 |
+
kl = (p * (np.log(p + 1e-9) - np.log(q + 1e-9))).sum(1).mean()
|
| 488 |
+
out["cat_negkl"] = float(-kl)
|
| 489 |
+
return out
|
| 490 |
+
|
| 491 |
+
def val_score(m):
|
| 492 |
+
"""Điểm tổng để early-stop = TB SRCC các task liên tục có nhãn."""
|
| 493 |
+
keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 494 |
+
return float(np.mean([m[k] for k in keys]))
|
| 495 |
+
|
| 496 |
+
best_score, best_state, bad = -1e9, None, 0
|
| 497 |
+
tr_t = torch.tensor(tr_idx, device=device)
|
| 498 |
+
for ep in range(1, EPOCHS + 1):
|
| 499 |
+
model.train()
|
| 500 |
+
perm = tr_t[torch.randperm(len(tr_t), device=device)]
|
| 501 |
+
run = 0.0
|
| 502 |
+
for i in range(0, len(perm), BATCH):
|
| 503 |
+
b = perm[i:i + BATCH]
|
| 504 |
+
opt.zero_grad()
|
| 505 |
+
emos_p, cat_logits, vad_p = model(Xn_t[b], T_t[b])
|
| 506 |
+
L = task_losses(emos_p, cat_logits, vad_p, b)
|
| 507 |
+
loss = combine(L)
|
| 508 |
+
loss.backward(); opt.step()
|
| 509 |
+
run += loss.item() * len(b)
|
| 510 |
+
m = eval_val()
|
| 511 |
+
sc = val_score(m)
|
| 512 |
+
if sc > best_score:
|
| 513 |
+
best_score = sc
|
| 514 |
+
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
| 515 |
+
bad = 0
|
| 516 |
+
else:
|
| 517 |
+
bad += 1
|
| 518 |
+
if ep % 5 == 0 or ep == 1:
|
| 519 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in m)
|
| 520 |
+
print(f"epoch {ep:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}")
|
| 521 |
+
if bad >= PATIENCE:
|
| 522 |
+
print(f"Early stop ở epoch {ep}.")
|
| 523 |
+
break
|
| 524 |
+
|
| 525 |
+
model.load_state_dict(best_state)
|
| 526 |
+
final = eval_val()
|
| 527 |
+
print("\n✅ VAL (nội bộ) tốt nhất:")
|
| 528 |
+
print(f" EMOS SRCC = {final['emos']:.4f} (so mốc exp01 emotion2vec = 0.637)")
|
| 529 |
+
if HAS_VAD:
|
| 530 |
+
print(f" VAL/ARO/DOM SRCC = {final['val']:.4f} / {final['aro']:.4f} / {final['dom']:.4f}"
|
| 531 |
+
f" (so mốc SAILER = 0.341 / 0.712 / 0.630)")
|
| 532 |
+
if USE_UNCERTAINTY:
|
| 533 |
+
print(" log σ² mỗi task:", {t: round(float(log_var[i]), 3) for i, t in enumerate(TASKS)})
|
| 534 |
+
|
| 535 |
+
# Lưu model + tham số chuẩn hóa.
|
| 536 |
+
torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
|
| 537 |
+
"emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
|
| 538 |
+
"FEAT_DIM": FEAT_DIM, "EMOTIONS5": EMOTIONS5, "HAS_VAD": HAS_VAD,
|
| 539 |
+
"USE_E2V": USE_E2V, "USE_SAILER": USE_SAILER, "USE_CLASSPROB": USE_CLASSPROB,
|
| 540 |
+
"TRUNK_HIDDEN": TRUNK_HIDDEN, "HEAD_HIDDEN": HEAD_HIDDEN, "val_score": best_score},
|
| 541 |
+
os.path.join(OUT_DIR, "fusion_mtl.pt"))
|
| 542 |
+
print("Đã lưu", os.path.join(OUT_DIR, "fusion_mtl.pt"))
|
| 543 |
+
|
| 544 |
+
# %% [markdown]
|
| 545 |
+
# ## 6. Dự đoán DEV → `answer.txt` đầy đủ 7 cột
|
| 546 |
+
# - **EMOS/CAT/VAD** = model fusion (đảo z-score về thang gốc cho EMOS/VAD; CAT = softmax 5 lớp).
|
| 547 |
+
# - **QMOS** = SpeechMOS (UTMOS) — để riêng, đúng thiết kế.
|
| 548 |
+
|
| 549 |
+
# %%
|
| 550 |
+
def list_dev():
|
| 551 |
+
with open(DEV_SCP) as f:
|
| 552 |
+
return [ln.strip() for ln in f if ln.strip()] # tên file .wav
|
| 553 |
+
|
| 554 |
+
dev_names = list_dev()
|
| 555 |
+
if LIMIT_DEV:
|
| 556 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 557 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 558 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 559 |
+
|
| 560 |
+
# 6a. Trích đặc trưng 2 backbone cho DEV (cache riêng)
|
| 561 |
+
e2v_dev = extract_e2v(dev_stems, "dev") if USE_E2V else {}
|
| 562 |
+
sailer_dev = extract_sailer(dev_stems, "dev") if USE_SAILER else {}
|
| 563 |
+
|
| 564 |
+
# 6b. Dự đoán 5 cột cảm xúc bằng model fusion
|
| 565 |
+
@torch.no_grad()
|
| 566 |
+
def predict_emotion(sid):
|
| 567 |
+
f = audio_feature(sid, e2v_dev, sailer_dev)
|
| 568 |
+
if f is None:
|
| 569 |
+
return None
|
| 570 |
+
fn = (f[None, :] - feat_mean) / feat_std
|
| 571 |
+
tgt = onehot_target(target_map.get(sid))[None, :]
|
| 572 |
+
model.eval()
|
| 573 |
+
emos_p, cat_logits, vad_p = model(to_t(fn), to_t(tgt))
|
| 574 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu # đảo z-score
|
| 575 |
+
cat5 = F.softmax(cat_logits, dim=1)[0].cpu().numpy()
|
| 576 |
+
vad3 = vad_p[0].cpu().numpy() * vad_sd + vad_mu # [VAL,ARO,DOM]
|
| 577 |
+
return emos, cat5, vad3
|
| 578 |
+
|
| 579 |
+
# 6c. QMOS = SpeechMOS (để riêng)
|
| 580 |
+
@torch.no_grad()
|
| 581 |
+
def run_qmos(names):
|
| 582 |
+
import librosa
|
| 583 |
+
from tqdm.auto import tqdm
|
| 584 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).to(device).eval()
|
| 585 |
+
out = {}
|
| 586 |
+
for n in tqdm(names, desc="QMOS"):
|
| 587 |
+
p = os.path.join(WAV_DIR, n)
|
| 588 |
+
if not os.path.exists(p):
|
| 589 |
+
continue
|
| 590 |
+
wave, _ = librosa.load(p, sr=16000, mono=True)
|
| 591 |
+
out[n] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device), sr=16000).mean().item())
|
| 592 |
+
return out
|
| 593 |
+
|
| 594 |
+
qmos_scores = run_qmos(dev_names)
|
| 595 |
+
|
| 596 |
+
# %%
|
| 597 |
+
def fmt_cat(probs5):
|
| 598 |
+
return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 599 |
+
|
| 600 |
+
def build_answer(out_path):
|
| 601 |
+
from tqdm.auto import tqdm
|
| 602 |
+
n_real = n_default = 0
|
| 603 |
+
with open(out_path, "w") as f:
|
| 604 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 605 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 606 |
+
sid = stem(name)
|
| 607 |
+
pred = predict_emotion(sid)
|
| 608 |
+
if pred is None:
|
| 609 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])
|
| 610 |
+
n_default += 1
|
| 611 |
+
else:
|
| 612 |
+
emos, cat5, vad3 = pred
|
| 613 |
+
n_real += 1
|
| 614 |
+
qmos = qmos_scores.get(name, 3.0)
|
| 615 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
|
| 616 |
+
f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 617 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | fusion thật {n_real}, mặc định {n_default}")
|
| 618 |
+
|
| 619 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 620 |
+
build_answer(answer_path)
|
| 621 |
+
|
| 622 |
+
# %% [markdown]
|
| 623 |
+
# ## 7. Validate + đóng zip
|
| 624 |
+
|
| 625 |
+
# %%
|
| 626 |
+
def validate(path):
|
| 627 |
+
import csv
|
| 628 |
+
with open(path) as f:
|
| 629 |
+
rows = list(csv.reader(f))
|
| 630 |
+
header = rows[0]
|
| 631 |
+
assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
|
| 632 |
+
for i, r in enumerate(rows[1:], 2):
|
| 633 |
+
assert len(r) == len(header), f"Dòng {i} sai số cột"
|
| 634 |
+
print(f"OK: {len(rows)-1} dòng, header = {header}")
|
| 635 |
+
|
| 636 |
+
validate(answer_path)
|
| 637 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp04_fusion.zip answer.txt && unzip -l submission_track2_exp04_fusion.zip")
|
| 638 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp04_fusion.zip"))
|
| 639 |
+
|
| 640 |
+
# %% [markdown]
|
| 641 |
+
# ## Ghi chú
|
| 642 |
+
# - **Lần đầu**: đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` ở cell 0 để bắt lỗi setup (clone repo / import / model).
|
| 643 |
+
# Chạy OK rồi đặt `None` chạy full.
|
| 644 |
+
# - **VAL SRCC** ở mục 5 là ước lượng nội bộ (10% train) → so mốc EMOS 0.637 / ARO 0.712. Điểm DEV thật
|
| 645 |
+
# phải nộp CodaBench mới biết (My Submissions → Track 2, bỏ chọn track khác).
|
| 646 |
+
# - Embedding đã cache trong `/kaggle/working/fusion_cache/` → **Save Version** để giữ; lần sau đổi
|
| 647 |
+
# siêu tham số/đổi cách cân loss chỉ train lại head (vài phút), khỏi trích lại.
|
| 648 |
+
# - **Ablation cho paper** (đổi cờ ở cell 0, train lại head):
|
| 649 |
+
# `USE_E2V=False` (chỉ SAILER) · `USE_SAILER=False` (chỉ emotion2vec) · `USE_UNCERTAINTY=False` (trọng số tay)
|
| 650 |
+
# · `USE_CLASSPROB=False` (chỉ embedding) → điền bảng ablation `docs/04_experiments_log.md`.
|
| 651 |
+
# - License SAILER = **Open RAIL (phi thương mại)** → nhắc trong `docs/12_system_description.md`.
|
| 652 |
+
# - Nhớ ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp04).
|
track2/exp05_vad_audeering.ipynb
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "40a15eae",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp05 (VAD bằng audeering MSP-dim) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục tiêu:** đẩy **VAL** (SAILER chỉ 0.341 — thấp nhất) bằng model VAD chuyên\n",
|
| 11 |
+
"`audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim` (dimensional, xuất thẳng\n",
|
| 12 |
+
"arousal/dominance/valence ∈ [0,1]). **Thay cả 3 cột VAD** bằng audeering.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Phân công model (giữ cái tốt của exp03, chỉ đổi VAD)\n",
|
| 15 |
+
"```\n",
|
| 16 |
+
"QMOS ← SpeechMOS (UTMOS) (để riêng)\n",
|
| 17 |
+
"EMOS ← SAILER (1 + 4·P(target)) ┐ giữ nguyên exp03\n",
|
| 18 |
+
"CAT ← SAILER (5 lớp renorm) ┘\n",
|
| 19 |
+
"VAL ← audeering ┐\n",
|
| 20 |
+
"ARO ← audeering ├─ THAY cả 3 (model VAD chuyên)\n",
|
| 21 |
+
"DOM ← audeering ┘\n",
|
| 22 |
+
"```\n",
|
| 23 |
+
"- Mỗi wav chạy **2 forward**: SAILER (EMOS+CAT) + audeering (VAD). KHÔNG train.\n",
|
| 24 |
+
"- So với exp03 (VAD từ SAILER: VAL 0.341 / ARO 0.712 / DOM 0.630) → nộp để A/B từng cột.\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"**Cách chạy Kaggle:** GPU **T4** + Internet **On** → + Add Input dataset Track 2 (có `sets/dev.scp`,\n",
|
| 27 |
+
"`metadata.csv`) → sửa `DATA_ROOT` → lần đầu `LIMIT = 20` kiểm tra VAD ra 1–5 hợp lý → rồi `None`.\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"⚠️ License **SAILER = Open RAIL** · **audeering = CC BY-NC-SA 4.0** (đều phi thương mại) → khai báo `docs/12_`."
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "markdown",
|
| 34 |
+
"id": "2c098aff",
|
| 35 |
+
"metadata": {},
|
| 36 |
+
"source": [
|
| 37 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": null,
|
| 43 |
+
"id": "fa143e27",
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"outputs": [],
|
| 46 |
+
"source": [
|
| 47 |
+
"import os\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
|
| 50 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 51 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript → target emotion (cho EMOS)\n",
|
| 52 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"DEVICE = \"cuda\"\n",
|
| 57 |
+
"MAX_SECONDS = 15\n",
|
| 58 |
+
"SR = 16000\n",
|
| 59 |
+
"LIMIT = None # đặt 20 để chạy thử nhanh; None = full DEV\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 62 |
+
"SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
|
| 63 |
+
"EMO2SAILER = {\"angry\": 0, \"happy\": 4, \"neutral\": 5, \"sad\": 6, \"surprised\": 7}\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"_EMO_ALIAS = {\n",
|
| 66 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 67 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 68 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 69 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 70 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 71 |
+
"}\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"def norm_emotion(label):\n",
|
| 74 |
+
" key = str(label).strip().lower()\n",
|
| 75 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"def stem(path_or_name):\n",
|
| 78 |
+
" return os.path.splitext(os.path.basename(str(path_or_name)))[0]\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 81 |
+
"for p in [WAV_DIR, METADATA_CSV, DEV_SCP]:\n",
|
| 82 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 83 |
+
]
|
| 84 |
+
},
|
| 85 |
+
{
|
| 86 |
+
"cell_type": "markdown",
|
| 87 |
+
"id": "f2d1dd91",
|
| 88 |
+
"metadata": {},
|
| 89 |
+
"source": [
|
| 90 |
+
"## 1. Cài đặt + tải code SAILER (clone + sys.path, KHÔNG pip install -e .)"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "code",
|
| 95 |
+
"execution_count": null,
|
| 96 |
+
"id": "d426b50b",
|
| 97 |
+
"metadata": {},
|
| 98 |
+
"outputs": [],
|
| 99 |
+
"source": [
|
| 100 |
+
"import sys, subprocess\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"def pip_install(*pkgs):\n",
|
| 103 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 106 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 107 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 108 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\", \"scipy\", \"tqdm\")\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"if REPO_DIR not in sys.path:\n",
|
| 113 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "markdown",
|
| 118 |
+
"id": "798ad5ef",
|
| 119 |
+
"metadata": {},
|
| 120 |
+
"source": [
|
| 121 |
+
"## 2. Nạp model SAILER (cho EMOS + CAT)"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "code",
|
| 126 |
+
"execution_count": null,
|
| 127 |
+
"id": "5d9ffc83",
|
| 128 |
+
"metadata": {},
|
| 129 |
+
"outputs": [],
|
| 130 |
+
"source": [
|
| 131 |
+
"import torch\n",
|
| 132 |
+
"import torch.nn.functional as F\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 135 |
+
"print(\"Device:\", device)\n",
|
| 136 |
+
"if device == \"cuda\":\n",
|
| 137 |
+
" print(\" ✅ GPU:\", torch.cuda.get_device_name(0))\n",
|
| 138 |
+
"else:\n",
|
| 139 |
+
" print(\" ⚠️ KHÔNG thấy GPU → Settings → Accelerator = GPU T4 rồi chạy lại.\")\n",
|
| 140 |
+
"\n",
|
| 141 |
+
"from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device)\n",
|
| 144 |
+
"sailer.eval()\n",
|
| 145 |
+
"print(\"✅ Đã nạp SAILER (wavlm-large-categorical-emotion)\")"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
{
|
| 149 |
+
"cell_type": "markdown",
|
| 150 |
+
"id": "6c18fa40",
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"source": [
|
| 153 |
+
"## 2b. Nạp model VAD chuyên: audeering wav2vec2 MSP-dim\n",
|
| 154 |
+
"⚠️ Kế thừa `Wav2Vec2PreTrainedModel` (theo model card) hay dính lỗi version transformers\n",
|
| 155 |
+
"(thiếu `__file__` / `all_tied_weights_keys`...). Cách dứt điểm: CHỈ dùng `Wav2Vec2Model` (backbone\n",
|
| 156 |
+
"được hỗ trợ tốt) + **tự nạp tay** trọng số regression head từ checkpoint → không đụng tie-weights/experts.\n",
|
| 157 |
+
"⚠️ Model xuất thứ tự **[arousal, dominance, valence]** ∈ [0,1] → đổi về [VAL,ARO,DOM] thang 1–5 khi ghi."
|
| 158 |
+
]
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"cell_type": "code",
|
| 162 |
+
"execution_count": null,
|
| 163 |
+
"id": "ddf569cd",
|
| 164 |
+
"metadata": {},
|
| 165 |
+
"outputs": [],
|
| 166 |
+
"source": [
|
| 167 |
+
"import torch.nn as nn\n",
|
| 168 |
+
"from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 169 |
+
"from huggingface_hub import hf_hub_download\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 172 |
+
"aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"# 1) backbone wav2vec2 (load chuẩn, không subclass)\n",
|
| 175 |
+
"aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 176 |
+
"aud_backbone = Wav2Vec2Model(aud_cfg)\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"# 2) tải state_dict gốc của checkpoint (ưu tiên safetensors)\n",
|
| 179 |
+
"try:\n",
|
| 180 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 181 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 182 |
+
"except Exception:\n",
|
| 183 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"# 3) nạp phần backbone (key có tiền tố \"wav2vec2.\") vào Wav2Vec2Model\n",
|
| 186 |
+
"bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 187 |
+
"missing, unexpected = aud_backbone.load_state_dict(bb_sd, strict=False)\n",
|
| 188 |
+
"print(f\" backbone: thiếu {len(missing)} key, dư {len(unexpected)} key (strict=False)\")\n",
|
| 189 |
+
"\n",
|
| 190 |
+
"# 4) dựng regression head theo đúng shape trong checkpoint rồi nạp trọng số \"classifier.*\"\n",
|
| 191 |
+
"_hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 192 |
+
"_out = _sd[\"classifier.out_proj.weight\"].shape[0] # = 3 (arousal, dominance, valence)\n",
|
| 193 |
+
"aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
|
| 194 |
+
"aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"])\n",
|
| 195 |
+
"aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 196 |
+
"aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"])\n",
|
| 197 |
+
"aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"aud_backbone = aud_backbone.to(device).eval()\n",
|
| 200 |
+
"aud_head = aud_head.to(device).eval()\n",
|
| 201 |
+
"print(f\"✅ Đã nạp audeering MSP-dim (backbone + head {_hid}→{_out}) — model VAD chuyên\")"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "markdown",
|
| 206 |
+
"id": "849f083a",
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"source": [
|
| 209 |
+
"## 3. Đọc cảm xúc target cho mỗi wav (cho EMOS của SAILER)"
|
| 210 |
+
]
|
| 211 |
+
},
|
| 212 |
+
{
|
| 213 |
+
"cell_type": "code",
|
| 214 |
+
"execution_count": null,
|
| 215 |
+
"id": "546df027",
|
| 216 |
+
"metadata": {
|
| 217 |
+
"lines_to_next_cell": 1
|
| 218 |
+
},
|
| 219 |
+
"outputs": [],
|
| 220 |
+
"source": [
|
| 221 |
+
"import numpy as np\n",
|
| 222 |
+
"import librosa\n",
|
| 223 |
+
"\n",
|
| 224 |
+
"def load_target_emotions():\n",
|
| 225 |
+
" tgt = {}\n",
|
| 226 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 227 |
+
" for ln in f:\n",
|
| 228 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 229 |
+
" if len(parts) < 2:\n",
|
| 230 |
+
" continue\n",
|
| 231 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 232 |
+
" return tgt\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"target_map = load_target_emotions()\n",
|
| 235 |
+
"print(f\"Target emotions: {len(target_map)} wav | ví dụ:\", dict(list(target_map.items())[:3]))"
|
| 236 |
+
]
|
| 237 |
+
},
|
| 238 |
+
{
|
| 239 |
+
"cell_type": "markdown",
|
| 240 |
+
"id": "1fee644a",
|
| 241 |
+
"metadata": {},
|
| 242 |
+
"source": [
|
| 243 |
+
"## 4. Hàm chấm: SAILER (EMOS+CAT) + audeering (VAD)"
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
{
|
| 247 |
+
"cell_type": "code",
|
| 248 |
+
"execution_count": null,
|
| 249 |
+
"id": "54a8ad31",
|
| 250 |
+
"metadata": {
|
| 251 |
+
"lines_to_next_cell": 1
|
| 252 |
+
},
|
| 253 |
+
"outputs": [],
|
| 254 |
+
"source": [
|
| 255 |
+
"@torch.no_grad()\n",
|
| 256 |
+
"def sailer_probs(wav_path):\n",
|
| 257 |
+
" \"\"\"→ probs9 (float32[9]); None nếu thiếu/lỗi. Chỉ lấy 9 lớp (EMOS+CAT), bỏ VAD của SAILER.\"\"\"\n",
|
| 258 |
+
" if not os.path.exists(wav_path):\n",
|
| 259 |
+
" return None\n",
|
| 260 |
+
" wave, _ = librosa.load(wav_path, sr=SR, mono=True)\n",
|
| 261 |
+
" wave = wave[: MAX_SECONDS * SR]\n",
|
| 262 |
+
" data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
|
| 263 |
+
" logits, _feat, _det, _aro, _val, _dom = sailer(data, return_feature=True)\n",
|
| 264 |
+
" return F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"def emos_from_probs(probs9, target):\n",
|
| 267 |
+
" if target is None or target not in EMO2SAILER:\n",
|
| 268 |
+
" return None\n",
|
| 269 |
+
" return 1.0 + 4.0 * float(probs9[EMO2SAILER[target]])\n",
|
| 270 |
+
"\n",
|
| 271 |
+
"def cat5_from_probs(probs9):\n",
|
| 272 |
+
" v = np.array([probs9[EMO2SAILER[e]] for e in EMOTIONS5], dtype=np.float32)\n",
|
| 273 |
+
" s = v.sum()\n",
|
| 274 |
+
" return v / s if s > 0 else np.full(5, 0.2, dtype=np.float32)\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"@torch.no_grad()\n",
|
| 277 |
+
"def audeering_vad(wav_path):\n",
|
| 278 |
+
" \"\"\"VAD bằng audeering → [VAL, ARO, DOM] thang 1–5; None nếu thiếu/lỗi.\n",
|
| 279 |
+
" Model xuất [arousal, dominance, valence] ∈ [0,1].\"\"\"\n",
|
| 280 |
+
" if not os.path.exists(wav_path):\n",
|
| 281 |
+
" return None\n",
|
| 282 |
+
" wave, _ = librosa.load(wav_path, sr=SR, mono=True)\n",
|
| 283 |
+
" wave = wave[: MAX_SECONDS * SR]\n",
|
| 284 |
+
" x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 285 |
+
" x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
|
| 286 |
+
" h = aud_backbone(x)[0].mean(dim=1) # mean-pool theo thời gian\n",
|
| 287 |
+
" out = aud_head(h)[0].detach().cpu().numpy() # [arousal, dominance, valence]\n",
|
| 288 |
+
" aro, dom, val = float(out[0]), float(out[1]), float(out[2])\n",
|
| 289 |
+
" return np.array([1 + 4 * val, 1 + 4 * aro, 1 + 4 * dom], dtype=np.float32) # [VAL,ARO,DOM]"
|
| 290 |
+
]
|
| 291 |
+
},
|
| 292 |
+
{
|
| 293 |
+
"cell_type": "markdown",
|
| 294 |
+
"id": "e662c05e",
|
| 295 |
+
"metadata": {},
|
| 296 |
+
"source": [
|
| 297 |
+
"## 5. QMOS = SpeechMOS (UTMOS) — bắt buộc cho answer.txt"
|
| 298 |
+
]
|
| 299 |
+
},
|
| 300 |
+
{
|
| 301 |
+
"cell_type": "code",
|
| 302 |
+
"execution_count": null,
|
| 303 |
+
"id": "aacc9e34",
|
| 304 |
+
"metadata": {
|
| 305 |
+
"lines_to_next_cell": 1
|
| 306 |
+
},
|
| 307 |
+
"outputs": [],
|
| 308 |
+
"source": [
|
| 309 |
+
"@torch.no_grad()\n",
|
| 310 |
+
"def run_qmos(names):\n",
|
| 311 |
+
" predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True).to(device).eval()\n",
|
| 312 |
+
" from tqdm.auto import tqdm\n",
|
| 313 |
+
" out = {}\n",
|
| 314 |
+
" for n in tqdm(names, desc=\"QMOS\"):\n",
|
| 315 |
+
" p = os.path.join(WAV_DIR, n)\n",
|
| 316 |
+
" if not os.path.exists(p):\n",
|
| 317 |
+
" continue\n",
|
| 318 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 319 |
+
" x = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 320 |
+
" out[n] = float(predictor(x, sr=SR).mean().item())\n",
|
| 321 |
+
" return out"
|
| 322 |
+
]
|
| 323 |
+
},
|
| 324 |
+
{
|
| 325 |
+
"cell_type": "markdown",
|
| 326 |
+
"id": "d6712414",
|
| 327 |
+
"metadata": {},
|
| 328 |
+
"source": [
|
| 329 |
+
"## 6. Chạy trên DEV → `answer.txt` (QMOS, EMOS, CAT ← SAILER/UTMOS · VAL,ARO,DOM ← audeering)"
|
| 330 |
+
]
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"cell_type": "code",
|
| 334 |
+
"execution_count": null,
|
| 335 |
+
"id": "011b2530",
|
| 336 |
+
"metadata": {
|
| 337 |
+
"lines_to_next_cell": 1
|
| 338 |
+
},
|
| 339 |
+
"outputs": [],
|
| 340 |
+
"source": [
|
| 341 |
+
"def list_dev():\n",
|
| 342 |
+
" with open(DEV_SCP) as f:\n",
|
| 343 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 344 |
+
"\n",
|
| 345 |
+
"dev_names = list_dev()\n",
|
| 346 |
+
"if LIMIT:\n",
|
| 347 |
+
" dev_names = dev_names[:LIMIT]\n",
|
| 348 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 349 |
+
"\n",
|
| 350 |
+
"qmos_scores = run_qmos(dev_names)\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"def fmt_cat(probs5):\n",
|
| 353 |
+
" return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 354 |
+
"\n",
|
| 355 |
+
"def build_answer(out_path):\n",
|
| 356 |
+
" from tqdm.auto import tqdm\n",
|
| 357 |
+
" n_emos = n_default = n_vad_def = 0\n",
|
| 358 |
+
" with open(out_path, \"w\") as f:\n",
|
| 359 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 360 |
+
" for name in tqdm(dev_names, desc=\"EMOS/CAT(SAILER)+VAD(audeering)\"):\n",
|
| 361 |
+
" sid = stem(name)\n",
|
| 362 |
+
" wav = os.path.join(WAV_DIR, name)\n",
|
| 363 |
+
" # EMOS + CAT từ SAILER\n",
|
| 364 |
+
" probs9 = sailer_probs(wav)\n",
|
| 365 |
+
" if probs9 is None:\n",
|
| 366 |
+
" emos, cat5 = 3.0, np.full(5, 0.2, dtype=np.float32); n_default += 1\n",
|
| 367 |
+
" else:\n",
|
| 368 |
+
" emos = emos_from_probs(probs9, target_map.get(sid))\n",
|
| 369 |
+
" if emos is None:\n",
|
| 370 |
+
" emos = 3.0; n_default += 1\n",
|
| 371 |
+
" else:\n",
|
| 372 |
+
" n_emos += 1\n",
|
| 373 |
+
" cat5 = cat5_from_probs(probs9)\n",
|
| 374 |
+
" # VAD từ audeering\n",
|
| 375 |
+
" vad3 = audeering_vad(wav)\n",
|
| 376 |
+
" if vad3 is None:\n",
|
| 377 |
+
" vad3 = np.array([3.0, 3.0, 3.0], dtype=np.float32); n_vad_def += 1\n",
|
| 378 |
+
" qmos = qmos_scores.get(name, 3.0)\n",
|
| 379 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
|
| 380 |
+
" f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 381 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default} | VAD mặc định {n_vad_def}\")\n",
|
| 382 |
+
"\n",
|
| 383 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 384 |
+
"build_answer(answer_path)"
|
| 385 |
+
]
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"cell_type": "markdown",
|
| 389 |
+
"id": "6afa397f",
|
| 390 |
+
"metadata": {},
|
| 391 |
+
"source": [
|
| 392 |
+
"## 7. Validate + đóng zip"
|
| 393 |
+
]
|
| 394 |
+
},
|
| 395 |
+
{
|
| 396 |
+
"cell_type": "code",
|
| 397 |
+
"execution_count": null,
|
| 398 |
+
"id": "749f1366",
|
| 399 |
+
"metadata": {},
|
| 400 |
+
"outputs": [],
|
| 401 |
+
"source": [
|
| 402 |
+
"def validate(path):\n",
|
| 403 |
+
" import csv\n",
|
| 404 |
+
" with open(path) as f:\n",
|
| 405 |
+
" rows = list(csv.reader(f))\n",
|
| 406 |
+
" header = rows[0]\n",
|
| 407 |
+
" assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
|
| 408 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 409 |
+
" assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
|
| 410 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
|
| 411 |
+
"\n",
|
| 412 |
+
"validate(answer_path)\n",
|
| 413 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp05_vad-audeering.zip answer.txt && unzip -l submission_track2_exp05_vad-audeering.zip\")\n",
|
| 414 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp05_vad-audeering.zip\"))"
|
| 415 |
+
]
|
| 416 |
+
},
|
| 417 |
+
{
|
| 418 |
+
"cell_type": "markdown",
|
| 419 |
+
"id": "69fb16b7",
|
| 420 |
+
"metadata": {},
|
| 421 |
+
"source": [
|
| 422 |
+
"## Ghi chú\n",
|
| 423 |
+
"- **Quan hệ với exp03:** exp03 = SAILER lo cả EMOS+CAT+VAD (giữ nguyên, file `exp03_emos_sailer`).\n",
|
| 424 |
+
" exp05 (file này) chỉ **đổi VAD sang audeering**, EMOS/CAT vẫn SAILER → nộp 2 bản để A/B từng cột VAD.\n",
|
| 425 |
+
"- **Lần đầu** đặt `LIMIT = 20`, kiểm tra VAL/ARO/DOM ∈ [1,5] hợp lý (không toàn 3 / không âm).\n",
|
| 426 |
+
" Nếu giá trị lệch → có thể sai thứ tự arousal/dominance/valence, báo lại để chỉnh.\n",
|
| 427 |
+
"- Khi chạy để ý dòng `backbone: thiếu N key, dư M key`: thiếu/dư vài key phụ là bình thường;\n",
|
| 428 |
+
" thiếu hàng trăm key = sai tiền tố → báo lại.\n",
|
| 429 |
+
"- Nếu audeering thắng VAL nhưng thua ARO/DOM so SAILER → bản tối ưu = trộn cột\n",
|
| 430 |
+
" (VAL từ audeering, ARO/DOM từ exp03). Ghi kết quả vào `docs/04_experiments_log.md` (exp05)."
|
| 431 |
+
]
|
| 432 |
+
}
|
| 433 |
+
],
|
| 434 |
+
"metadata": {
|
| 435 |
+
"jupytext": {
|
| 436 |
+
"cell_metadata_filter": "-all",
|
| 437 |
+
"main_language": "python",
|
| 438 |
+
"notebook_metadata_filter": "-all"
|
| 439 |
+
}
|
| 440 |
+
},
|
| 441 |
+
"nbformat": 4,
|
| 442 |
+
"nbformat_minor": 5
|
| 443 |
+
}
|
track2/exp05_vad_audeering_pipeline.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp05 (VAD bằng audeering MSP-dim) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục tiêu:** đẩy **VAL** (SAILER chỉ 0.341 — thấp nhất) bằng model VAD chuyên
|
| 5 |
+
# `audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim` (dimensional, xuất thẳng
|
| 6 |
+
# arousal/dominance/valence ∈ [0,1]). **Thay cả 3 cột VAD** bằng audeering.
|
| 7 |
+
#
|
| 8 |
+
# ## Phân công model (giữ cái tốt của exp03, chỉ đổi VAD)
|
| 9 |
+
# ```
|
| 10 |
+
# QMOS ← SpeechMOS (UTMOS) (để riêng)
|
| 11 |
+
# EMOS ← SAILER (1 + 4·P(target)) ┐ giữ nguyên exp03
|
| 12 |
+
# CAT ← SAILER (5 lớp renorm) ┘
|
| 13 |
+
# VAL ← audeering ┐
|
| 14 |
+
# ARO ← audeering ├─ THAY cả 3 (model VAD chuyên)
|
| 15 |
+
# DOM ← audeering ┘
|
| 16 |
+
# ```
|
| 17 |
+
# - Mỗi wav chạy **2 forward**: SAILER (EMOS+CAT) + audeering (VAD). KHÔNG train.
|
| 18 |
+
# - So với exp03 (VAD từ SAILER: VAL 0.341 / ARO 0.712 / DOM 0.630) → nộp để A/B từng cột.
|
| 19 |
+
#
|
| 20 |
+
# **Cách chạy Kaggle:** GPU **T4** + Internet **On** → + Add Input dataset Track 2 (có `sets/dev.scp`,
|
| 21 |
+
# `metadata.csv`) → sửa `DATA_ROOT` → lần đầu `LIMIT = 20` kiểm tra VAD ra 1–5 hợp lý → rồi `None`.
|
| 22 |
+
#
|
| 23 |
+
# ⚠️ License **SAILER = Open RAIL** · **audeering = CC BY-NC-SA 4.0** (đều phi thương mại) → khai báo `docs/12_`.
|
| 24 |
+
|
| 25 |
+
# %% [markdown]
|
| 26 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 27 |
+
|
| 28 |
+
# %%
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
|
| 32 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 33 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript → target emotion (cho EMOS)
|
| 34 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV
|
| 35 |
+
|
| 36 |
+
OUT_DIR = "/kaggle/working"
|
| 37 |
+
|
| 38 |
+
DEVICE = "cuda"
|
| 39 |
+
MAX_SECONDS = 15
|
| 40 |
+
SR = 16000
|
| 41 |
+
LIMIT = None # đặt 20 để chạy thử nhanh; None = full DEV
|
| 42 |
+
|
| 43 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 44 |
+
SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
|
| 45 |
+
EMO2SAILER = {"angry": 0, "happy": 4, "neutral": 5, "sad": 6, "surprised": 7}
|
| 46 |
+
|
| 47 |
+
_EMO_ALIAS = {
|
| 48 |
+
"angry": "angry", "anger": "angry",
|
| 49 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 50 |
+
"neutral": "neutral", "calm": "neutral",
|
| 51 |
+
"sad": "sad", "sadness": "sad",
|
| 52 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
def norm_emotion(label):
|
| 56 |
+
key = str(label).strip().lower()
|
| 57 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 58 |
+
|
| 59 |
+
def stem(path_or_name):
|
| 60 |
+
return os.path.splitext(os.path.basename(str(path_or_name)))[0]
|
| 61 |
+
|
| 62 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 63 |
+
for p in [WAV_DIR, METADATA_CSV, DEV_SCP]:
|
| 64 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 65 |
+
|
| 66 |
+
# %% [markdown]
|
| 67 |
+
# ## 1. Cài đặt + tải code SAILER (clone + sys.path, KHÔNG pip install -e .)
|
| 68 |
+
|
| 69 |
+
# %%
|
| 70 |
+
import sys, subprocess
|
| 71 |
+
|
| 72 |
+
def pip_install(*pkgs):
|
| 73 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 74 |
+
|
| 75 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 76 |
+
if not os.path.exists(REPO_DIR):
|
| 77 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 78 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 79 |
+
|
| 80 |
+
pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile", "scipy", "tqdm")
|
| 81 |
+
|
| 82 |
+
if REPO_DIR not in sys.path:
|
| 83 |
+
sys.path.insert(0, REPO_DIR)
|
| 84 |
+
|
| 85 |
+
# %% [markdown]
|
| 86 |
+
# ## 2. Nạp model SAILER (cho EMOS + CAT)
|
| 87 |
+
|
| 88 |
+
# %%
|
| 89 |
+
import torch
|
| 90 |
+
import torch.nn.functional as F
|
| 91 |
+
|
| 92 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 93 |
+
print("Device:", device)
|
| 94 |
+
if device == "cuda":
|
| 95 |
+
print(" ✅ GPU:", torch.cuda.get_device_name(0))
|
| 96 |
+
else:
|
| 97 |
+
print(" ⚠️ KHÔNG thấy GPU → Settings → Accelerator = GPU T4 rồi chạy lại.")
|
| 98 |
+
|
| 99 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 100 |
+
|
| 101 |
+
sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device)
|
| 102 |
+
sailer.eval()
|
| 103 |
+
print("✅ Đã nạp SAILER (wavlm-large-categorical-emotion)")
|
| 104 |
+
|
| 105 |
+
# %% [markdown]
|
| 106 |
+
# ## 2b. Nạp model VAD chuyên: audeering wav2vec2 MSP-dim
|
| 107 |
+
# ⚠️ Kế thừa `Wav2Vec2PreTrainedModel` (theo model card) hay dính lỗi version transformers
|
| 108 |
+
# (thiếu `__file__` / `all_tied_weights_keys`...). Cách dứt điểm: CHỈ dùng `Wav2Vec2Model` (backbone
|
| 109 |
+
# được hỗ trợ tốt) + **tự nạp tay** trọng số regression head từ checkpoint → không đụng tie-weights/experts.
|
| 110 |
+
# ⚠️ Model xuất thứ tự **[arousal, dominance, valence]** ∈ [0,1] → đổi về [VAL,ARO,DOM] thang 1–5 khi ghi.
|
| 111 |
+
|
| 112 |
+
# %%
|
| 113 |
+
import torch.nn as nn
|
| 114 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 115 |
+
from huggingface_hub import hf_hub_download
|
| 116 |
+
|
| 117 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 118 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 119 |
+
|
| 120 |
+
# 1) backbone wav2vec2 (load chuẩn, không subclass)
|
| 121 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 122 |
+
aud_backbone = Wav2Vec2Model(aud_cfg)
|
| 123 |
+
|
| 124 |
+
# 2) tải state_dict gốc của checkpoint (ưu tiên safetensors)
|
| 125 |
+
try:
|
| 126 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 127 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 128 |
+
except Exception:
|
| 129 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 130 |
+
|
| 131 |
+
# 3) nạp phần backbone (key có tiền tố "wav2vec2.") vào Wav2Vec2Model
|
| 132 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 133 |
+
missing, unexpected = aud_backbone.load_state_dict(bb_sd, strict=False)
|
| 134 |
+
print(f" backbone: thiếu {len(missing)} key, dư {len(unexpected)} key (strict=False)")
|
| 135 |
+
|
| 136 |
+
# 4) dựng regression head theo đúng shape trong checkpoint rồi nạp trọng số "classifier.*"
|
| 137 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 138 |
+
_out = _sd["classifier.out_proj.weight"].shape[0] # = 3 (arousal, dominance, valence)
|
| 139 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
|
| 140 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"])
|
| 141 |
+
aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 142 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"])
|
| 143 |
+
aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 144 |
+
|
| 145 |
+
aud_backbone = aud_backbone.to(device).eval()
|
| 146 |
+
aud_head = aud_head.to(device).eval()
|
| 147 |
+
print(f"✅ Đã nạp audeering MSP-dim (backbone + head {_hid}→{_out}) — model VAD chuyên")
|
| 148 |
+
|
| 149 |
+
# %% [markdown]
|
| 150 |
+
# ## 3. Đọc cảm xúc target cho mỗi wav (cho EMOS của SAILER)
|
| 151 |
+
|
| 152 |
+
# %%
|
| 153 |
+
import numpy as np
|
| 154 |
+
import librosa
|
| 155 |
+
|
| 156 |
+
def load_target_emotions():
|
| 157 |
+
tgt = {}
|
| 158 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 159 |
+
for ln in f:
|
| 160 |
+
parts = ln.strip().split("|")
|
| 161 |
+
if len(parts) < 2:
|
| 162 |
+
continue
|
| 163 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 164 |
+
return tgt
|
| 165 |
+
|
| 166 |
+
target_map = load_target_emotions()
|
| 167 |
+
print(f"Target emotions: {len(target_map)} wav | ví dụ:", dict(list(target_map.items())[:3]))
|
| 168 |
+
|
| 169 |
+
# %% [markdown]
|
| 170 |
+
# ## 4. Hàm chấm: SAILER (EMOS+CAT) + audeering (VAD)
|
| 171 |
+
|
| 172 |
+
# %%
|
| 173 |
+
@torch.no_grad()
|
| 174 |
+
def sailer_probs(wav_path):
|
| 175 |
+
"""→ probs9 (float32[9]); None nếu thiếu/lỗi. Chỉ lấy 9 lớp (EMOS+CAT), bỏ VAD của SAILER."""
|
| 176 |
+
if not os.path.exists(wav_path):
|
| 177 |
+
return None
|
| 178 |
+
wave, _ = librosa.load(wav_path, sr=SR, mono=True)
|
| 179 |
+
wave = wave[: MAX_SECONDS * SR]
|
| 180 |
+
data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
|
| 181 |
+
logits, _feat, _det, _aro, _val, _dom = sailer(data, return_feature=True)
|
| 182 |
+
return F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
|
| 183 |
+
|
| 184 |
+
def emos_from_probs(probs9, target):
|
| 185 |
+
if target is None or target not in EMO2SAILER:
|
| 186 |
+
return None
|
| 187 |
+
return 1.0 + 4.0 * float(probs9[EMO2SAILER[target]])
|
| 188 |
+
|
| 189 |
+
def cat5_from_probs(probs9):
|
| 190 |
+
v = np.array([probs9[EMO2SAILER[e]] for e in EMOTIONS5], dtype=np.float32)
|
| 191 |
+
s = v.sum()
|
| 192 |
+
return v / s if s > 0 else np.full(5, 0.2, dtype=np.float32)
|
| 193 |
+
|
| 194 |
+
@torch.no_grad()
|
| 195 |
+
def audeering_vad(wav_path):
|
| 196 |
+
"""VAD bằng audeering → [VAL, ARO, DOM] thang 1–5; None nếu thiếu/lỗi.
|
| 197 |
+
Model xuất [arousal, dominance, valence] ∈ [0,1]."""
|
| 198 |
+
if not os.path.exists(wav_path):
|
| 199 |
+
return None
|
| 200 |
+
wave, _ = librosa.load(wav_path, sr=SR, mono=True)
|
| 201 |
+
wave = wave[: MAX_SECONDS * SR]
|
| 202 |
+
x = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 203 |
+
x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
|
| 204 |
+
h = aud_backbone(x)[0].mean(dim=1) # mean-pool theo thời gian
|
| 205 |
+
out = aud_head(h)[0].detach().cpu().numpy() # [arousal, dominance, valence]
|
| 206 |
+
aro, dom, val = float(out[0]), float(out[1]), float(out[2])
|
| 207 |
+
return np.array([1 + 4 * val, 1 + 4 * aro, 1 + 4 * dom], dtype=np.float32) # [VAL,ARO,DOM]
|
| 208 |
+
|
| 209 |
+
# %% [markdown]
|
| 210 |
+
# ## 5. QMOS = SpeechMOS (UTMOS) — bắt buộc cho answer.txt
|
| 211 |
+
|
| 212 |
+
# %%
|
| 213 |
+
@torch.no_grad()
|
| 214 |
+
def run_qmos(names):
|
| 215 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).to(device).eval()
|
| 216 |
+
from tqdm.auto import tqdm
|
| 217 |
+
out = {}
|
| 218 |
+
for n in tqdm(names, desc="QMOS"):
|
| 219 |
+
p = os.path.join(WAV_DIR, n)
|
| 220 |
+
if not os.path.exists(p):
|
| 221 |
+
continue
|
| 222 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 223 |
+
x = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 224 |
+
out[n] = float(predictor(x, sr=SR).mean().item())
|
| 225 |
+
return out
|
| 226 |
+
|
| 227 |
+
# %% [markdown]
|
| 228 |
+
# ## 6. Chạy trên DEV → `answer.txt` (QMOS, EMOS, CAT ← SAILER/UTMOS · VAL,ARO,DOM ← audeering)
|
| 229 |
+
|
| 230 |
+
# %%
|
| 231 |
+
def list_dev():
|
| 232 |
+
with open(DEV_SCP) as f:
|
| 233 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 234 |
+
|
| 235 |
+
dev_names = list_dev()
|
| 236 |
+
if LIMIT:
|
| 237 |
+
dev_names = dev_names[:LIMIT]
|
| 238 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 239 |
+
|
| 240 |
+
qmos_scores = run_qmos(dev_names)
|
| 241 |
+
|
| 242 |
+
def fmt_cat(probs5):
|
| 243 |
+
return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 244 |
+
|
| 245 |
+
def build_answer(out_path):
|
| 246 |
+
from tqdm.auto import tqdm
|
| 247 |
+
n_emos = n_default = n_vad_def = 0
|
| 248 |
+
with open(out_path, "w") as f:
|
| 249 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 250 |
+
for name in tqdm(dev_names, desc="EMOS/CAT(SAILER)+VAD(audeering)"):
|
| 251 |
+
sid = stem(name)
|
| 252 |
+
wav = os.path.join(WAV_DIR, name)
|
| 253 |
+
# EMOS + CAT từ SAILER
|
| 254 |
+
probs9 = sailer_probs(wav)
|
| 255 |
+
if probs9 is None:
|
| 256 |
+
emos, cat5 = 3.0, np.full(5, 0.2, dtype=np.float32); n_default += 1
|
| 257 |
+
else:
|
| 258 |
+
emos = emos_from_probs(probs9, target_map.get(sid))
|
| 259 |
+
if emos is None:
|
| 260 |
+
emos = 3.0; n_default += 1
|
| 261 |
+
else:
|
| 262 |
+
n_emos += 1
|
| 263 |
+
cat5 = cat5_from_probs(probs9)
|
| 264 |
+
# VAD từ audeering
|
| 265 |
+
vad3 = audeering_vad(wav)
|
| 266 |
+
if vad3 is None:
|
| 267 |
+
vad3 = np.array([3.0, 3.0, 3.0], dtype=np.float32); n_vad_def += 1
|
| 268 |
+
qmos = qmos_scores.get(name, 3.0)
|
| 269 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
|
| 270 |
+
f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 271 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | EMOS thật {n_emos}, mặc định {n_default} | VAD mặc định {n_vad_def}")
|
| 272 |
+
|
| 273 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 274 |
+
build_answer(answer_path)
|
| 275 |
+
|
| 276 |
+
# %% [markdown]
|
| 277 |
+
# ## 7. Validate + đóng zip
|
| 278 |
+
|
| 279 |
+
# %%
|
| 280 |
+
def validate(path):
|
| 281 |
+
import csv
|
| 282 |
+
with open(path) as f:
|
| 283 |
+
rows = list(csv.reader(f))
|
| 284 |
+
header = rows[0]
|
| 285 |
+
assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
|
| 286 |
+
for i, r in enumerate(rows[1:], 2):
|
| 287 |
+
assert len(r) == len(header), f"Dòng {i} sai số cột"
|
| 288 |
+
print(f"OK: {len(rows)-1} dòng, header = {header}")
|
| 289 |
+
|
| 290 |
+
validate(answer_path)
|
| 291 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp05_vad-audeering.zip answer.txt && unzip -l submission_track2_exp05_vad-audeering.zip")
|
| 292 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp05_vad-audeering.zip"))
|
| 293 |
+
|
| 294 |
+
# %% [markdown]
|
| 295 |
+
# ## Ghi chú
|
| 296 |
+
# - **Quan hệ với exp03:** exp03 = SAILER lo cả EMOS+CAT+VAD (giữ nguyên, file `exp03_emos_sailer`).
|
| 297 |
+
# exp05 (file này) chỉ **đổi VAD sang audeering**, EMOS/CAT vẫn SAILER → nộp 2 bản để A/B từng cột VAD.
|
| 298 |
+
# - **Lần đầu** đặt `LIMIT = 20`, kiểm tra VAL/ARO/DOM ∈ [1,5] hợp lý (không toàn 3 / không âm).
|
| 299 |
+
# Nếu giá trị lệch → có thể sai thứ tự arousal/dominance/valence, báo lại để chỉnh.
|
| 300 |
+
# - Khi chạy để ý dòng `backbone: thiếu N key, dư M key`: thiếu/dư vài key phụ là bình thường;
|
| 301 |
+
# thiếu hàng trăm key = sai tiền tố → báo lại.
|
| 302 |
+
# - Nếu audeering thắng VAL nhưng thua ARO/DOM so SAILER → bản tối ưu = trộn cột
|
| 303 |
+
# (VAL từ audeering, ARO/DOM từ exp03). Ghi kết quả vào `docs/04_experiments_log.md` (exp05).
|
track2/exp06_qmos_train.ipynb
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "e2d94d72",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp06 (TRAIN QMOS head) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục tiêu:** QMOS là cột **duy nhất chưa train** (đang dùng UTMOS zero-shot → SRCC kẹt 0.414).\n",
|
| 11 |
+
"`train.csv` CÓ sẵn cột `qMOS` → ta train 1 **head hồi quy nhỏ** trên đặc trưng SSL (đã cache ở exp04)\n",
|
| 12 |
+
"để vượt 0.414.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Ý tưởng (đọc 1 lần cho hiểu)\n",
|
| 15 |
+
"- Tái dùng đặc trưng **emotion2vec + SAILER** đã trích & cache trong `fusion_cache/` (exp04) → KHÔNG trích lại.\n",
|
| 16 |
+
"- Thêm **chính điểm UTMOS** (SpeechMOS) làm 1 đặc trưng đầu vào → head chỉ cần **học chỉnh sửa (residual)**\n",
|
| 17 |
+
" quanh 0.414 thay vì học lại từ đầu → an toàn, gần như chắc chắn ≥ UTMOS đơn lẻ.\n",
|
| 18 |
+
"- Nhãn vàng QMOS = **TB `qMOS` theo wav** (gộp các listener trong `train.csv`).\n",
|
| 19 |
+
"- Có **val nội bộ 10%** → đo SRCC, so thẳng với UTMOS trên CÙNG tập val → biết có cải thiện thật\n",
|
| 20 |
+
" **trước khi** tốn lượt nộp CodaBench.\n",
|
| 21 |
+
"- Cuối cùng: **GIỮ NGUYÊN exp04** (5 cột cảm xúc đang thắng), chỉ **thay cột QMOS** trong `answer.txt`.\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"```\n",
|
| 24 |
+
" mỗi wav ─► [e2v_emb | e2v_probs5 | sailer_emb | sailer_probs9 | sailer_vad3 | UTMOS] ─► MLP ─► QMOS\n",
|
| 25 |
+
" (head train)\n",
|
| 26 |
+
"```\n",
|
| 27 |
+
"\n",
|
| 28 |
+
"**Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**.\n",
|
| 29 |
+
"+ Add Input: (1) dataset Track 2 (15.477 wav, có `sets/train.csv`) ; (2) — nếu có — dataset chứa\n",
|
| 30 |
+
"`fusion_cache/*.npz` đã Save Version ở exp04 (đỡ ~15') ; (3) file `answer.txt` của exp04 để ghép cột.\n",
|
| 31 |
+
"Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để bắt lỗi setup, OK rồi đặt `None`."
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "markdown",
|
| 36 |
+
"id": "b42d5d49",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"source": [
|
| 39 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 40 |
+
]
|
| 41 |
+
},
|
| 42 |
+
{
|
| 43 |
+
"cell_type": "code",
|
| 44 |
+
"execution_count": null,
|
| 45 |
+
"id": "93e29194",
|
| 46 |
+
"metadata": {},
|
| 47 |
+
"outputs": [],
|
| 48 |
+
"source": [
|
| 49 |
+
"import os\n",
|
| 50 |
+
"\n",
|
| 51 |
+
"# ── Data Track 2 ─────────────────────────────────────────────────────────────\n",
|
| 52 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
|
| 53 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 54 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # nhãn người nghe: lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
|
| 55 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav tập DEV\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 58 |
+
"# Dùng CHUNG cache với exp04. Nếu đã Save Version cache ở exp04, trỏ CACHE_DIR vào dataset đó\n",
|
| 59 |
+
"# (vd \"/kaggle/input/<slug-cache>/fusion_cache\") để khỏi trích lại; nếu không, để mặc định sẽ tự trích.\n",
|
| 60 |
+
"CACHE_DIR = \"/kaggle/working/fusion_cache\"\n",
|
| 61 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"# File answer.txt của exp04 (5 cột cảm xúc đang thắng) để GHÉP cột QMOS mới vào.\n",
|
| 64 |
+
"# Trỏ tới nơi bạn đặt file exp04. Nếu không có, notebook vẫn xuất qmos_dev.csv riêng + cảnh báo.\n",
|
| 65 |
+
"EXP04_ANSWER = \"/kaggle/input/exp04-answer/answer.txt\" # << SỬA; hoặc \"/kaggle/working/answer.txt\"\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"# ── Đặc trưng dùng cho QMOS ──────────────────────────────────────────────────\n",
|
| 68 |
+
"USE_E2V = True # nối embedding emotion2vec\n",
|
| 69 |
+
"USE_SAILER = True # nối embedding SAILER/WavLM\n",
|
| 70 |
+
"USE_CLASSPROB = True # nối thêm xác suất lớp (e2v5 + sailer9 + vad3)\n",
|
| 71 |
+
"USE_UTMOS_FEAT = True # nối thêm điểm UTMOS làm 1 đặc trưng (neo residual quanh 0.414)\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"# ── Siêu tham số train head ──────────────────────────────────────────────────\n",
|
| 74 |
+
"DEVICE = \"cuda\"\n",
|
| 75 |
+
"HIDDEN = 256\n",
|
| 76 |
+
"DROPOUT = 0.3\n",
|
| 77 |
+
"LR = 1e-3\n",
|
| 78 |
+
"EPOCHS = 120\n",
|
| 79 |
+
"BATCH = 64\n",
|
| 80 |
+
"VAL_FRAC = 0.10\n",
|
| 81 |
+
"PATIENCE = 20\n",
|
| 82 |
+
"SEED = 42\n",
|
| 83 |
+
"RANK_LAMBDA = 0.0 # 0 = chỉ MSE. >0 (vd 0.2) = cộng thêm pairwise ranking loss (tối ưu thứ hạng=SRCC)\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"LIMIT_TRAIN = None # số nhỏ (vd 300) để chạy thử; None = full\n",
|
| 86 |
+
"LIMIT_DEV = None\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"def stem(p):\n",
|
| 89 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"assert USE_E2V or USE_SAILER or USE_UTMOS_FEAT, \"Phải bật ít nhất 1 nguồn đặc trưng.\"\n",
|
| 92 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 93 |
+
"for p in [WAV_DIR, TRAIN_CSV, DEV_SCP]:\n",
|
| 94 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "markdown",
|
| 99 |
+
"id": "47ac221d",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"source": [
|
| 102 |
+
"## 1. Cài đặt + (nếu cần) tải code SAILER\n",
|
| 103 |
+
"emotion2vec qua `funasr`; SAILER cần `WavLMWrapper` trong repo `vox-profile-release` (clone + sys.path).\n",
|
| 104 |
+
"Nếu cache đã đủ thì các model này sẽ KHÔNG được nạp (chỉ nạp khi còn file phải trích)."
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": null,
|
| 110 |
+
"id": "99ba1947",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [],
|
| 113 |
+
"source": [
|
| 114 |
+
"import sys, subprocess\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"def pip_install(*pkgs):\n",
|
| 117 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"pip_install(\"speechmos\", \"funasr\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"if USE_SAILER:\n",
|
| 122 |
+
" pip_install(\"loralib\", \"speechbrain\")\n",
|
| 123 |
+
" REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 124 |
+
" if not os.path.exists(REPO_DIR):\n",
|
| 125 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 126 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 127 |
+
" if REPO_DIR not in sys.path:\n",
|
| 128 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "markdown",
|
| 133 |
+
"id": "ac9dcefc",
|
| 134 |
+
"metadata": {},
|
| 135 |
+
"source": [
|
| 136 |
+
"## 2. Nhãn vàng QMOS (gộp `qMOS` theo wavID)"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "code",
|
| 141 |
+
"execution_count": null,
|
| 142 |
+
"id": "db4a41a5",
|
| 143 |
+
"metadata": {},
|
| 144 |
+
"outputs": [],
|
| 145 |
+
"source": [
|
| 146 |
+
"import numpy as np\n",
|
| 147 |
+
"import pandas as pd\n",
|
| 148 |
+
"\n",
|
| 149 |
+
"def load_qmos_labels():\n",
|
| 150 |
+
" \"\"\"train.csv (sep='|') → DataFrame [wavID, qmos] với qmos = TB theo wav.\"\"\"\n",
|
| 151 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 152 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 153 |
+
" wav_col = cols.get(\"wavid\") or cols.get(\"wav\") or list(df.columns)[1]\n",
|
| 154 |
+
" qmos_col = cols.get(\"qmos\") or cols.get(\"qMOS\".lower()) or cols.get(\"mos\")\n",
|
| 155 |
+
" assert qmos_col, f\"Không thấy cột qMOS trong train.csv (cột: {list(df.columns)})\"\n",
|
| 156 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 157 |
+
" g = df.groupby(\"_stem\")[qmos_col].mean().reset_index()\n",
|
| 158 |
+
" g.columns = [\"wavID\", \"qmos\"]\n",
|
| 159 |
+
" return g\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"qmos_df = load_qmos_labels()\n",
|
| 162 |
+
"print(f\"wav train (gộp): {len(qmos_df)}\")\n",
|
| 163 |
+
"print(\"qMOS:\", qmos_df[\"qmos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
|
| 164 |
+
"qmos_df.head()"
|
| 165 |
+
]
|
| 166 |
+
},
|
| 167 |
+
{
|
| 168 |
+
"cell_type": "markdown",
|
| 169 |
+
"id": "dfd7df0c",
|
| 170 |
+
"metadata": {},
|
| 171 |
+
"source": [
|
| 172 |
+
"## 3. Trích / nạp đặc trưng (cache CHUNG với exp04) + điểm UTMOS\n",
|
| 173 |
+
"- `extract_e2v` / `extract_sailer`: y hệt exp04, cache `e2v_<tag>.npz` / `sailer_<tag>.npz`.\n",
|
| 174 |
+
"- `extract_utmos`: chấm UTMOS từng wav → cache `utmos_<tag>.npz` (dùng vừa làm đặc trưng, vừa làm baseline so sánh)."
|
| 175 |
+
]
|
| 176 |
+
},
|
| 177 |
+
{
|
| 178 |
+
"cell_type": "code",
|
| 179 |
+
"execution_count": null,
|
| 180 |
+
"id": "ec1e63a1",
|
| 181 |
+
"metadata": {
|
| 182 |
+
"lines_to_next_cell": 1
|
| 183 |
+
},
|
| 184 |
+
"outputs": [],
|
| 185 |
+
"source": [
|
| 186 |
+
"import torch\n",
|
| 187 |
+
"import torch.nn.functional as F\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 190 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU\")\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"def extract_e2v(stems, tag):\n",
|
| 195 |
+
" \"\"\"→ dict {stem: emb_full[D1+5]}. Cache CACHE_DIR/e2v_<tag>.npz (giống exp04).\"\"\"\n",
|
| 196 |
+
" from tqdm.auto import tqdm\n",
|
| 197 |
+
" cache_path = os.path.join(CACHE_DIR, f\"e2v_{tag}.npz\")\n",
|
| 198 |
+
" store = {}\n",
|
| 199 |
+
" if os.path.exists(cache_path):\n",
|
| 200 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 201 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 202 |
+
" print(f\"[e2v/{tag}] nạp cache: {len(store)}\")\n",
|
| 203 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 204 |
+
" if todo:\n",
|
| 205 |
+
" from funasr import AutoModel\n",
|
| 206 |
+
" m = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\", device=device)\n",
|
| 207 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"e2v {tag}\")):\n",
|
| 208 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 209 |
+
" if not os.path.exists(wav):\n",
|
| 210 |
+
" continue\n",
|
| 211 |
+
" r = m.generate(wav, granularity=\"utterance\", extract_embedding=True)[0]\n",
|
| 212 |
+
" emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
|
| 213 |
+
" probs = {e: 0.0 for e in EMOTIONS5}\n",
|
| 214 |
+
" for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
|
| 215 |
+
" name = lab.split(\"/\")[-1]\n",
|
| 216 |
+
" if name in probs:\n",
|
| 217 |
+
" probs[name] = float(sc)\n",
|
| 218 |
+
" tot = sum(probs.values())\n",
|
| 219 |
+
" p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)\n",
|
| 220 |
+
" store[s] = np.concatenate([emb, p5]).astype(np.float32)\n",
|
| 221 |
+
" if (i + 1) % 500 == 0:\n",
|
| 222 |
+
" np.savez(cache_path, **store)\n",
|
| 223 |
+
" np.savez(cache_path, **store)\n",
|
| 224 |
+
" del m\n",
|
| 225 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 226 |
+
" return store # mỗi value = [D1 | 5]\n",
|
| 227 |
+
"\n",
|
| 228 |
+
"def _pool_feat(features):\n",
|
| 229 |
+
" f = features.detach().cpu().numpy()\n",
|
| 230 |
+
" if f.ndim <= 1:\n",
|
| 231 |
+
" return f.reshape(-1).astype(np.float32)\n",
|
| 232 |
+
" return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"def extract_sailer(stems, tag):\n",
|
| 235 |
+
" \"\"\"→ dict {stem: vec[D2+9+3]}. Cache CACHE_DIR/sailer_<tag>.npz (giống exp04).\"\"\"\n",
|
| 236 |
+
" import librosa\n",
|
| 237 |
+
" from tqdm.auto import tqdm\n",
|
| 238 |
+
" cache_path = os.path.join(CACHE_DIR, f\"sailer_{tag}.npz\")\n",
|
| 239 |
+
" store = {}\n",
|
| 240 |
+
" if os.path.exists(cache_path):\n",
|
| 241 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 242 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 243 |
+
" print(f\"[sailer/{tag}] nạp cache: {len(store)}\")\n",
|
| 244 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 245 |
+
" if todo:\n",
|
| 246 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
|
| 247 |
+
" sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device).eval()\n",
|
| 248 |
+
" with torch.no_grad():\n",
|
| 249 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"sailer {tag}\")):\n",
|
| 250 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 251 |
+
" if not os.path.exists(wav):\n",
|
| 252 |
+
" continue\n",
|
| 253 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 254 |
+
" wave = wave[: 15 * 16000]\n",
|
| 255 |
+
" data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
|
| 256 |
+
" logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
|
| 257 |
+
" emb = _pool_feat(feat)\n",
|
| 258 |
+
" p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
|
| 259 |
+
" vad3 = np.array([1 + 4 * float(valence.item()),\n",
|
| 260 |
+
" 1 + 4 * float(arousal.item()),\n",
|
| 261 |
+
" 1 + 4 * float(dominance.item())], dtype=np.float32)\n",
|
| 262 |
+
" store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)\n",
|
| 263 |
+
" if (i + 1) % 500 == 0:\n",
|
| 264 |
+
" np.savez(cache_path, **store)\n",
|
| 265 |
+
" np.savez(cache_path, **store)\n",
|
| 266 |
+
" del sailer\n",
|
| 267 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 268 |
+
" return store # mỗi value = [D2 | 9 | 3]\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"def extract_utmos(names, tag):\n",
|
| 271 |
+
" \"\"\"Chấm UTMOS từng wav (theo TÊN file, vì DEV gọi .wav theo tên). → dict {stem: score}.\n",
|
| 272 |
+
" Cache CACHE_DIR/utmos_<tag>.npz. Dùng vừa làm đặc trưng vừa làm baseline so sánh.\"\"\"\n",
|
| 273 |
+
" import librosa\n",
|
| 274 |
+
" from tqdm.auto import tqdm\n",
|
| 275 |
+
" cache_path = os.path.join(CACHE_DIR, f\"utmos_{tag}.npz\")\n",
|
| 276 |
+
" store = {}\n",
|
| 277 |
+
" if os.path.exists(cache_path):\n",
|
| 278 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 279 |
+
" store = {k: float(z[k]) for k in z.files}\n",
|
| 280 |
+
" print(f\"[utmos/{tag}] nạp cache: {len(store)}\")\n",
|
| 281 |
+
" todo = [n for n in names if stem(n) not in store]\n",
|
| 282 |
+
" if todo:\n",
|
| 283 |
+
" predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\",\n",
|
| 284 |
+
" trust_repo=True).to(device).eval()\n",
|
| 285 |
+
" with torch.no_grad():\n",
|
| 286 |
+
" for i, n in enumerate(tqdm(todo, desc=f\"utmos {tag}\")):\n",
|
| 287 |
+
" wav = os.path.join(WAV_DIR, n if n.endswith(\".wav\") else n + \".wav\")\n",
|
| 288 |
+
" if not os.path.exists(wav):\n",
|
| 289 |
+
" continue\n",
|
| 290 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 291 |
+
" sc = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device), sr=16000).mean().item())\n",
|
| 292 |
+
" store[stem(n)] = sc\n",
|
| 293 |
+
" if (i + 1) % 500 == 0:\n",
|
| 294 |
+
" np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 295 |
+
" np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 296 |
+
" del predictor\n",
|
| 297 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 298 |
+
" return store"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"cell_type": "markdown",
|
| 303 |
+
"id": "aed7338b",
|
| 304 |
+
"metadata": {},
|
| 305 |
+
"source": [
|
| 306 |
+
"## 4. Dựng feature + nhãn cho train"
|
| 307 |
+
]
|
| 308 |
+
},
|
| 309 |
+
{
|
| 310 |
+
"cell_type": "code",
|
| 311 |
+
"execution_count": null,
|
| 312 |
+
"id": "c09bb508",
|
| 313 |
+
"metadata": {},
|
| 314 |
+
"outputs": [],
|
| 315 |
+
"source": [
|
| 316 |
+
"train_stems = list(qmos_df[\"wavID\"])\n",
|
| 317 |
+
"if LIMIT_TRAIN:\n",
|
| 318 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 319 |
+
"\n",
|
| 320 |
+
"e2v_tr = extract_e2v(train_stems, \"train\") if USE_E2V else {}\n",
|
| 321 |
+
"sailer_tr = extract_sailer(train_stems, \"train\") if USE_SAILER else {}\n",
|
| 322 |
+
"utmos_tr = extract_utmos(train_stems, \"train\") if USE_UTMOS_FEAT else {}\n",
|
| 323 |
+
"\n",
|
| 324 |
+
"def qmos_feature(sid, e2v_map, sailer_map, utmos_map):\n",
|
| 325 |
+
" \"\"\"Nối đặc trưng QMOS cho 1 wav. None nếu thiếu phần bắt buộc.\"\"\"\n",
|
| 326 |
+
" parts = []\n",
|
| 327 |
+
" if USE_E2V:\n",
|
| 328 |
+
" v = e2v_map.get(sid)\n",
|
| 329 |
+
" if v is None:\n",
|
| 330 |
+
" return None\n",
|
| 331 |
+
" parts.append(v[:-5]) # emb e2v\n",
|
| 332 |
+
" if USE_CLASSPROB:\n",
|
| 333 |
+
" parts.append(v[-5:]) # probs5\n",
|
| 334 |
+
" if USE_SAILER:\n",
|
| 335 |
+
" v = sailer_map.get(sid)\n",
|
| 336 |
+
" if v is None:\n",
|
| 337 |
+
" return None\n",
|
| 338 |
+
" parts.append(v[:-12]) # emb sailer\n",
|
| 339 |
+
" if USE_CLASSPROB:\n",
|
| 340 |
+
" parts.append(v[-12:]) # probs9 + vad3\n",
|
| 341 |
+
" if USE_UTMOS_FEAT:\n",
|
| 342 |
+
" u = utmos_map.get(sid)\n",
|
| 343 |
+
" if u is None:\n",
|
| 344 |
+
" return None\n",
|
| 345 |
+
" parts.append(np.array([u], dtype=np.float32))\n",
|
| 346 |
+
" return np.concatenate(parts).astype(np.float32)\n",
|
| 347 |
+
"\n",
|
| 348 |
+
"lab = qmos_df.set_index(\"wavID\")[\"qmos\"]\n",
|
| 349 |
+
"X, y = [], []\n",
|
| 350 |
+
"for s in train_stems:\n",
|
| 351 |
+
" f = qmos_feature(s, e2v_tr, sailer_tr, utmos_tr)\n",
|
| 352 |
+
" if f is None or s not in lab.index:\n",
|
| 353 |
+
" continue\n",
|
| 354 |
+
" X.append(f)\n",
|
| 355 |
+
" y.append(float(lab.loc[s]))\n",
|
| 356 |
+
"\n",
|
| 357 |
+
"X = np.stack(X).astype(np.float32)\n",
|
| 358 |
+
"y = np.array(y, dtype=np.float32)\n",
|
| 359 |
+
"FEAT_DIM = X.shape[1]\n",
|
| 360 |
+
"print(f\"Train: X={X.shape} y={y.shape}\")\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"feat_mean = X.mean(0, keepdims=True)\n",
|
| 363 |
+
"feat_std = X.std(0, keepdims=True) + 1e-6\n",
|
| 364 |
+
"Xn = (X - feat_mean) / feat_std\n",
|
| 365 |
+
"y_mu, y_sd = float(y.mean()), float(y.std() + 1e-6)\n",
|
| 366 |
+
"yn = (y - y_mu) / y_sd"
|
| 367 |
+
]
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"cell_type": "markdown",
|
| 371 |
+
"id": "82cc65f8",
|
| 372 |
+
"metadata": {},
|
| 373 |
+
"source": [
|
| 374 |
+
"## 5. Train head QMOS + so với UTMOS trên CÙNG val nội bộ\n",
|
| 375 |
+
"- Head = MLP nhỏ (`Linear→ReLU→Dropout ×2 → 1`). Loss = MSE (+ tùy chọn pairwise ranking).\n",
|
| 376 |
+
"- In **SRCC head** và **SRCC UTMOS** trên cùng tập val → biết head có thật sự vượt 0.414 không."
|
| 377 |
+
]
|
| 378 |
+
},
|
| 379 |
+
{
|
| 380 |
+
"cell_type": "code",
|
| 381 |
+
"execution_count": null,
|
| 382 |
+
"id": "324ab564",
|
| 383 |
+
"metadata": {
|
| 384 |
+
"lines_to_next_cell": 1
|
| 385 |
+
},
|
| 386 |
+
"outputs": [],
|
| 387 |
+
"source": [
|
| 388 |
+
"import torch.nn as nn\n",
|
| 389 |
+
"from scipy.stats import spearmanr\n",
|
| 390 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 391 |
+
"\n",
|
| 392 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 393 |
+
"idx_all = np.arange(X.shape[0])\n",
|
| 394 |
+
"tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"def to_t(a):\n",
|
| 397 |
+
" return torch.tensor(a, dtype=torch.float32, device=device)\n",
|
| 398 |
+
"\n",
|
| 399 |
+
"Xn_t = to_t(Xn); yn_t = to_t(yn).unsqueeze(1)\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"class QMOSHead(nn.Module):\n",
|
| 402 |
+
" def __init__(self, d_in, h, p):\n",
|
| 403 |
+
" super().__init__()\n",
|
| 404 |
+
" self.net = nn.Sequential(\n",
|
| 405 |
+
" nn.Linear(d_in, h), nn.ReLU(), nn.Dropout(p),\n",
|
| 406 |
+
" nn.Linear(h, h), nn.ReLU(), nn.Dropout(p),\n",
|
| 407 |
+
" nn.Linear(h, 1),\n",
|
| 408 |
+
" )\n",
|
| 409 |
+
"\n",
|
| 410 |
+
" def forward(self, x):\n",
|
| 411 |
+
" return self.net(x)\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"model = QMOSHead(FEAT_DIM, HIDDEN, DROPOUT).to(device)\n",
|
| 414 |
+
"opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)\n",
|
| 415 |
+
"mse = nn.MSELoss()\n",
|
| 416 |
+
"\n",
|
| 417 |
+
"def pairwise_rank_loss(pred, target):\n",
|
| 418 |
+
" \"\"\"Khuyến khích pred xếp hạng giống target (margin ranking trên các cặp trong batch).\"\"\"\n",
|
| 419 |
+
" n = pred.shape[0]\n",
|
| 420 |
+
" if n < 2:\n",
|
| 421 |
+
" return torch.zeros((), device=device)\n",
|
| 422 |
+
" pi, pj = pred.unsqueeze(0), pred.unsqueeze(1)\n",
|
| 423 |
+
" ti, tj = target.unsqueeze(0), target.unsqueeze(1)\n",
|
| 424 |
+
" sign = torch.sign(ti - tj) # +1 nếu i nên cao hơn j\n",
|
| 425 |
+
" diff = pi - pj\n",
|
| 426 |
+
" # hinge: phạt khi thứ tự sai\n",
|
| 427 |
+
" return torch.relu(-sign * diff).mean()\n",
|
| 428 |
+
"\n",
|
| 429 |
+
"@torch.no_grad()\n",
|
| 430 |
+
"def eval_val():\n",
|
| 431 |
+
" model.eval()\n",
|
| 432 |
+
" p = model(Xn_t[va_idx]).cpu().numpy().ravel()\n",
|
| 433 |
+
" srcc_head = spearmanr(p, y[va_idx]).correlation\n",
|
| 434 |
+
" out = {\"head\": float(srcc_head)}\n",
|
| 435 |
+
" if USE_UTMOS_FEAT:\n",
|
| 436 |
+
" u = X[va_idx, -1] # cột UTMOS (đặc trưng cuối, chưa chuẩn hóa)\n",
|
| 437 |
+
" out[\"utmos\"] = float(spearmanr(u, y[va_idx]).correlation)\n",
|
| 438 |
+
" return out\n",
|
| 439 |
+
"\n",
|
| 440 |
+
"best, best_state, bad = -1e9, None, 0\n",
|
| 441 |
+
"tr_t = torch.tensor(tr_idx, device=device)\n",
|
| 442 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 443 |
+
" model.train()\n",
|
| 444 |
+
" perm = tr_t[torch.randperm(len(tr_t), device=device)]\n",
|
| 445 |
+
" run = 0.0\n",
|
| 446 |
+
" for i in range(0, len(perm), BATCH):\n",
|
| 447 |
+
" b = perm[i:i + BATCH]\n",
|
| 448 |
+
" opt.zero_grad()\n",
|
| 449 |
+
" pred = model(Xn_t[b])\n",
|
| 450 |
+
" loss = mse(pred, yn_t[b])\n",
|
| 451 |
+
" if RANK_LAMBDA > 0:\n",
|
| 452 |
+
" loss = loss + RANK_LAMBDA * pairwise_rank_loss(pred.ravel(), yn_t[b].ravel())\n",
|
| 453 |
+
" loss.backward(); opt.step()\n",
|
| 454 |
+
" run += loss.item() * len(b)\n",
|
| 455 |
+
" m = eval_val()\n",
|
| 456 |
+
" if m[\"head\"] > best:\n",
|
| 457 |
+
" best = m[\"head\"]\n",
|
| 458 |
+
" best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
|
| 459 |
+
" bad = 0\n",
|
| 460 |
+
" else:\n",
|
| 461 |
+
" bad += 1\n",
|
| 462 |
+
" if ep % 5 == 0 or ep == 1:\n",
|
| 463 |
+
" extra = f\" | UTMOS={m['utmos']:.4f}\" if \"utmos\" in m else \"\"\n",
|
| 464 |
+
" print(f\"epoch {ep:3d} | loss {run/len(perm):.4f} | head SRCC={m['head']:.4f}{extra} | best {best:.4f}\")\n",
|
| 465 |
+
" if bad >= PATIENCE:\n",
|
| 466 |
+
" print(f\"Early stop ở epoch {ep}.\")\n",
|
| 467 |
+
" break\n",
|
| 468 |
+
"\n",
|
| 469 |
+
"model.load_state_dict(best_state)\n",
|
| 470 |
+
"final = eval_val()\n",
|
| 471 |
+
"print(\"\\n✅ VAL (nội bộ):\")\n",
|
| 472 |
+
"print(f\" QMOS head SRCC = {final['head']:.4f}\")\n",
|
| 473 |
+
"if \"utmos\" in final:\n",
|
| 474 |
+
" print(f\" UTMOS baseline = {final['utmos']:.4f} (mốc leaderboard 0.414)\")\n",
|
| 475 |
+
" print(\" →\", \"✅ HEAD VƯỢT UTMOS\" if final[\"head\"] > final[\"utmos\"] else \"⚠️ chưa vượt — thử tăng EPOCHS / RANK_LAMBDA / bật thêm đặc trưng\")\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
|
| 478 |
+
" \"y_mu\": y_mu, \"y_sd\": y_sd, \"FEAT_DIM\": FEAT_DIM,\n",
|
| 479 |
+
" \"USE_E2V\": USE_E2V, \"USE_SAILER\": USE_SAILER,\n",
|
| 480 |
+
" \"USE_CLASSPROB\": USE_CLASSPROB, \"USE_UTMOS_FEAT\": USE_UTMOS_FEAT,\n",
|
| 481 |
+
" \"val_srcc\": best}, os.path.join(OUT_DIR, \"qmos_head.pt\"))\n",
|
| 482 |
+
"print(\"Đã lưu\", os.path.join(OUT_DIR, \"qmos_head.pt\"))"
|
| 483 |
+
]
|
| 484 |
+
},
|
| 485 |
+
{
|
| 486 |
+
"cell_type": "markdown",
|
| 487 |
+
"id": "d33a7aca",
|
| 488 |
+
"metadata": {},
|
| 489 |
+
"source": [
|
| 490 |
+
"## 6. Dự đoán QMOS cho DEV"
|
| 491 |
+
]
|
| 492 |
+
},
|
| 493 |
+
{
|
| 494 |
+
"cell_type": "code",
|
| 495 |
+
"execution_count": null,
|
| 496 |
+
"id": "69efbd00",
|
| 497 |
+
"metadata": {
|
| 498 |
+
"lines_to_next_cell": 1
|
| 499 |
+
},
|
| 500 |
+
"outputs": [],
|
| 501 |
+
"source": [
|
| 502 |
+
"def list_dev():\n",
|
| 503 |
+
" with open(DEV_SCP) as f:\n",
|
| 504 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 505 |
+
"\n",
|
| 506 |
+
"dev_names = list_dev()\n",
|
| 507 |
+
"if LIMIT_DEV:\n",
|
| 508 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 509 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 510 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 511 |
+
"\n",
|
| 512 |
+
"e2v_dev = extract_e2v(dev_stems, \"dev\") if USE_E2V else {}\n",
|
| 513 |
+
"sailer_dev = extract_sailer(dev_stems, \"dev\") if USE_SAILER else {}\n",
|
| 514 |
+
"utmos_dev = extract_utmos(dev_names, \"dev\") if USE_UTMOS_FEAT else {}\n",
|
| 515 |
+
"\n",
|
| 516 |
+
"@torch.no_grad()\n",
|
| 517 |
+
"def predict_qmos(sid):\n",
|
| 518 |
+
" f = qmos_feature(sid, e2v_dev, sailer_dev, utmos_dev)\n",
|
| 519 |
+
" if f is None:\n",
|
| 520 |
+
" return None\n",
|
| 521 |
+
" fn = (f[None, :] - feat_mean) / feat_std\n",
|
| 522 |
+
" model.eval()\n",
|
| 523 |
+
" return float(model(to_t(fn)).item()) * y_sd + y_mu # đảo z-score\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"qmos_pred = {}\n",
|
| 526 |
+
"n_real = n_def = 0\n",
|
| 527 |
+
"for n in dev_names:\n",
|
| 528 |
+
" sid = stem(n)\n",
|
| 529 |
+
" p = predict_qmos(sid)\n",
|
| 530 |
+
" if p is None:\n",
|
| 531 |
+
" p = utmos_dev.get(sid, 3.0) # rơi về UTMOS nếu thiếu feature\n",
|
| 532 |
+
" n_def += 1\n",
|
| 533 |
+
" else:\n",
|
| 534 |
+
" n_real += 1\n",
|
| 535 |
+
" qmos_pred[n] = p\n",
|
| 536 |
+
"print(f\"QMOS dự đoán: head thật {n_real}, dự phòng UTMOS {n_def}\")\n",
|
| 537 |
+
"\n",
|
| 538 |
+
"# Lưu riêng (để ghép tay nếu cần)\n",
|
| 539 |
+
"import csv\n",
|
| 540 |
+
"qmos_csv = os.path.join(OUT_DIR, \"qmos_dev.csv\")\n",
|
| 541 |
+
"with open(qmos_csv, \"w\", newline=\"\") as f:\n",
|
| 542 |
+
" w = csv.writer(f); w.writerow([\"wav\", \"QMOS\"])\n",
|
| 543 |
+
" for n in dev_names:\n",
|
| 544 |
+
" w.writerow([n, f\"{qmos_pred[n]:.6g}\"])\n",
|
| 545 |
+
"print(\"Đã ghi\", qmos_csv)"
|
| 546 |
+
]
|
| 547 |
+
},
|
| 548 |
+
{
|
| 549 |
+
"cell_type": "markdown",
|
| 550 |
+
"id": "f3e47def",
|
| 551 |
+
"metadata": {},
|
| 552 |
+
"source": [
|
| 553 |
+
"## 7. Ghép QMOS mới vào answer.txt của exp04 → bản nộp mới\n",
|
| 554 |
+
"Giữ NGUYÊN 5 cột cảm xúc đang thắng (EMOS/CAT/VAL/ARO/DOM), chỉ thay cột QMOS."
|
| 555 |
+
]
|
| 556 |
+
},
|
| 557 |
+
{
|
| 558 |
+
"cell_type": "code",
|
| 559 |
+
"execution_count": null,
|
| 560 |
+
"id": "a3b94589",
|
| 561 |
+
"metadata": {},
|
| 562 |
+
"outputs": [],
|
| 563 |
+
"source": [
|
| 564 |
+
"def merge_into_exp04(exp04_path, out_path):\n",
|
| 565 |
+
" if not os.path.exists(exp04_path):\n",
|
| 566 |
+
" print(f\"⚠️ Không thấy {exp04_path} → BỎ QUA ghép. Hãy dùng qmos_dev.csv để thay cột QMOS thủ công,\")\n",
|
| 567 |
+
" print(\" hoặc trỏ EXP04_ANSWER đúng đường dẫn answer.txt của exp04 rồi chạy lại cell này.\")\n",
|
| 568 |
+
" return False\n",
|
| 569 |
+
" with open(exp04_path) as f:\n",
|
| 570 |
+
" rows = list(csv.reader(f))\n",
|
| 571 |
+
" header = rows[0]\n",
|
| 572 |
+
" qi = header.index(\"QMOS\")\n",
|
| 573 |
+
" wi = header.index(\"wav\")\n",
|
| 574 |
+
" n_swapped = n_miss = 0\n",
|
| 575 |
+
" with open(out_path, \"w\", newline=\"\") as f:\n",
|
| 576 |
+
" w = csv.writer(f); w.writerow(header)\n",
|
| 577 |
+
" for r in rows[1:]:\n",
|
| 578 |
+
" name = r[wi]\n",
|
| 579 |
+
" if name in qmos_pred:\n",
|
| 580 |
+
" r[qi] = f\"{qmos_pred[name]:.6g}\"; n_swapped += 1\n",
|
| 581 |
+
" else:\n",
|
| 582 |
+
" n_miss += 1\n",
|
| 583 |
+
" w.writerow(r)\n",
|
| 584 |
+
" print(f\"Ghép xong → {out_path} | thay {n_swapped} cột QMOS, thiếu {n_miss} (giữ QMOS cũ)\")\n",
|
| 585 |
+
" return True\n",
|
| 586 |
+
"\n",
|
| 587 |
+
"merged = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 588 |
+
"ok = merge_into_exp04(EXP04_ANSWER, merged)\n",
|
| 589 |
+
"\n",
|
| 590 |
+
"if ok:\n",
|
| 591 |
+
" # validate + zip\n",
|
| 592 |
+
" with open(merged) as f:\n",
|
| 593 |
+
" rows = list(csv.reader(f))\n",
|
| 594 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0]\n",
|
| 595 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 596 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 597 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 598 |
+
" os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp06_qmos.zip answer.txt \"\n",
|
| 599 |
+
" f\"&& unzip -l submission_track2_exp06_qmos.zip\")\n",
|
| 600 |
+
" print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp06_qmos.zip\"))"
|
| 601 |
+
]
|
| 602 |
+
},
|
| 603 |
+
{
|
| 604 |
+
"cell_type": "markdown",
|
| 605 |
+
"id": "0a517b97",
|
| 606 |
+
"metadata": {},
|
| 607 |
+
"source": [
|
| 608 |
+
"## Ghi chú\n",
|
| 609 |
+
"- **Lần đầu** đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để bắt lỗi; OK rồi đặt `None`.\n",
|
| 610 |
+
"- **So sánh công bằng**: mục 5 in cả `head SRCC` và `UTMOS SRCC` trên CÙNG val nội bộ → chỉ nộp khi head > UTMOS.\n",
|
| 611 |
+
"- Nếu head **chưa vượt** 0.414: thử (a) tăng `EPOCHS`; (b) bật `RANK_LAMBDA=0.2` (tối ưu thứ hạng);\n",
|
| 612 |
+
" (c) đảm bảo `USE_UTMOS_FEAT=True` (neo residual); (d) thử bỏ bớt đặc trưng nhiễu (tắt `USE_CLASSPROB`).\n",
|
| 613 |
+
"- **Ablation QMOS cho paper**: bật/tắt `USE_E2V/USE_SAILER/USE_UTMOS_FEAT/USE_CLASSPROB` → ghi `docs/04_experiments_log.md` (exp06).\n",
|
| 614 |
+
"- Cache dùng CHUNG `fusion_cache/` với exp04 → nhớ **Save Version** giữ lại (gồm `utmos_*.npz` mới).\n",
|
| 615 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp06)."
|
| 616 |
+
]
|
| 617 |
+
}
|
| 618 |
+
],
|
| 619 |
+
"metadata": {
|
| 620 |
+
"jupytext": {
|
| 621 |
+
"cell_metadata_filter": "-all",
|
| 622 |
+
"main_language": "python",
|
| 623 |
+
"notebook_metadata_filter": "-all"
|
| 624 |
+
}
|
| 625 |
+
},
|
| 626 |
+
"nbformat": 4,
|
| 627 |
+
"nbformat_minor": 5
|
| 628 |
+
}
|
track2/exp06_qmos_train_pipeline.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp06 (TRAIN QMOS head) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục tiêu:** QMOS là cột **duy nhất chưa train** (đang dùng UTMOS zero-shot → SRCC kẹt 0.414).
|
| 5 |
+
# `train.csv` CÓ sẵn cột `qMOS` → ta train 1 **head hồi quy nhỏ** trên đặc trưng SSL (đã cache ở exp04)
|
| 6 |
+
# để vượt 0.414.
|
| 7 |
+
#
|
| 8 |
+
# ## Ý tưởng (đọc 1 lần cho hiểu)
|
| 9 |
+
# - Tái dùng đặc trưng **emotion2vec + SAILER** đã trích & cache trong `fusion_cache/` (exp04) → KHÔNG trích lại.
|
| 10 |
+
# - Thêm **chính điểm UTMOS** (SpeechMOS) làm 1 đặc trưng đầu vào → head chỉ cần **học chỉnh sửa (residual)**
|
| 11 |
+
# quanh 0.414 thay vì học lại từ đầu → an toàn, gần như chắc chắn ≥ UTMOS đơn lẻ.
|
| 12 |
+
# - Nhãn vàng QMOS = **TB `qMOS` theo wav** (gộp các listener trong `train.csv`).
|
| 13 |
+
# - Có **val nội bộ 10%** → đo SRCC, so thẳng với UTMOS trên CÙNG tập val → biết có cải thiện thật
|
| 14 |
+
# **trước khi** tốn lượt nộp CodaBench.
|
| 15 |
+
# - Cuối cùng: **GIỮ NGUYÊN exp04** (5 cột cảm xúc đang thắng), chỉ **thay cột QMOS** trong `answer.txt`.
|
| 16 |
+
#
|
| 17 |
+
# ```
|
| 18 |
+
# mỗi wav ─► [e2v_emb | e2v_probs5 | sailer_emb | sailer_probs9 | sailer_vad3 | UTMOS] ─► MLP ─► QMOS
|
| 19 |
+
# (head train)
|
| 20 |
+
# ```
|
| 21 |
+
#
|
| 22 |
+
# **Cách chạy trên Kaggle:** Settings → Accelerator = **GPU T4**, Internet = **On**.
|
| 23 |
+
# + Add Input: (1) dataset Track 2 (15.477 wav, có `sets/train.csv`) ; (2) — nếu có — dataset chứa
|
| 24 |
+
# `fusion_cache/*.npz` đã Save Version ở exp04 (đỡ ~15') ; (3) file `answer.txt` của exp04 để ghép cột.
|
| 25 |
+
# Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để bắt lỗi setup, OK rồi đặt `None`.
|
| 26 |
+
|
| 27 |
+
# %% [markdown]
|
| 28 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 29 |
+
|
| 30 |
+
# %%
|
| 31 |
+
import os
|
| 32 |
+
|
| 33 |
+
# ── Data Track 2 ─────────────────────────────────────────────────────────────
|
| 34 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
|
| 35 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 36 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # nhãn người nghe: lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
|
| 37 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav tập DEV
|
| 38 |
+
|
| 39 |
+
OUT_DIR = "/kaggle/working"
|
| 40 |
+
# Dùng CHUNG cache với exp04. Nếu đã Save Version cache ở exp04, trỏ CACHE_DIR vào dataset đó
|
| 41 |
+
# (vd "/kaggle/input/<slug-cache>/fusion_cache") để khỏi trích lại; nếu không, để mặc định sẽ tự trích.
|
| 42 |
+
CACHE_DIR = "/kaggle/working/fusion_cache"
|
| 43 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# File answer.txt của exp04 (5 cột cảm xúc đang thắng) để GHÉP cột QMOS mới vào.
|
| 46 |
+
# Trỏ tới nơi bạn đặt file exp04. Nếu không có, notebook vẫn xuất qmos_dev.csv riêng + cảnh báo.
|
| 47 |
+
EXP04_ANSWER = "/kaggle/input/exp04-answer/answer.txt" # << SỬA; hoặc "/kaggle/working/answer.txt"
|
| 48 |
+
|
| 49 |
+
# ── Đặc trưng dùng cho QMOS ──────────────────────────────────────────────────
|
| 50 |
+
USE_E2V = True # nối embedding emotion2vec
|
| 51 |
+
USE_SAILER = True # nối embedding SAILER/WavLM
|
| 52 |
+
USE_CLASSPROB = True # nối thêm xác suất lớp (e2v5 + sailer9 + vad3)
|
| 53 |
+
USE_UTMOS_FEAT = True # nối thêm điểm UTMOS làm 1 đặc trưng (neo residual quanh 0.414)
|
| 54 |
+
|
| 55 |
+
# ── Siêu tham số train head ──────────────────────────────────────────────────
|
| 56 |
+
DEVICE = "cuda"
|
| 57 |
+
HIDDEN = 256
|
| 58 |
+
DROPOUT = 0.3
|
| 59 |
+
LR = 1e-3
|
| 60 |
+
EPOCHS = 120
|
| 61 |
+
BATCH = 64
|
| 62 |
+
VAL_FRAC = 0.10
|
| 63 |
+
PATIENCE = 20
|
| 64 |
+
SEED = 42
|
| 65 |
+
RANK_LAMBDA = 0.0 # 0 = chỉ MSE. >0 (vd 0.2) = cộng thêm pairwise ranking loss (tối ưu thứ hạng=SRCC)
|
| 66 |
+
|
| 67 |
+
LIMIT_TRAIN = None # số nhỏ (vd 300) để chạy thử; None = full
|
| 68 |
+
LIMIT_DEV = None
|
| 69 |
+
|
| 70 |
+
def stem(p):
|
| 71 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 72 |
+
|
| 73 |
+
assert USE_E2V or USE_SAILER or USE_UTMOS_FEAT, "Phải bật ít nhất 1 nguồn đặc trưng."
|
| 74 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 75 |
+
for p in [WAV_DIR, TRAIN_CSV, DEV_SCP]:
|
| 76 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 77 |
+
|
| 78 |
+
# %% [markdown]
|
| 79 |
+
# ## 1. Cài đặt + (nếu cần) tải code SAILER
|
| 80 |
+
# emotion2vec qua `funasr`; SAILER cần `WavLMWrapper` trong repo `vox-profile-release` (clone + sys.path).
|
| 81 |
+
# Nếu cache đã đủ thì các model này sẽ KHÔNG được nạp (chỉ nạp khi còn file phải trích).
|
| 82 |
+
|
| 83 |
+
# %%
|
| 84 |
+
import sys, subprocess
|
| 85 |
+
|
| 86 |
+
def pip_install(*pkgs):
|
| 87 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 88 |
+
|
| 89 |
+
pip_install("speechmos", "funasr", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
|
| 90 |
+
|
| 91 |
+
if USE_SAILER:
|
| 92 |
+
pip_install("loralib", "speechbrain")
|
| 93 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 94 |
+
if not os.path.exists(REPO_DIR):
|
| 95 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 96 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 97 |
+
if REPO_DIR not in sys.path:
|
| 98 |
+
sys.path.insert(0, REPO_DIR)
|
| 99 |
+
|
| 100 |
+
# %% [markdown]
|
| 101 |
+
# ## 2. Nhãn vàng QMOS (gộp `qMOS` theo wavID)
|
| 102 |
+
|
| 103 |
+
# %%
|
| 104 |
+
import numpy as np
|
| 105 |
+
import pandas as pd
|
| 106 |
+
|
| 107 |
+
def load_qmos_labels():
|
| 108 |
+
"""train.csv (sep='|') → DataFrame [wavID, qmos] với qmos = TB theo wav."""
|
| 109 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 110 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 111 |
+
wav_col = cols.get("wavid") or cols.get("wav") or list(df.columns)[1]
|
| 112 |
+
qmos_col = cols.get("qmos") or cols.get("qMOS".lower()) or cols.get("mos")
|
| 113 |
+
assert qmos_col, f"Không thấy cột qMOS trong train.csv (cột: {list(df.columns)})"
|
| 114 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 115 |
+
g = df.groupby("_stem")[qmos_col].mean().reset_index()
|
| 116 |
+
g.columns = ["wavID", "qmos"]
|
| 117 |
+
return g
|
| 118 |
+
|
| 119 |
+
qmos_df = load_qmos_labels()
|
| 120 |
+
print(f"wav train (gộp): {len(qmos_df)}")
|
| 121 |
+
print("qMOS:", qmos_df["qmos"].describe()[["mean", "std", "min", "max"]].to_dict())
|
| 122 |
+
qmos_df.head()
|
| 123 |
+
|
| 124 |
+
# %% [markdown]
|
| 125 |
+
# ## 3. Trích / nạp đặc trưng (cache CHUNG với exp04) + điểm UTMOS
|
| 126 |
+
# - `extract_e2v` / `extract_sailer`: y hệt exp04, cache `e2v_<tag>.npz` / `sailer_<tag>.npz`.
|
| 127 |
+
# - `extract_utmos`: chấm UTMOS từng wav → cache `utmos_<tag>.npz` (dùng vừa làm đặc trưng, vừa làm baseline so sánh).
|
| 128 |
+
|
| 129 |
+
# %%
|
| 130 |
+
import torch
|
| 131 |
+
import torch.nn.functional as F
|
| 132 |
+
|
| 133 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 134 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU")
|
| 135 |
+
|
| 136 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 137 |
+
|
| 138 |
+
def extract_e2v(stems, tag):
|
| 139 |
+
"""→ dict {stem: emb_full[D1+5]}. Cache CACHE_DIR/e2v_<tag>.npz (giống exp04)."""
|
| 140 |
+
from tqdm.auto import tqdm
|
| 141 |
+
cache_path = os.path.join(CACHE_DIR, f"e2v_{tag}.npz")
|
| 142 |
+
store = {}
|
| 143 |
+
if os.path.exists(cache_path):
|
| 144 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 145 |
+
store = {k: z[k] for k in z.files}
|
| 146 |
+
print(f"[e2v/{tag}] nạp cache: {len(store)}")
|
| 147 |
+
todo = [s for s in stems if s not in store]
|
| 148 |
+
if todo:
|
| 149 |
+
from funasr import AutoModel
|
| 150 |
+
m = AutoModel(model="iic/emotion2vec_plus_large", hub="hf", device=device)
|
| 151 |
+
for i, s in enumerate(tqdm(todo, desc=f"e2v {tag}")):
|
| 152 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 153 |
+
if not os.path.exists(wav):
|
| 154 |
+
continue
|
| 155 |
+
r = m.generate(wav, granularity="utterance", extract_embedding=True)[0]
|
| 156 |
+
emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
|
| 157 |
+
probs = {e: 0.0 for e in EMOTIONS5}
|
| 158 |
+
for lab, sc in zip(r["labels"], r["scores"]):
|
| 159 |
+
name = lab.split("/")[-1]
|
| 160 |
+
if name in probs:
|
| 161 |
+
probs[name] = float(sc)
|
| 162 |
+
tot = sum(probs.values())
|
| 163 |
+
p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)
|
| 164 |
+
store[s] = np.concatenate([emb, p5]).astype(np.float32)
|
| 165 |
+
if (i + 1) % 500 == 0:
|
| 166 |
+
np.savez(cache_path, **store)
|
| 167 |
+
np.savez(cache_path, **store)
|
| 168 |
+
del m
|
| 169 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 170 |
+
return store # mỗi value = [D1 | 5]
|
| 171 |
+
|
| 172 |
+
def _pool_feat(features):
|
| 173 |
+
f = features.detach().cpu().numpy()
|
| 174 |
+
if f.ndim <= 1:
|
| 175 |
+
return f.reshape(-1).astype(np.float32)
|
| 176 |
+
return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)
|
| 177 |
+
|
| 178 |
+
def extract_sailer(stems, tag):
|
| 179 |
+
"""→ dict {stem: vec[D2+9+3]}. Cache CACHE_DIR/sailer_<tag>.npz (giống exp04)."""
|
| 180 |
+
import librosa
|
| 181 |
+
from tqdm.auto import tqdm
|
| 182 |
+
cache_path = os.path.join(CACHE_DIR, f"sailer_{tag}.npz")
|
| 183 |
+
store = {}
|
| 184 |
+
if os.path.exists(cache_path):
|
| 185 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 186 |
+
store = {k: z[k] for k in z.files}
|
| 187 |
+
print(f"[sailer/{tag}] nạp cache: {len(store)}")
|
| 188 |
+
todo = [s for s in stems if s not in store]
|
| 189 |
+
if todo:
|
| 190 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper
|
| 191 |
+
sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device).eval()
|
| 192 |
+
with torch.no_grad():
|
| 193 |
+
for i, s in enumerate(tqdm(todo, desc=f"sailer {tag}")):
|
| 194 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 195 |
+
if not os.path.exists(wav):
|
| 196 |
+
continue
|
| 197 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 198 |
+
wave = wave[: 15 * 16000]
|
| 199 |
+
data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
|
| 200 |
+
logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
|
| 201 |
+
emb = _pool_feat(feat)
|
| 202 |
+
p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
|
| 203 |
+
vad3 = np.array([1 + 4 * float(valence.item()),
|
| 204 |
+
1 + 4 * float(arousal.item()),
|
| 205 |
+
1 + 4 * float(dominance.item())], dtype=np.float32)
|
| 206 |
+
store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)
|
| 207 |
+
if (i + 1) % 500 == 0:
|
| 208 |
+
np.savez(cache_path, **store)
|
| 209 |
+
np.savez(cache_path, **store)
|
| 210 |
+
del sailer
|
| 211 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 212 |
+
return store # mỗi value = [D2 | 9 | 3]
|
| 213 |
+
|
| 214 |
+
def extract_utmos(names, tag):
|
| 215 |
+
"""Chấm UTMOS từng wav (theo TÊN file, vì DEV gọi .wav theo tên). → dict {stem: score}.
|
| 216 |
+
Cache CACHE_DIR/utmos_<tag>.npz. Dùng vừa làm đặc trưng vừa làm baseline so sánh."""
|
| 217 |
+
import librosa
|
| 218 |
+
from tqdm.auto import tqdm
|
| 219 |
+
cache_path = os.path.join(CACHE_DIR, f"utmos_{tag}.npz")
|
| 220 |
+
store = {}
|
| 221 |
+
if os.path.exists(cache_path):
|
| 222 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 223 |
+
store = {k: float(z[k]) for k in z.files}
|
| 224 |
+
print(f"[utmos/{tag}] nạp cache: {len(store)}")
|
| 225 |
+
todo = [n for n in names if stem(n) not in store]
|
| 226 |
+
if todo:
|
| 227 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong",
|
| 228 |
+
trust_repo=True).to(device).eval()
|
| 229 |
+
with torch.no_grad():
|
| 230 |
+
for i, n in enumerate(tqdm(todo, desc=f"utmos {tag}")):
|
| 231 |
+
wav = os.path.join(WAV_DIR, n if n.endswith(".wav") else n + ".wav")
|
| 232 |
+
if not os.path.exists(wav):
|
| 233 |
+
continue
|
| 234 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 235 |
+
sc = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device), sr=16000).mean().item())
|
| 236 |
+
store[stem(n)] = sc
|
| 237 |
+
if (i + 1) % 500 == 0:
|
| 238 |
+
np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
|
| 239 |
+
np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
|
| 240 |
+
del predictor
|
| 241 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 242 |
+
return store
|
| 243 |
+
|
| 244 |
+
# %% [markdown]
|
| 245 |
+
# ## 4. Dựng feature + nhãn cho train
|
| 246 |
+
|
| 247 |
+
# %%
|
| 248 |
+
train_stems = list(qmos_df["wavID"])
|
| 249 |
+
if LIMIT_TRAIN:
|
| 250 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 251 |
+
|
| 252 |
+
e2v_tr = extract_e2v(train_stems, "train") if USE_E2V else {}
|
| 253 |
+
sailer_tr = extract_sailer(train_stems, "train") if USE_SAILER else {}
|
| 254 |
+
utmos_tr = extract_utmos(train_stems, "train") if USE_UTMOS_FEAT else {}
|
| 255 |
+
|
| 256 |
+
def qmos_feature(sid, e2v_map, sailer_map, utmos_map):
|
| 257 |
+
"""Nối đặc trưng QMOS cho 1 wav. None nếu thiếu phần bắt buộc."""
|
| 258 |
+
parts = []
|
| 259 |
+
if USE_E2V:
|
| 260 |
+
v = e2v_map.get(sid)
|
| 261 |
+
if v is None:
|
| 262 |
+
return None
|
| 263 |
+
parts.append(v[:-5]) # emb e2v
|
| 264 |
+
if USE_CLASSPROB:
|
| 265 |
+
parts.append(v[-5:]) # probs5
|
| 266 |
+
if USE_SAILER:
|
| 267 |
+
v = sailer_map.get(sid)
|
| 268 |
+
if v is None:
|
| 269 |
+
return None
|
| 270 |
+
parts.append(v[:-12]) # emb sailer
|
| 271 |
+
if USE_CLASSPROB:
|
| 272 |
+
parts.append(v[-12:]) # probs9 + vad3
|
| 273 |
+
if USE_UTMOS_FEAT:
|
| 274 |
+
u = utmos_map.get(sid)
|
| 275 |
+
if u is None:
|
| 276 |
+
return None
|
| 277 |
+
parts.append(np.array([u], dtype=np.float32))
|
| 278 |
+
return np.concatenate(parts).astype(np.float32)
|
| 279 |
+
|
| 280 |
+
lab = qmos_df.set_index("wavID")["qmos"]
|
| 281 |
+
X, y = [], []
|
| 282 |
+
for s in train_stems:
|
| 283 |
+
f = qmos_feature(s, e2v_tr, sailer_tr, utmos_tr)
|
| 284 |
+
if f is None or s not in lab.index:
|
| 285 |
+
continue
|
| 286 |
+
X.append(f)
|
| 287 |
+
y.append(float(lab.loc[s]))
|
| 288 |
+
|
| 289 |
+
X = np.stack(X).astype(np.float32)
|
| 290 |
+
y = np.array(y, dtype=np.float32)
|
| 291 |
+
FEAT_DIM = X.shape[1]
|
| 292 |
+
print(f"Train: X={X.shape} y={y.shape}")
|
| 293 |
+
|
| 294 |
+
feat_mean = X.mean(0, keepdims=True)
|
| 295 |
+
feat_std = X.std(0, keepdims=True) + 1e-6
|
| 296 |
+
Xn = (X - feat_mean) / feat_std
|
| 297 |
+
y_mu, y_sd = float(y.mean()), float(y.std() + 1e-6)
|
| 298 |
+
yn = (y - y_mu) / y_sd
|
| 299 |
+
|
| 300 |
+
# %% [markdown]
|
| 301 |
+
# ## 5. Train head QMOS + so với UTMOS trên CÙNG val nội bộ
|
| 302 |
+
# - Head = MLP nhỏ (`Linear→ReLU→Dropout ×2 → 1`). Loss = MSE (+ tùy chọn pairwise ranking).
|
| 303 |
+
# - In **SRCC head** và **SRCC UTMOS** trên cùng tập val → biết head có thật sự vượt 0.414 không.
|
| 304 |
+
|
| 305 |
+
# %%
|
| 306 |
+
import torch.nn as nn
|
| 307 |
+
from scipy.stats import spearmanr
|
| 308 |
+
from sklearn.model_selection import train_test_split
|
| 309 |
+
|
| 310 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 311 |
+
idx_all = np.arange(X.shape[0])
|
| 312 |
+
tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)
|
| 313 |
+
|
| 314 |
+
def to_t(a):
|
| 315 |
+
return torch.tensor(a, dtype=torch.float32, device=device)
|
| 316 |
+
|
| 317 |
+
Xn_t = to_t(Xn); yn_t = to_t(yn).unsqueeze(1)
|
| 318 |
+
|
| 319 |
+
class QMOSHead(nn.Module):
|
| 320 |
+
def __init__(self, d_in, h, p):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.net = nn.Sequential(
|
| 323 |
+
nn.Linear(d_in, h), nn.ReLU(), nn.Dropout(p),
|
| 324 |
+
nn.Linear(h, h), nn.ReLU(), nn.Dropout(p),
|
| 325 |
+
nn.Linear(h, 1),
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
def forward(self, x):
|
| 329 |
+
return self.net(x)
|
| 330 |
+
|
| 331 |
+
model = QMOSHead(FEAT_DIM, HIDDEN, DROPOUT).to(device)
|
| 332 |
+
opt = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
|
| 333 |
+
mse = nn.MSELoss()
|
| 334 |
+
|
| 335 |
+
def pairwise_rank_loss(pred, target):
|
| 336 |
+
"""Khuyến khích pred xếp hạng giống target (margin ranking trên các cặp trong batch)."""
|
| 337 |
+
n = pred.shape[0]
|
| 338 |
+
if n < 2:
|
| 339 |
+
return torch.zeros((), device=device)
|
| 340 |
+
pi, pj = pred.unsqueeze(0), pred.unsqueeze(1)
|
| 341 |
+
ti, tj = target.unsqueeze(0), target.unsqueeze(1)
|
| 342 |
+
sign = torch.sign(ti - tj) # +1 nếu i nên cao hơn j
|
| 343 |
+
diff = pi - pj
|
| 344 |
+
# hinge: phạt khi thứ tự sai
|
| 345 |
+
return torch.relu(-sign * diff).mean()
|
| 346 |
+
|
| 347 |
+
@torch.no_grad()
|
| 348 |
+
def eval_val():
|
| 349 |
+
model.eval()
|
| 350 |
+
p = model(Xn_t[va_idx]).cpu().numpy().ravel()
|
| 351 |
+
srcc_head = spearmanr(p, y[va_idx]).correlation
|
| 352 |
+
out = {"head": float(srcc_head)}
|
| 353 |
+
if USE_UTMOS_FEAT:
|
| 354 |
+
u = X[va_idx, -1] # cột UTMOS (đặc trưng cuối, chưa chuẩn hóa)
|
| 355 |
+
out["utmos"] = float(spearmanr(u, y[va_idx]).correlation)
|
| 356 |
+
return out
|
| 357 |
+
|
| 358 |
+
best, best_state, bad = -1e9, None, 0
|
| 359 |
+
tr_t = torch.tensor(tr_idx, device=device)
|
| 360 |
+
for ep in range(1, EPOCHS + 1):
|
| 361 |
+
model.train()
|
| 362 |
+
perm = tr_t[torch.randperm(len(tr_t), device=device)]
|
| 363 |
+
run = 0.0
|
| 364 |
+
for i in range(0, len(perm), BATCH):
|
| 365 |
+
b = perm[i:i + BATCH]
|
| 366 |
+
opt.zero_grad()
|
| 367 |
+
pred = model(Xn_t[b])
|
| 368 |
+
loss = mse(pred, yn_t[b])
|
| 369 |
+
if RANK_LAMBDA > 0:
|
| 370 |
+
loss = loss + RANK_LAMBDA * pairwise_rank_loss(pred.ravel(), yn_t[b].ravel())
|
| 371 |
+
loss.backward(); opt.step()
|
| 372 |
+
run += loss.item() * len(b)
|
| 373 |
+
m = eval_val()
|
| 374 |
+
if m["head"] > best:
|
| 375 |
+
best = m["head"]
|
| 376 |
+
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
| 377 |
+
bad = 0
|
| 378 |
+
else:
|
| 379 |
+
bad += 1
|
| 380 |
+
if ep % 5 == 0 or ep == 1:
|
| 381 |
+
extra = f" | UTMOS={m['utmos']:.4f}" if "utmos" in m else ""
|
| 382 |
+
print(f"epoch {ep:3d} | loss {run/len(perm):.4f} | head SRCC={m['head']:.4f}{extra} | best {best:.4f}")
|
| 383 |
+
if bad >= PATIENCE:
|
| 384 |
+
print(f"Early stop ở epoch {ep}.")
|
| 385 |
+
break
|
| 386 |
+
|
| 387 |
+
model.load_state_dict(best_state)
|
| 388 |
+
final = eval_val()
|
| 389 |
+
print("\n✅ VAL (nội bộ):")
|
| 390 |
+
print(f" QMOS head SRCC = {final['head']:.4f}")
|
| 391 |
+
if "utmos" in final:
|
| 392 |
+
print(f" UTMOS baseline = {final['utmos']:.4f} (mốc leaderboard 0.414)")
|
| 393 |
+
print(" →", "✅ HEAD VƯỢT UTMOS" if final["head"] > final["utmos"] else "⚠️ chưa vượt — thử tăng EPOCHS / RANK_LAMBDA / bật thêm đặc trưng")
|
| 394 |
+
|
| 395 |
+
torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
|
| 396 |
+
"y_mu": y_mu, "y_sd": y_sd, "FEAT_DIM": FEAT_DIM,
|
| 397 |
+
"USE_E2V": USE_E2V, "USE_SAILER": USE_SAILER,
|
| 398 |
+
"USE_CLASSPROB": USE_CLASSPROB, "USE_UTMOS_FEAT": USE_UTMOS_FEAT,
|
| 399 |
+
"val_srcc": best}, os.path.join(OUT_DIR, "qmos_head.pt"))
|
| 400 |
+
print("Đã lưu", os.path.join(OUT_DIR, "qmos_head.pt"))
|
| 401 |
+
|
| 402 |
+
# %% [markdown]
|
| 403 |
+
# ## 6. Dự đoán QMOS cho DEV
|
| 404 |
+
|
| 405 |
+
# %%
|
| 406 |
+
def list_dev():
|
| 407 |
+
with open(DEV_SCP) as f:
|
| 408 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 409 |
+
|
| 410 |
+
dev_names = list_dev()
|
| 411 |
+
if LIMIT_DEV:
|
| 412 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 413 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 414 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 415 |
+
|
| 416 |
+
e2v_dev = extract_e2v(dev_stems, "dev") if USE_E2V else {}
|
| 417 |
+
sailer_dev = extract_sailer(dev_stems, "dev") if USE_SAILER else {}
|
| 418 |
+
utmos_dev = extract_utmos(dev_names, "dev") if USE_UTMOS_FEAT else {}
|
| 419 |
+
|
| 420 |
+
@torch.no_grad()
|
| 421 |
+
def predict_qmos(sid):
|
| 422 |
+
f = qmos_feature(sid, e2v_dev, sailer_dev, utmos_dev)
|
| 423 |
+
if f is None:
|
| 424 |
+
return None
|
| 425 |
+
fn = (f[None, :] - feat_mean) / feat_std
|
| 426 |
+
model.eval()
|
| 427 |
+
return float(model(to_t(fn)).item()) * y_sd + y_mu # đảo z-score
|
| 428 |
+
|
| 429 |
+
qmos_pred = {}
|
| 430 |
+
n_real = n_def = 0
|
| 431 |
+
for n in dev_names:
|
| 432 |
+
sid = stem(n)
|
| 433 |
+
p = predict_qmos(sid)
|
| 434 |
+
if p is None:
|
| 435 |
+
p = utmos_dev.get(sid, 3.0) # rơi về UTMOS nếu thiếu feature
|
| 436 |
+
n_def += 1
|
| 437 |
+
else:
|
| 438 |
+
n_real += 1
|
| 439 |
+
qmos_pred[n] = p
|
| 440 |
+
print(f"QMOS dự đoán: head thật {n_real}, dự phòng UTMOS {n_def}")
|
| 441 |
+
|
| 442 |
+
# Lưu riêng (để ghép tay nếu cần)
|
| 443 |
+
import csv
|
| 444 |
+
qmos_csv = os.path.join(OUT_DIR, "qmos_dev.csv")
|
| 445 |
+
with open(qmos_csv, "w", newline="") as f:
|
| 446 |
+
w = csv.writer(f); w.writerow(["wav", "QMOS"])
|
| 447 |
+
for n in dev_names:
|
| 448 |
+
w.writerow([n, f"{qmos_pred[n]:.6g}"])
|
| 449 |
+
print("Đã ghi", qmos_csv)
|
| 450 |
+
|
| 451 |
+
# %% [markdown]
|
| 452 |
+
# ## 7. Ghép QMOS mới vào answer.txt của exp04 → bản nộp mới
|
| 453 |
+
# Giữ NGUYÊN 5 cột cảm xúc đang thắng (EMOS/CAT/VAL/ARO/DOM), chỉ thay cột QMOS.
|
| 454 |
+
|
| 455 |
+
# %%
|
| 456 |
+
def merge_into_exp04(exp04_path, out_path):
|
| 457 |
+
if not os.path.exists(exp04_path):
|
| 458 |
+
print(f"⚠️ Không thấy {exp04_path} → BỎ QUA ghép. Hãy dùng qmos_dev.csv để thay cột QMOS thủ công,")
|
| 459 |
+
print(" hoặc trỏ EXP04_ANSWER đúng đường dẫn answer.txt của exp04 rồi chạy lại cell này.")
|
| 460 |
+
return False
|
| 461 |
+
with open(exp04_path) as f:
|
| 462 |
+
rows = list(csv.reader(f))
|
| 463 |
+
header = rows[0]
|
| 464 |
+
qi = header.index("QMOS")
|
| 465 |
+
wi = header.index("wav")
|
| 466 |
+
n_swapped = n_miss = 0
|
| 467 |
+
with open(out_path, "w", newline="") as f:
|
| 468 |
+
w = csv.writer(f); w.writerow(header)
|
| 469 |
+
for r in rows[1:]:
|
| 470 |
+
name = r[wi]
|
| 471 |
+
if name in qmos_pred:
|
| 472 |
+
r[qi] = f"{qmos_pred[name]:.6g}"; n_swapped += 1
|
| 473 |
+
else:
|
| 474 |
+
n_miss += 1
|
| 475 |
+
w.writerow(r)
|
| 476 |
+
print(f"Ghép xong → {out_path} | thay {n_swapped} cột QMOS, thiếu {n_miss} (giữ QMOS cũ)")
|
| 477 |
+
return True
|
| 478 |
+
|
| 479 |
+
merged = os.path.join(OUT_DIR, "answer.txt")
|
| 480 |
+
ok = merge_into_exp04(EXP04_ANSWER, merged)
|
| 481 |
+
|
| 482 |
+
if ok:
|
| 483 |
+
# validate + zip
|
| 484 |
+
with open(merged) as f:
|
| 485 |
+
rows = list(csv.reader(f))
|
| 486 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0]
|
| 487 |
+
for i, r in enumerate(rows[1:], 2):
|
| 488 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 489 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 490 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp06_qmos.zip answer.txt "
|
| 491 |
+
f"&& unzip -l submission_track2_exp06_qmos.zip")
|
| 492 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp06_qmos.zip"))
|
| 493 |
+
|
| 494 |
+
# %% [markdown]
|
| 495 |
+
# ## Ghi chú
|
| 496 |
+
# - **Lần đầu** đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để bắt lỗi; OK rồi đặt `None`.
|
| 497 |
+
# - **So sánh công bằng**: mục 5 in cả `head SRCC` và `UTMOS SRCC` trên CÙNG val nội bộ → chỉ nộp khi head > UTMOS.
|
| 498 |
+
# - Nếu head **chưa vượt** 0.414: thử (a) tăng `EPOCHS`; (b) bật `RANK_LAMBDA=0.2` (tối ưu thứ hạng);
|
| 499 |
+
# (c) đảm bảo `USE_UTMOS_FEAT=True` (neo residual); (d) thử bỏ bớt đặc trưng nhiễu (tắt `USE_CLASSPROB`).
|
| 500 |
+
# - **Ablation QMOS cho paper**: bật/tắt `USE_E2V/USE_SAILER/USE_UTMOS_FEAT/USE_CLASSPROB` → ghi `docs/04_experiments_log.md` (exp06).
|
| 501 |
+
# - Cache dùng CHUNG `fusion_cache/` với exp04 → nhớ **Save Version** giữ lại (gồm `utmos_*.npz` mới).
|
| 502 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp06).
|
track2/exp07_fusion_qmos.ipynb
ADDED
|
@@ -0,0 +1,780 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "c75f9ad6",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp07 (FUSION + QMOS head, HỢP NHẤT 6 cột) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Khác exp04 ở đâu:** exp04 để **QMOS riêng** (UTMOS zero-shot). exp07 **gộp luôn QMOS vào trunk chung**\n",
|
| 11 |
+
"→ 1 model multi-task dự đoán **đủ 6 đầu ra**: QMOS · EMOS · CAT · VAL · ARO · DOM.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"## Giả thuyết (của bạn) cần kiểm chứng\n",
|
| 14 |
+
"\"Chất giọng tự nhiên có liên quan tới cảm nhận cảm xúc\" → nếu đúng, QMOS sẽ **hưởng lợi** từ biểu diễn\n",
|
| 15 |
+
"cảm xúc chung (emotion2vec + SAILER). **Rủi ro:** 2 backbone này chuyên *cảm xúc*, chưa chắc bắt tốt\n",
|
| 16 |
+
"*lỗi chất lượng/artifact* (thứ UTMOS chuyên trị) → QMOS có thể **thua** UTMOS, hoặc gộp làm **tụt** EMOS/VAD.\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"## Lưới an toàn trong thiết kế\n",
|
| 19 |
+
"- **Vẫn đưa điểm UTMOS làm 1 đầu vào** cho QMOS head (`USE_UTMOS_FEAT`) → head học **chỉnh sửa** quanh\n",
|
| 20 |
+
" 0.414 thay vì học lại từ đầu → khó tệ hơn UTMOS.\n",
|
| 21 |
+
"- **In SRCC cả 6 cột + so mốc exp04** (EMOS 0.788 · CAT err 0.145 · VAL 0.578 · ARO 0.754 · DOM 0.706)\n",
|
| 22 |
+
" → cảnh báo ngay nếu gộp QMOS làm tụt 5 cột cảm xúc.\n",
|
| 23 |
+
"- **File riêng**, KHÔNG đụng `exp04_fusion_pipeline.py` (exp04 vẫn nguyên).\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"```\n",
|
| 26 |
+
" mỗi wav ─► [e2v_emb | e2v_p5 | sailer_emb | sailer_p9 | sailer_vad3] ─► TRUNK chung\n",
|
| 27 |
+
" │\n",
|
| 28 |
+
" ┌──────────────┬───────────────┬─────────────┬───────────────────┤\n",
|
| 29 |
+
" [QMOS head] [EMOS head] [CAT head] [VAD head]\n",
|
| 30 |
+
" trunk + UTMOS trunk + target trunk trunk\n",
|
| 31 |
+
"```\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"**Cách chạy:** GPU T4 + Internet On → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.\n",
|
| 34 |
+
"Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`. Dùng CHUNG cache `fusion_cache/` với exp04 (thêm `utmos_*.npz`)."
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "markdown",
|
| 39 |
+
"id": "b4e814c4",
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"source": [
|
| 42 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": null,
|
| 48 |
+
"id": "57b9eedb",
|
| 49 |
+
"metadata": {},
|
| 50 |
+
"outputs": [],
|
| 51 |
+
"source": [
|
| 52 |
+
"import os\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
|
| 55 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 56 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
|
| 57 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
|
| 58 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 59 |
+
"\n",
|
| 60 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 61 |
+
"CACHE_DIR = \"/kaggle/working/fusion_cache\" # dùng CHUNG với exp04 (thêm utmos_*.npz)\n",
|
| 62 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"# ── Siêu tham số ─────────────────────────────────────────────────────────────\n",
|
| 65 |
+
"DEVICE = \"cuda\"\n",
|
| 66 |
+
"TRUNK_HIDDEN = 512\n",
|
| 67 |
+
"HEAD_HIDDEN = 128\n",
|
| 68 |
+
"DROPOUT = 0.3\n",
|
| 69 |
+
"LR = 1e-3\n",
|
| 70 |
+
"EPOCHS = 80\n",
|
| 71 |
+
"BATCH = 64\n",
|
| 72 |
+
"VAL_FRAC = 0.10\n",
|
| 73 |
+
"PATIENCE = 15\n",
|
| 74 |
+
"SEED = 42\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"USE_UNCERTAINTY = True # tự cân 6 loss (Kendall); False = dùng LOSS_W cố định\n",
|
| 77 |
+
"LOSS_W = {\"qmos\": 1.0, \"emos\": 1.0, \"cat\": 1.0, \"val\": 1.0, \"aro\": 1.0, \"dom\": 1.0}\n",
|
| 78 |
+
"USE_E2V = True\n",
|
| 79 |
+
"USE_SAILER = True\n",
|
| 80 |
+
"USE_CLASSPROB = True\n",
|
| 81 |
+
"USE_UTMOS_FEAT = True # đưa điểm UTMOS làm đầu vào QMOS head (neo residual quanh 0.414)\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"LIMIT_TRAIN = None\n",
|
| 84 |
+
"LIMIT_DEV = None\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"# Mốc exp04 để so (cảnh báo nếu tụt khi gộp QMOS)\n",
|
| 87 |
+
"EXP04 = {\"emos\": 0.788, \"cat_err\": 0.145, \"val\": 0.578, \"aro\": 0.754, \"dom\": 0.706, \"qmos_utmos\": 0.414}\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 90 |
+
"SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
|
| 91 |
+
"EMO2SAILER = {\"angry\": 0, \"happy\": 4, \"neutral\": 5, \"sad\": 6, \"surprised\": 7}\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"_EMO_ALIAS = {\n",
|
| 94 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 95 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 96 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 97 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 98 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 99 |
+
"}\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"def norm_emotion(label):\n",
|
| 102 |
+
" key = str(label).strip().lower()\n",
|
| 103 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"def stem(p):\n",
|
| 106 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"assert USE_E2V or USE_SAILER, \"Phải bật ít nhất 1 backbone.\"\n",
|
| 109 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 110 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 111 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "markdown",
|
| 116 |
+
"id": "547ccf32",
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"source": [
|
| 119 |
+
"## 1. Cài đặt + tải code SAILER"
|
| 120 |
+
]
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"cell_type": "code",
|
| 124 |
+
"execution_count": null,
|
| 125 |
+
"id": "132b9321",
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"outputs": [],
|
| 128 |
+
"source": [
|
| 129 |
+
"import sys, subprocess\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"def pip_install(*pkgs):\n",
|
| 132 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"pip_install(\"speechmos\", \"funasr\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"if USE_SAILER:\n",
|
| 137 |
+
" pip_install(\"loralib\", \"speechbrain\")\n",
|
| 138 |
+
" REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 139 |
+
" if not os.path.exists(REPO_DIR):\n",
|
| 140 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 141 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 142 |
+
" if REPO_DIR not in sys.path:\n",
|
| 143 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "markdown",
|
| 148 |
+
"id": "75c6f07c",
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"source": [
|
| 151 |
+
"## 2. Đọc & gộp nhãn (gộp theo wavID) — THÊM cột qMOS\n",
|
| 152 |
+
"Khác exp04: gộp thêm **qMOS** (= TB `qMOS` theo wav) làm nhãn cho QMOS head."
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "code",
|
| 157 |
+
"execution_count": null,
|
| 158 |
+
"id": "3c73a7fb",
|
| 159 |
+
"metadata": {},
|
| 160 |
+
"outputs": [],
|
| 161 |
+
"source": [
|
| 162 |
+
"import numpy as np\n",
|
| 163 |
+
"import pandas as pd\n",
|
| 164 |
+
"\n",
|
| 165 |
+
"def load_target_emotions():\n",
|
| 166 |
+
" tgt = {}\n",
|
| 167 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 168 |
+
" for ln in f:\n",
|
| 169 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 170 |
+
" if len(parts) < 2:\n",
|
| 171 |
+
" continue\n",
|
| 172 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 173 |
+
" return tgt\n",
|
| 174 |
+
"\n",
|
| 175 |
+
"def _col(cols_map, *names, default_idx=None, df=None):\n",
|
| 176 |
+
" for n in names:\n",
|
| 177 |
+
" if n in cols_map:\n",
|
| 178 |
+
" return cols_map[n]\n",
|
| 179 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"def parse_emocat_votes(cell):\n",
|
| 182 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 183 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 184 |
+
" e = norm_emotion(tok)\n",
|
| 185 |
+
" if e in EMOTIONS5:\n",
|
| 186 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 187 |
+
" return v\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"def load_train_labels():\n",
|
| 190 |
+
" \"\"\"train.csv → DataFrame [wavID, qmos, emos, val, aro, dom, cat0..cat4] gộp theo wav.\"\"\"\n",
|
| 191 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 192 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 193 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", default_idx=1, df=df)\n",
|
| 194 |
+
" qmos_col = _col(cols, \"qmos\", \"mos\")\n",
|
| 195 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 196 |
+
" val_col = _col(cols, \"val\", \"valence\")\n",
|
| 197 |
+
" aro_col = _col(cols, \"aro\", \"arousal\")\n",
|
| 198 |
+
" dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 199 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 200 |
+
" assert qmos_col, f\"Không thấy cột qMOS trong train.csv (cột: {list(df.columns)})\"\n",
|
| 201 |
+
" assert emos_col, f\"Không thấy cột eMOS trong train.csv (cột: {list(df.columns)})\"\n",
|
| 202 |
+
"\n",
|
| 203 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 204 |
+
" rows = []\n",
|
| 205 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 206 |
+
" rec = {\"wavID\": sid,\n",
|
| 207 |
+
" \"qmos\": float(g[qmos_col].mean()),\n",
|
| 208 |
+
" \"emos\": float(g[emos_col].mean())}\n",
|
| 209 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 210 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 211 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 212 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 213 |
+
" if cat_col:\n",
|
| 214 |
+
" for cell in g[cat_col]:\n",
|
| 215 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 216 |
+
" s = votes.sum()\n",
|
| 217 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)\n",
|
| 218 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 219 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 220 |
+
" rows.append(rec)\n",
|
| 221 |
+
" return pd.DataFrame(rows)\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"target_map = load_target_emotions()\n",
|
| 224 |
+
"train_df = load_train_labels()\n",
|
| 225 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 226 |
+
"print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")\n",
|
| 227 |
+
"print(\"qMOS:\", train_df[\"qmos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
|
| 228 |
+
"print(\"eMOS:\", train_df[\"emos\"].describe()[[\"mean\", \"std\", \"min\", \"max\"]].to_dict())\n",
|
| 229 |
+
"train_df.head()"
|
| 230 |
+
]
|
| 231 |
+
},
|
| 232 |
+
{
|
| 233 |
+
"cell_type": "markdown",
|
| 234 |
+
"id": "0726b340",
|
| 235 |
+
"metadata": {},
|
| 236 |
+
"source": [
|
| 237 |
+
"## 3. Trích đặc trưng 2 backbone + điểm UTMOS (cache CHUNG với exp04)"
|
| 238 |
+
]
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"cell_type": "code",
|
| 242 |
+
"execution_count": null,
|
| 243 |
+
"id": "ae27e424",
|
| 244 |
+
"metadata": {
|
| 245 |
+
"lines_to_next_cell": 1
|
| 246 |
+
},
|
| 247 |
+
"outputs": [],
|
| 248 |
+
"source": [
|
| 249 |
+
"import torch\n",
|
| 250 |
+
"import torch.nn.functional as F\n",
|
| 251 |
+
"\n",
|
| 252 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 253 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU\")\n",
|
| 254 |
+
"\n",
|
| 255 |
+
"def extract_e2v(stems, tag):\n",
|
| 256 |
+
" from tqdm.auto import tqdm\n",
|
| 257 |
+
" cache_path = os.path.join(CACHE_DIR, f\"e2v_{tag}.npz\")\n",
|
| 258 |
+
" store = {}\n",
|
| 259 |
+
" if os.path.exists(cache_path):\n",
|
| 260 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 261 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 262 |
+
" print(f\"[e2v/{tag}] nạp cache: {len(store)}\")\n",
|
| 263 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 264 |
+
" if todo:\n",
|
| 265 |
+
" from funasr import AutoModel\n",
|
| 266 |
+
" m = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\", device=device)\n",
|
| 267 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"e2v {tag}\")):\n",
|
| 268 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 269 |
+
" if not os.path.exists(wav):\n",
|
| 270 |
+
" continue\n",
|
| 271 |
+
" r = m.generate(wav, granularity=\"utterance\", extract_embedding=True)[0]\n",
|
| 272 |
+
" emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
|
| 273 |
+
" probs = {e: 0.0 for e in EMOTIONS5}\n",
|
| 274 |
+
" for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
|
| 275 |
+
" name = lab.split(\"/\")[-1]\n",
|
| 276 |
+
" if name in probs:\n",
|
| 277 |
+
" probs[name] = float(sc)\n",
|
| 278 |
+
" tot = sum(probs.values())\n",
|
| 279 |
+
" p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)\n",
|
| 280 |
+
" store[s] = np.concatenate([emb, p5]).astype(np.float32)\n",
|
| 281 |
+
" if (i + 1) % 500 == 0:\n",
|
| 282 |
+
" np.savez(cache_path, **store)\n",
|
| 283 |
+
" np.savez(cache_path, **store)\n",
|
| 284 |
+
" del m\n",
|
| 285 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 286 |
+
" return {s: (v[:-5], v[-5:]) for s, v in store.items()}\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"def _pool_feat(features):\n",
|
| 289 |
+
" f = features.detach().cpu().numpy()\n",
|
| 290 |
+
" if f.ndim <= 1:\n",
|
| 291 |
+
" return f.reshape(-1).astype(np.float32)\n",
|
| 292 |
+
" return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"def extract_sailer(stems, tag):\n",
|
| 295 |
+
" import librosa\n",
|
| 296 |
+
" from tqdm.auto import tqdm\n",
|
| 297 |
+
" cache_path = os.path.join(CACHE_DIR, f\"sailer_{tag}.npz\")\n",
|
| 298 |
+
" store = {}\n",
|
| 299 |
+
" if os.path.exists(cache_path):\n",
|
| 300 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 301 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 302 |
+
" print(f\"[sailer/{tag}] nạp cache: {len(store)}\")\n",
|
| 303 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 304 |
+
" if todo:\n",
|
| 305 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
|
| 306 |
+
" sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device).eval()\n",
|
| 307 |
+
" with torch.no_grad():\n",
|
| 308 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"sailer {tag}\")):\n",
|
| 309 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 310 |
+
" if not os.path.exists(wav):\n",
|
| 311 |
+
" continue\n",
|
| 312 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 313 |
+
" wave = wave[: 15 * 16000]\n",
|
| 314 |
+
" data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
|
| 315 |
+
" logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
|
| 316 |
+
" emb = _pool_feat(feat)\n",
|
| 317 |
+
" p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
|
| 318 |
+
" vad3 = np.array([1 + 4 * float(valence.item()),\n",
|
| 319 |
+
" 1 + 4 * float(arousal.item()),\n",
|
| 320 |
+
" 1 + 4 * float(dominance.item())], dtype=np.float32)\n",
|
| 321 |
+
" store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)\n",
|
| 322 |
+
" if (i + 1) % 500 == 0:\n",
|
| 323 |
+
" np.savez(cache_path, **store)\n",
|
| 324 |
+
" np.savez(cache_path, **store)\n",
|
| 325 |
+
" del sailer\n",
|
| 326 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 327 |
+
" return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"def extract_utmos(names, tag):\n",
|
| 330 |
+
" \"\"\"Chấm UTMOS từng wav (theo TÊN, vì DEV gọi .wav theo tên). → dict {stem: score}.\n",
|
| 331 |
+
" Cache CACHE_DIR/utmos_<tag>.npz. Dùng vừa làm đầu vào QMOS head, vừa làm baseline so sánh.\"\"\"\n",
|
| 332 |
+
" import librosa\n",
|
| 333 |
+
" from tqdm.auto import tqdm\n",
|
| 334 |
+
" cache_path = os.path.join(CACHE_DIR, f\"utmos_{tag}.npz\")\n",
|
| 335 |
+
" store = {}\n",
|
| 336 |
+
" if os.path.exists(cache_path):\n",
|
| 337 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 338 |
+
" store = {k: float(z[k]) for k in z.files}\n",
|
| 339 |
+
" print(f\"[utmos/{tag}] nạp cache: {len(store)}\")\n",
|
| 340 |
+
" todo = [n for n in names if stem(n) not in store]\n",
|
| 341 |
+
" if todo:\n",
|
| 342 |
+
" predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\",\n",
|
| 343 |
+
" trust_repo=True).to(device).eval()\n",
|
| 344 |
+
" with torch.no_grad():\n",
|
| 345 |
+
" for i, n in enumerate(tqdm(todo, desc=f\"utmos {tag}\")):\n",
|
| 346 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else n + \".wav\")\n",
|
| 347 |
+
" if not os.path.exists(wav):\n",
|
| 348 |
+
" continue\n",
|
| 349 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 350 |
+
" store[stem(n)] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),\n",
|
| 351 |
+
" sr=16000).mean().item())\n",
|
| 352 |
+
" if (i + 1) % 500 == 0:\n",
|
| 353 |
+
" np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 354 |
+
" np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 355 |
+
" del predictor\n",
|
| 356 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 357 |
+
" return store"
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"cell_type": "markdown",
|
| 362 |
+
"id": "e9fad1a0",
|
| 363 |
+
"metadata": {},
|
| 364 |
+
"source": [
|
| 365 |
+
"## 4. Dựng feature + nhãn cho train\n",
|
| 366 |
+
"Feature audio (cảm xúc) = `[e2v_emb | e2v_p5 | sailer_emb | sailer_p9 | sailer_vad3]` (như exp04).\n",
|
| 367 |
+
"Thêm: vector **UTMOS** (1 số/ wav) cho QMOS head, và nhãn **qMOS**."
|
| 368 |
+
]
|
| 369 |
+
},
|
| 370 |
+
{
|
| 371 |
+
"cell_type": "code",
|
| 372 |
+
"execution_count": null,
|
| 373 |
+
"id": "768f374b",
|
| 374 |
+
"metadata": {},
|
| 375 |
+
"outputs": [],
|
| 376 |
+
"source": [
|
| 377 |
+
"train_stems = list(train_df[\"wavID\"])\n",
|
| 378 |
+
"if LIMIT_TRAIN:\n",
|
| 379 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 380 |
+
"\n",
|
| 381 |
+
"e2v_tr = extract_e2v(train_stems, \"train\") if USE_E2V else {}\n",
|
| 382 |
+
"sailer_tr = extract_sailer(train_stems, \"train\") if USE_SAILER else {}\n",
|
| 383 |
+
"utmos_tr = extract_utmos(train_stems, \"train\") if USE_UTMOS_FEAT else {}\n",
|
| 384 |
+
"\n",
|
| 385 |
+
"def audio_feature(sid, e2v_map, sailer_map):\n",
|
| 386 |
+
" parts = []\n",
|
| 387 |
+
" if USE_E2V:\n",
|
| 388 |
+
" pk = e2v_map.get(sid)\n",
|
| 389 |
+
" if pk is None:\n",
|
| 390 |
+
" return None\n",
|
| 391 |
+
" emb, p5 = pk\n",
|
| 392 |
+
" parts.append(emb)\n",
|
| 393 |
+
" if USE_CLASSPROB:\n",
|
| 394 |
+
" parts.append(p5)\n",
|
| 395 |
+
" if USE_SAILER:\n",
|
| 396 |
+
" pk = sailer_map.get(sid)\n",
|
| 397 |
+
" if pk is None:\n",
|
| 398 |
+
" return None\n",
|
| 399 |
+
" emb, p9, vad3 = pk\n",
|
| 400 |
+
" parts.append(emb)\n",
|
| 401 |
+
" if USE_CLASSPROB:\n",
|
| 402 |
+
" parts.append(p9); parts.append(vad3)\n",
|
| 403 |
+
" return np.concatenate(parts).astype(np.float32)\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"def onehot_target(tgt):\n",
|
| 406 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 407 |
+
" if tgt in EMOTIONS5:\n",
|
| 408 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 409 |
+
" return v\n",
|
| 410 |
+
"\n",
|
| 411 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 412 |
+
"X, T, U, y_qmos, y_emos, y_vad, y_cat = [], [], [], [], [], [], []\n",
|
| 413 |
+
"for s in train_stems:\n",
|
| 414 |
+
" f = audio_feature(s, e2v_tr, sailer_tr)\n",
|
| 415 |
+
" tgt = target_map.get(s)\n",
|
| 416 |
+
" if f is None or tgt is None or s not in lab.index:\n",
|
| 417 |
+
" continue\n",
|
| 418 |
+
" if USE_UTMOS_FEAT and s not in utmos_tr:\n",
|
| 419 |
+
" continue\n",
|
| 420 |
+
" X.append(f)\n",
|
| 421 |
+
" T.append(onehot_target(tgt))\n",
|
| 422 |
+
" U.append(utmos_tr.get(s, 3.0) if USE_UTMOS_FEAT else 0.0)\n",
|
| 423 |
+
" y_qmos.append(lab.loc[s, \"qmos\"])\n",
|
| 424 |
+
" y_emos.append(lab.loc[s, \"emos\"])\n",
|
| 425 |
+
" y_vad.append([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]])\n",
|
| 426 |
+
" y_cat.append([lab.loc[s, f\"cat{i}\"] for i in range(len(EMOTIONS5))])\n",
|
| 427 |
+
"\n",
|
| 428 |
+
"X = np.stack(X).astype(np.float32)\n",
|
| 429 |
+
"T = np.stack(T).astype(np.float32)\n",
|
| 430 |
+
"U = np.array(U, dtype=np.float32).reshape(-1, 1)\n",
|
| 431 |
+
"y_qmos = np.array(y_qmos, dtype=np.float32)\n",
|
| 432 |
+
"y_emos = np.array(y_emos, dtype=np.float32)\n",
|
| 433 |
+
"y_vad = np.array(y_vad, dtype=np.float32)\n",
|
| 434 |
+
"y_cat = np.array(y_cat, dtype=np.float32)\n",
|
| 435 |
+
"FEAT_DIM = X.shape[1]\n",
|
| 436 |
+
"print(f\"Train: X={X.shape} U={U.shape} qmos={y_qmos.shape} emos={y_emos.shape} vad={y_vad.shape}\")\n",
|
| 437 |
+
"\n",
|
| 438 |
+
"# Chuẩn hóa feature audio + UTMOS (z-score), lưu mean/std.\n",
|
| 439 |
+
"feat_mean = X.mean(0, keepdims=True); feat_std = X.std(0, keepdims=True) + 1e-6\n",
|
| 440 |
+
"Xn = (X - feat_mean) / feat_std\n",
|
| 441 |
+
"u_mu, u_sd = float(U.mean()), float(U.std() + 1e-6)\n",
|
| 442 |
+
"Un = (U - u_mu) / u_sd\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"# Chuẩn hóa nhãn liên tục về z-score.\n",
|
| 445 |
+
"qmos_mu, qmos_sd = float(y_qmos.mean()), float(y_qmos.std() + 1e-6)\n",
|
| 446 |
+
"y_qmos_z = (y_qmos - qmos_mu) / qmos_sd\n",
|
| 447 |
+
"emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6)\n",
|
| 448 |
+
"y_emos_z = (y_emos - emos_mu) / emos_sd\n",
|
| 449 |
+
"if HAS_VAD:\n",
|
| 450 |
+
" vad_mu = np.nanmean(y_vad, axis=0); vad_sd = np.nanstd(y_vad, axis=0) + 1e-6\n",
|
| 451 |
+
" y_vad_z = (y_vad - vad_mu) / vad_sd\n",
|
| 452 |
+
"else:\n",
|
| 453 |
+
" vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
|
| 454 |
+
" y_vad_z = np.zeros_like(y_vad)"
|
| 455 |
+
]
|
| 456 |
+
},
|
| 457 |
+
{
|
| 458 |
+
"cell_type": "markdown",
|
| 459 |
+
"id": "5903e6ca",
|
| 460 |
+
"metadata": {},
|
| 461 |
+
"source": [
|
| 462 |
+
"## 5. Model fusion multi-task (6 head) + train loop\n",
|
| 463 |
+
"Thêm so exp04: **QMOS head** nhận `[trunk | UTMOS]` → 1; `qmos` vào uncertainty weighting (6 task)."
|
| 464 |
+
]
|
| 465 |
+
},
|
| 466 |
+
{
|
| 467 |
+
"cell_type": "code",
|
| 468 |
+
"execution_count": null,
|
| 469 |
+
"id": "68a3a836",
|
| 470 |
+
"metadata": {
|
| 471 |
+
"lines_to_next_cell": 1
|
| 472 |
+
},
|
| 473 |
+
"outputs": [],
|
| 474 |
+
"source": [
|
| 475 |
+
"import torch.nn as nn\n",
|
| 476 |
+
"from scipy.stats import spearmanr\n",
|
| 477 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 478 |
+
"\n",
|
| 479 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 480 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 481 |
+
"idx_all = np.arange(X.shape[0])\n",
|
| 482 |
+
"tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)\n",
|
| 483 |
+
"\n",
|
| 484 |
+
"def to_t(a):\n",
|
| 485 |
+
" return torch.tensor(a, dtype=torch.float32, device=device)\n",
|
| 486 |
+
"\n",
|
| 487 |
+
"Xn_t, T_t, Un_t = to_t(Xn), to_t(T), to_t(Un)\n",
|
| 488 |
+
"qmos_t = to_t(y_qmos_z).unsqueeze(1)\n",
|
| 489 |
+
"emos_t = to_t(y_emos_z).unsqueeze(1)\n",
|
| 490 |
+
"vad_t = to_t(y_vad_z)\n",
|
| 491 |
+
"cat_t = to_t(y_cat)\n",
|
| 492 |
+
"\n",
|
| 493 |
+
"class FusionMTL6(nn.Module):\n",
|
| 494 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo, use_utmos):\n",
|
| 495 |
+
" super().__init__()\n",
|
| 496 |
+
" self.use_utmos = use_utmos\n",
|
| 497 |
+
" self.trunk = nn.Sequential(\n",
|
| 498 |
+
" nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 499 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 500 |
+
" )\n",
|
| 501 |
+
" self.qmos = nn.Sequential( # nhận [trunk | utmos] nếu bật\n",
|
| 502 |
+
" nn.Linear(trunk_h + (1 if use_utmos else 0), head_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 503 |
+
" nn.Linear(head_h, 1))\n",
|
| 504 |
+
" self.emos = nn.Sequential( # nhận [trunk | target]\n",
|
| 505 |
+
" nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 506 |
+
" self.cat = nn.Sequential(\n",
|
| 507 |
+
" nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 508 |
+
" self.vad = nn.Sequential(\n",
|
| 509 |
+
" nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 510 |
+
"\n",
|
| 511 |
+
" def forward(self, x, tgt, utmos):\n",
|
| 512 |
+
" h = self.trunk(x)\n",
|
| 513 |
+
" qmos_in = torch.cat([h, utmos], dim=1) if self.use_utmos else h\n",
|
| 514 |
+
" qmos = self.qmos(qmos_in)\n",
|
| 515 |
+
" emos = self.emos(torch.cat([h, tgt], dim=1))\n",
|
| 516 |
+
" cat_logits = self.cat(h)\n",
|
| 517 |
+
" vad = self.vad(h)\n",
|
| 518 |
+
" return qmos, emos, cat_logits, vad\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"model = FusionMTL6(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO, USE_UTMOS_FEAT).to(device)\n",
|
| 521 |
+
"\n",
|
| 522 |
+
"TASKS = [\"qmos\", \"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 523 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 524 |
+
"params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 525 |
+
"opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)\n",
|
| 526 |
+
"mse = nn.MSELoss(reduction=\"none\")\n",
|
| 527 |
+
"\n",
|
| 528 |
+
"def soft_ce(logits, target_dist):\n",
|
| 529 |
+
" logq = F.log_softmax(logits, dim=1)\n",
|
| 530 |
+
" return -(target_dist * logq).sum(dim=1)\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"def task_losses(qmos_p, emos_p, cat_logits, vad_p, b):\n",
|
| 533 |
+
" L = {}\n",
|
| 534 |
+
" L[\"qmos\"] = mse(qmos_p, qmos_t[b]).mean()\n",
|
| 535 |
+
" L[\"emos\"] = mse(emos_p, emos_t[b]).mean()\n",
|
| 536 |
+
" L[\"cat\"] = soft_ce(cat_logits, cat_t[b]).mean()\n",
|
| 537 |
+
" if HAS_VAD:\n",
|
| 538 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()\n",
|
| 539 |
+
" L[\"aro\"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()\n",
|
| 540 |
+
" L[\"dom\"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()\n",
|
| 541 |
+
" else:\n",
|
| 542 |
+
" z = torch.zeros((), device=device)\n",
|
| 543 |
+
" L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 544 |
+
" return L\n",
|
| 545 |
+
"\n",
|
| 546 |
+
"def combine(L):\n",
|
| 547 |
+
" if USE_UNCERTAINTY:\n",
|
| 548 |
+
" tot = 0.0\n",
|
| 549 |
+
" for i, t in enumerate(TASKS):\n",
|
| 550 |
+
" tot = tot + torch.exp(-log_var[i]) * L[t] + log_var[i]\n",
|
| 551 |
+
" return tot\n",
|
| 552 |
+
" return sum(LOSS_W[t] * L[t] for t in TASKS)\n",
|
| 553 |
+
"\n",
|
| 554 |
+
"@torch.no_grad()\n",
|
| 555 |
+
"def eval_val():\n",
|
| 556 |
+
" model.eval()\n",
|
| 557 |
+
" qp, ep, cl, vp = model(Xn_t[va_idx], T_t[va_idx], Un_t[va_idx])\n",
|
| 558 |
+
" qp = qp.cpu().numpy().ravel(); ep = ep.cpu().numpy().ravel()\n",
|
| 559 |
+
" out = {\"qmos\": spearmanr(qp, y_qmos[va_idx]).correlation,\n",
|
| 560 |
+
" \"emos\": spearmanr(ep, y_emos[va_idx]).correlation}\n",
|
| 561 |
+
" if USE_UTMOS_FEAT:\n",
|
| 562 |
+
" out[\"qmos_utmos\"] = spearmanr(U[va_idx, 0], y_qmos[va_idx]).correlation # baseline UTMOS đơn lẻ\n",
|
| 563 |
+
" if HAS_VAD:\n",
|
| 564 |
+
" vp = vp.cpu().numpy()\n",
|
| 565 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 566 |
+
" out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation\n",
|
| 567 |
+
" q = F.softmax(cl, dim=1).cpu().numpy(); p = y_cat[va_idx]\n",
|
| 568 |
+
" kl = (p * (np.log(p + 1e-9) - np.log(q + 1e-9))).sum(1).mean()\n",
|
| 569 |
+
" out[\"cat_negkl\"] = float(-kl)\n",
|
| 570 |
+
" return out\n",
|
| 571 |
+
"\n",
|
| 572 |
+
"def val_score(m):\n",
|
| 573 |
+
" \"\"\"Điểm tổng early-stop = TB SRCC các task liên tục (qmos+emos+VAD).\"\"\"\n",
|
| 574 |
+
" keys = [\"qmos\", \"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 575 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 576 |
+
"\n",
|
| 577 |
+
"best_score, best_state, bad = -1e9, None, 0\n",
|
| 578 |
+
"tr_t = torch.tensor(tr_idx, device=device)\n",
|
| 579 |
+
"for ep_i in range(1, EPOCHS + 1):\n",
|
| 580 |
+
" model.train()\n",
|
| 581 |
+
" perm = tr_t[torch.randperm(len(tr_t), device=device)]\n",
|
| 582 |
+
" run = 0.0\n",
|
| 583 |
+
" for i in range(0, len(perm), BATCH):\n",
|
| 584 |
+
" b = perm[i:i + BATCH]\n",
|
| 585 |
+
" opt.zero_grad()\n",
|
| 586 |
+
" qmos_p, emos_p, cat_logits, vad_p = model(Xn_t[b], T_t[b], Un_t[b])\n",
|
| 587 |
+
" L = task_losses(qmos_p, emos_p, cat_logits, vad_p, b)\n",
|
| 588 |
+
" loss = combine(L)\n",
|
| 589 |
+
" loss.backward(); opt.step()\n",
|
| 590 |
+
" run += loss.item() * len(b)\n",
|
| 591 |
+
" m = eval_val()\n",
|
| 592 |
+
" sc = val_score(m)\n",
|
| 593 |
+
" if sc > best_score:\n",
|
| 594 |
+
" best_score = sc\n",
|
| 595 |
+
" best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
|
| 596 |
+
" bad = 0\n",
|
| 597 |
+
" else:\n",
|
| 598 |
+
" bad += 1\n",
|
| 599 |
+
" if ep_i % 5 == 0 or ep_i == 1:\n",
|
| 600 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"qmos\", \"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
|
| 601 |
+
" print(f\"epoch {ep_i:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}\")\n",
|
| 602 |
+
" if bad >= PATIENCE:\n",
|
| 603 |
+
" print(f\"Early stop ở epoch {ep_i}.\")\n",
|
| 604 |
+
" break\n",
|
| 605 |
+
"\n",
|
| 606 |
+
"model.load_state_dict(best_state)\n",
|
| 607 |
+
"final = eval_val()\n",
|
| 608 |
+
"print(\"\\n✅ VAL (nội bộ) — exp07 (fusion + QMOS head):\")\n",
|
| 609 |
+
"print(f\" QMOS SRCC = {final['qmos']:.4f}\", end=\"\")\n",
|
| 610 |
+
"if \"qmos_utmos\" in final:\n",
|
| 611 |
+
" tag = \"✅ vượt UTMOS\" if final[\"qmos\"] > final[\"qmos_utmos\"] else \"⚠️ CHƯA vượt UTMOS\"\n",
|
| 612 |
+
" print(f\" (UTMOS đơn lẻ = {final['qmos_utmos']:.4f} → {tag})\")\n",
|
| 613 |
+
"else:\n",
|
| 614 |
+
" print()\n",
|
| 615 |
+
"print(f\" EMOS SRCC = {final['emos']:.4f} (mốc exp04 = {EXP04['emos']})\")\n",
|
| 616 |
+
"if HAS_VAD:\n",
|
| 617 |
+
" print(f\" VAL/ARO/DOM = {final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}\"\n",
|
| 618 |
+
" f\" (mốc exp04 = {EXP04['val']}/{EXP04['aro']}/{EXP04['dom']})\")\n",
|
| 619 |
+
"# Cảnh báo negative transfer (gộp QMOS làm tụt cảm xúc)\n",
|
| 620 |
+
"warn = []\n",
|
| 621 |
+
"if final[\"emos\"] < EXP04[\"emos\"] - 0.02:\n",
|
| 622 |
+
" warn.append(f\"EMOS {final['emos']:.3f} < {EXP04['emos']}\")\n",
|
| 623 |
+
"if HAS_VAD:\n",
|
| 624 |
+
" for t in [\"val\", \"aro\", \"dom\"]:\n",
|
| 625 |
+
" if final[t] < EXP04[t] - 0.02:\n",
|
| 626 |
+
" warn.append(f\"{t.upper()} {final[t]:.3f} < {EXP04[t]}\")\n",
|
| 627 |
+
"if warn:\n",
|
| 628 |
+
" print(\" ⚠️ NEGATIVE TRANSFER? Cảm xúc tụt so exp04:\", \"; \".join(warn),\n",
|
| 629 |
+
" \"\\n → cân nhắc giữ exp04 cho 5 cột cảm xúc + chỉ lấy QMOS từ exp07/exp06.\")\n",
|
| 630 |
+
"else:\n",
|
| 631 |
+
" print(\" ✅ Không thấy 5 cột cảm xúc tụt rõ so exp04.\")\n",
|
| 632 |
+
"if USE_UNCERTAINTY:\n",
|
| 633 |
+
" print(\" log σ² mỗi task:\", {t: round(float(log_var[i]), 3) for i, t in enumerate(TASKS)})\n",
|
| 634 |
+
"\n",
|
| 635 |
+
"torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
|
| 636 |
+
" \"u_mu\": u_mu, \"u_sd\": u_sd,\n",
|
| 637 |
+
" \"qmos_mu\": qmos_mu, \"qmos_sd\": qmos_sd, \"emos_mu\": emos_mu, \"emos_sd\": emos_sd,\n",
|
| 638 |
+
" \"vad_mu\": vad_mu, \"vad_sd\": vad_sd, \"FEAT_DIM\": FEAT_DIM,\n",
|
| 639 |
+
" \"USE_E2V\": USE_E2V, \"USE_SAILER\": USE_SAILER, \"USE_CLASSPROB\": USE_CLASSPROB,\n",
|
| 640 |
+
" \"USE_UTMOS_FEAT\": USE_UTMOS_FEAT, \"val_score\": best_score},\n",
|
| 641 |
+
" os.path.join(OUT_DIR, \"fusion_qmos_mtl.pt\"))\n",
|
| 642 |
+
"print(\"Đã lưu\", os.path.join(OUT_DIR, \"fusion_qmos_mtl.pt\"))"
|
| 643 |
+
]
|
| 644 |
+
},
|
| 645 |
+
{
|
| 646 |
+
"cell_type": "markdown",
|
| 647 |
+
"id": "9a788c48",
|
| 648 |
+
"metadata": {},
|
| 649 |
+
"source": [
|
| 650 |
+
"## 6. Dự đoán DEV → `answer.txt` đủ 6 cột (QMOS giờ từ HEAD, không phải SpeechMOS riêng)"
|
| 651 |
+
]
|
| 652 |
+
},
|
| 653 |
+
{
|
| 654 |
+
"cell_type": "code",
|
| 655 |
+
"execution_count": null,
|
| 656 |
+
"id": "8acd813f",
|
| 657 |
+
"metadata": {
|
| 658 |
+
"lines_to_next_cell": 1
|
| 659 |
+
},
|
| 660 |
+
"outputs": [],
|
| 661 |
+
"source": [
|
| 662 |
+
"def list_dev():\n",
|
| 663 |
+
" with open(DEV_SCP) as f:\n",
|
| 664 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 665 |
+
"\n",
|
| 666 |
+
"dev_names = list_dev()\n",
|
| 667 |
+
"if LIMIT_DEV:\n",
|
| 668 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 669 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 670 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"e2v_dev = extract_e2v(dev_stems, \"dev\") if USE_E2V else {}\n",
|
| 673 |
+
"sailer_dev = extract_sailer(dev_stems, \"dev\") if USE_SAILER else {}\n",
|
| 674 |
+
"utmos_dev = extract_utmos(dev_names, \"dev\") if USE_UTMOS_FEAT else {}\n",
|
| 675 |
+
"\n",
|
| 676 |
+
"@torch.no_grad()\n",
|
| 677 |
+
"def predict_all(sid):\n",
|
| 678 |
+
" f = audio_feature(sid, e2v_dev, sailer_dev)\n",
|
| 679 |
+
" if f is None:\n",
|
| 680 |
+
" return None\n",
|
| 681 |
+
" fn = (f[None, :] - feat_mean) / feat_std\n",
|
| 682 |
+
" tgt = onehot_target(target_map.get(sid))[None, :]\n",
|
| 683 |
+
" u = np.array([[utmos_dev.get(sid, 3.0)]], dtype=np.float32)\n",
|
| 684 |
+
" un = (u - u_mu) / u_sd\n",
|
| 685 |
+
" model.eval()\n",
|
| 686 |
+
" qmos_p, emos_p, cat_logits, vad_p = model(to_t(fn), to_t(tgt), to_t(un))\n",
|
| 687 |
+
" qmos = float(qmos_p.item()) * qmos_sd + qmos_mu\n",
|
| 688 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 689 |
+
" cat5 = F.softmax(cat_logits, dim=1)[0].cpu().numpy()\n",
|
| 690 |
+
" vad3 = vad_p[0].cpu().numpy() * vad_sd + vad_mu\n",
|
| 691 |
+
" return qmos, emos, cat5, vad3\n",
|
| 692 |
+
"\n",
|
| 693 |
+
"def fmt_cat(probs5):\n",
|
| 694 |
+
" return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 695 |
+
"\n",
|
| 696 |
+
"def build_answer(out_path):\n",
|
| 697 |
+
" from tqdm.auto import tqdm\n",
|
| 698 |
+
" n_real = n_default = 0\n",
|
| 699 |
+
" with open(out_path, \"w\") as f:\n",
|
| 700 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 701 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 702 |
+
" sid = stem(name)\n",
|
| 703 |
+
" pred = predict_all(sid)\n",
|
| 704 |
+
" if pred is None:\n",
|
| 705 |
+
" # rơi về: QMOS=UTMOS nếu có, còn lại mặc định\n",
|
| 706 |
+
" qmos = utmos_dev.get(sid, 3.0)\n",
|
| 707 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])\n",
|
| 708 |
+
" n_default += 1\n",
|
| 709 |
+
" else:\n",
|
| 710 |
+
" qmos, emos, cat5, vad3 = pred\n",
|
| 711 |
+
" n_real += 1\n",
|
| 712 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
|
| 713 |
+
" f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 714 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | head thật {n_real}, mặc định {n_default}\")\n",
|
| 715 |
+
"\n",
|
| 716 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 717 |
+
"build_answer(answer_path)"
|
| 718 |
+
]
|
| 719 |
+
},
|
| 720 |
+
{
|
| 721 |
+
"cell_type": "markdown",
|
| 722 |
+
"id": "0ac67bb5",
|
| 723 |
+
"metadata": {},
|
| 724 |
+
"source": [
|
| 725 |
+
"## 7. Validate + đóng zip"
|
| 726 |
+
]
|
| 727 |
+
},
|
| 728 |
+
{
|
| 729 |
+
"cell_type": "code",
|
| 730 |
+
"execution_count": null,
|
| 731 |
+
"id": "8244b0c3",
|
| 732 |
+
"metadata": {},
|
| 733 |
+
"outputs": [],
|
| 734 |
+
"source": [
|
| 735 |
+
"def validate(path):\n",
|
| 736 |
+
" import csv\n",
|
| 737 |
+
" with open(path) as f:\n",
|
| 738 |
+
" rows = list(csv.reader(f))\n",
|
| 739 |
+
" header = rows[0]\n",
|
| 740 |
+
" assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
|
| 741 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 742 |
+
" assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
|
| 743 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
|
| 744 |
+
"\n",
|
| 745 |
+
"validate(answer_path)\n",
|
| 746 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp07_fusion_qmos.zip answer.txt \"\n",
|
| 747 |
+
" f\"&& unzip -l submission_track2_exp07_fusion_qmos.zip\")\n",
|
| 748 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp07_fusion_qmos.zip\"))"
|
| 749 |
+
]
|
| 750 |
+
},
|
| 751 |
+
{
|
| 752 |
+
"cell_type": "markdown",
|
| 753 |
+
"id": "a73c1c11",
|
| 754 |
+
"metadata": {},
|
| 755 |
+
"source": [
|
| 756 |
+
"## Ghi chú\n",
|
| 757 |
+
"- **Lần đầu** đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`; OK rồi đặt `None`.\n",
|
| 758 |
+
"- **Đọc kết quả mục 5 theo 2 câu hỏi:**\n",
|
| 759 |
+
" 1. QMOS head có **vượt UTMOS đơn lẻ (0.414)** không? (dòng \"vượt/CHƯA vượt UTMOS\")\n",
|
| 760 |
+
" 2. Gộp QMOS có **làm tụt** EMOS/VAD so exp04 không? (dòng \"NEGATIVE TRANSFER?\")\n",
|
| 761 |
+
"- **Quyết định nộp:**\n",
|
| 762 |
+
" - Nếu QMOS↑ và cảm xúc KHÔNG tụt → nộp answer.txt exp07 (1 model trọn 6 cột — đẹp cho paper).\n",
|
| 763 |
+
" - Nếu QMOS↑ nhưng cảm xúc TỤT → giữ exp04 cho 5 cột cảm xúc, chỉ lấy **cột QMOS** của exp07/exp06 ghép vào.\n",
|
| 764 |
+
" - Nếu QMOS không vượt UTMOS → kết luận \"chất lượng trực giao cảm xúc\" (vẫn là phát hiện cho paper); giữ exp04.\n",
|
| 765 |
+
"- **Ablation cho paper**: `USE_UTMOS_FEAT=False` (QMOS chỉ từ trunk cảm xúc) → đo trực tiếp giả thuyết của bạn.\n",
|
| 766 |
+
"- Cache dùng CHUNG `fusion_cache/` với exp04 → **Save Version** giữ lại.\n",
|
| 767 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp07)."
|
| 768 |
+
]
|
| 769 |
+
}
|
| 770 |
+
],
|
| 771 |
+
"metadata": {
|
| 772 |
+
"jupytext": {
|
| 773 |
+
"cell_metadata_filter": "-all",
|
| 774 |
+
"main_language": "python",
|
| 775 |
+
"notebook_metadata_filter": "-all"
|
| 776 |
+
}
|
| 777 |
+
},
|
| 778 |
+
"nbformat": 4,
|
| 779 |
+
"nbformat_minor": 5
|
| 780 |
+
}
|
track2/exp07_fusion_qmos_pipeline.py
ADDED
|
@@ -0,0 +1,654 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp07 (FUSION + QMOS head, HỢP NHẤT 6 cột) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Khác exp04 ở đâu:** exp04 để **QMOS riêng** (UTMOS zero-shot). exp07 **gộp luôn QMOS vào trunk chung**
|
| 5 |
+
# → 1 model multi-task dự đoán **đủ 6 đầu ra**: QMOS · EMOS · CAT · VAL · ARO · DOM.
|
| 6 |
+
#
|
| 7 |
+
# ## Giả thuyết (của bạn) cần kiểm chứng
|
| 8 |
+
# "Chất giọng tự nhiên có liên quan tới cảm nhận cảm xúc" → nếu đúng, QMOS sẽ **hưởng lợi** từ biểu diễn
|
| 9 |
+
# cảm xúc chung (emotion2vec + SAILER). **Rủi ro:** 2 backbone này chuyên *cảm xúc*, chưa chắc bắt tốt
|
| 10 |
+
# *lỗi chất lượng/artifact* (thứ UTMOS chuyên trị) → QMOS có thể **thua** UTMOS, hoặc gộp làm **tụt** EMOS/VAD.
|
| 11 |
+
#
|
| 12 |
+
# ## Lưới an toàn trong thiết kế
|
| 13 |
+
# - **Vẫn đưa điểm UTMOS làm 1 đầu vào** cho QMOS head (`USE_UTMOS_FEAT`) → head học **chỉnh sửa** quanh
|
| 14 |
+
# 0.414 thay vì học lại từ đầu → khó tệ hơn UTMOS.
|
| 15 |
+
# - **In SRCC cả 6 cột + so mốc exp04** (EMOS 0.788 · CAT err 0.145 · VAL 0.578 · ARO 0.754 · DOM 0.706)
|
| 16 |
+
# → cảnh báo ngay nếu gộp QMOS làm tụt 5 cột cảm xúc.
|
| 17 |
+
# - **File riêng**, KHÔNG đụng `exp04_fusion_pipeline.py` (exp04 vẫn nguyên).
|
| 18 |
+
#
|
| 19 |
+
# ```
|
| 20 |
+
# mỗi wav ─► [e2v_emb | e2v_p5 | sailer_emb | sailer_p9 | sailer_vad3] ─► TRUNK chung
|
| 21 |
+
# │
|
| 22 |
+
# ┌──────────────┬───────────────┬─────────────┬───────────────────┤
|
| 23 |
+
# [QMOS head] [EMOS head] [CAT head] [VAD head]
|
| 24 |
+
# trunk + UTMOS trunk + target trunk trunk
|
| 25 |
+
# ```
|
| 26 |
+
#
|
| 27 |
+
# **Cách chạy:** GPU T4 + Internet On → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.
|
| 28 |
+
# Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`. Dùng CHUNG cache `fusion_cache/` với exp04 (thêm `utmos_*.npz`).
|
| 29 |
+
|
| 30 |
+
# %% [markdown]
|
| 31 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 32 |
+
|
| 33 |
+
# %%
|
| 34 |
+
import os
|
| 35 |
+
|
| 36 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
|
| 37 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 38 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
|
| 39 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
|
| 40 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 41 |
+
|
| 42 |
+
OUT_DIR = "/kaggle/working"
|
| 43 |
+
CACHE_DIR = "/kaggle/working/fusion_cache" # dùng CHUNG với exp04 (thêm utmos_*.npz)
|
| 44 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 45 |
+
|
| 46 |
+
# ── Siêu tham số ─────────────────────────────────────────────────────────────
|
| 47 |
+
DEVICE = "cuda"
|
| 48 |
+
TRUNK_HIDDEN = 512
|
| 49 |
+
HEAD_HIDDEN = 128
|
| 50 |
+
DROPOUT = 0.3
|
| 51 |
+
LR = 1e-3
|
| 52 |
+
EPOCHS = 80
|
| 53 |
+
BATCH = 64
|
| 54 |
+
VAL_FRAC = 0.10
|
| 55 |
+
PATIENCE = 15
|
| 56 |
+
SEED = 42
|
| 57 |
+
|
| 58 |
+
USE_UNCERTAINTY = True # tự cân 6 loss (Kendall); False = dùng LOSS_W cố định
|
| 59 |
+
LOSS_W = {"qmos": 1.0, "emos": 1.0, "cat": 1.0, "val": 1.0, "aro": 1.0, "dom": 1.0}
|
| 60 |
+
USE_E2V = True
|
| 61 |
+
USE_SAILER = True
|
| 62 |
+
USE_CLASSPROB = True
|
| 63 |
+
USE_UTMOS_FEAT = True # đưa điểm UTMOS làm đầu vào QMOS head (neo residual quanh 0.414)
|
| 64 |
+
|
| 65 |
+
LIMIT_TRAIN = None
|
| 66 |
+
LIMIT_DEV = None
|
| 67 |
+
|
| 68 |
+
# Mốc exp04 để so (cảnh báo nếu tụt khi gộp QMOS)
|
| 69 |
+
EXP04 = {"emos": 0.788, "cat_err": 0.145, "val": 0.578, "aro": 0.754, "dom": 0.706, "qmos_utmos": 0.414}
|
| 70 |
+
|
| 71 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 72 |
+
SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
|
| 73 |
+
EMO2SAILER = {"angry": 0, "happy": 4, "neutral": 5, "sad": 6, "surprised": 7}
|
| 74 |
+
|
| 75 |
+
_EMO_ALIAS = {
|
| 76 |
+
"angry": "angry", "anger": "angry",
|
| 77 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 78 |
+
"neutral": "neutral", "calm": "neutral",
|
| 79 |
+
"sad": "sad", "sadness": "sad",
|
| 80 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def norm_emotion(label):
|
| 84 |
+
key = str(label).strip().lower()
|
| 85 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 86 |
+
|
| 87 |
+
def stem(p):
|
| 88 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 89 |
+
|
| 90 |
+
assert USE_E2V or USE_SAILER, "Phải bật ít nhất 1 backbone."
|
| 91 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 92 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 93 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 94 |
+
|
| 95 |
+
# %% [markdown]
|
| 96 |
+
# ## 1. Cài đặt + tải code SAILER
|
| 97 |
+
|
| 98 |
+
# %%
|
| 99 |
+
import sys, subprocess
|
| 100 |
+
|
| 101 |
+
def pip_install(*pkgs):
|
| 102 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 103 |
+
|
| 104 |
+
pip_install("speechmos", "funasr", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
|
| 105 |
+
|
| 106 |
+
if USE_SAILER:
|
| 107 |
+
pip_install("loralib", "speechbrain")
|
| 108 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 109 |
+
if not os.path.exists(REPO_DIR):
|
| 110 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 111 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 112 |
+
if REPO_DIR not in sys.path:
|
| 113 |
+
sys.path.insert(0, REPO_DIR)
|
| 114 |
+
|
| 115 |
+
# %% [markdown]
|
| 116 |
+
# ## 2. Đọc & gộp nhãn (gộp theo wavID) — THÊM cột qMOS
|
| 117 |
+
# Khác exp04: gộp thêm **qMOS** (= TB `qMOS` theo wav) làm nhãn cho QMOS head.
|
| 118 |
+
|
| 119 |
+
# %%
|
| 120 |
+
import numpy as np
|
| 121 |
+
import pandas as pd
|
| 122 |
+
|
| 123 |
+
def load_target_emotions():
|
| 124 |
+
tgt = {}
|
| 125 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 126 |
+
for ln in f:
|
| 127 |
+
parts = ln.strip().split("|")
|
| 128 |
+
if len(parts) < 2:
|
| 129 |
+
continue
|
| 130 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 131 |
+
return tgt
|
| 132 |
+
|
| 133 |
+
def _col(cols_map, *names, default_idx=None, df=None):
|
| 134 |
+
for n in names:
|
| 135 |
+
if n in cols_map:
|
| 136 |
+
return cols_map[n]
|
| 137 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 138 |
+
|
| 139 |
+
def parse_emocat_votes(cell):
|
| 140 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 141 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 142 |
+
e = norm_emotion(tok)
|
| 143 |
+
if e in EMOTIONS5:
|
| 144 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 145 |
+
return v
|
| 146 |
+
|
| 147 |
+
def load_train_labels():
|
| 148 |
+
"""train.csv → DataFrame [wavID, qmos, emos, val, aro, dom, cat0..cat4] gộp theo wav."""
|
| 149 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 150 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 151 |
+
wav_col = _col(cols, "wavid", "wav", default_idx=1, df=df)
|
| 152 |
+
qmos_col = _col(cols, "qmos", "mos")
|
| 153 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 154 |
+
val_col = _col(cols, "val", "valence")
|
| 155 |
+
aro_col = _col(cols, "aro", "arousal")
|
| 156 |
+
dom_col = _col(cols, "dom", "dominance")
|
| 157 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 158 |
+
assert qmos_col, f"Không thấy cột qMOS trong train.csv (cột: {list(df.columns)})"
|
| 159 |
+
assert emos_col, f"Không thấy cột eMOS trong train.csv (cột: {list(df.columns)})"
|
| 160 |
+
|
| 161 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 162 |
+
rows = []
|
| 163 |
+
for sid, g in df.groupby("_stem"):
|
| 164 |
+
rec = {"wavID": sid,
|
| 165 |
+
"qmos": float(g[qmos_col].mean()),
|
| 166 |
+
"emos": float(g[emos_col].mean())}
|
| 167 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 168 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 169 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 170 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 171 |
+
if cat_col:
|
| 172 |
+
for cell in g[cat_col]:
|
| 173 |
+
votes += parse_emocat_votes(cell)
|
| 174 |
+
s = votes.sum()
|
| 175 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)
|
| 176 |
+
for i in range(len(EMOTIONS5)):
|
| 177 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 178 |
+
rows.append(rec)
|
| 179 |
+
return pd.DataFrame(rows)
|
| 180 |
+
|
| 181 |
+
target_map = load_target_emotions()
|
| 182 |
+
train_df = load_train_labels()
|
| 183 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 184 |
+
print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
|
| 185 |
+
print("qMOS:", train_df["qmos"].describe()[["mean", "std", "min", "max"]].to_dict())
|
| 186 |
+
print("eMOS:", train_df["emos"].describe()[["mean", "std", "min", "max"]].to_dict())
|
| 187 |
+
train_df.head()
|
| 188 |
+
|
| 189 |
+
# %% [markdown]
|
| 190 |
+
# ## 3. Trích đặc trưng 2 backbone + điểm UTMOS (cache CHUNG với exp04)
|
| 191 |
+
|
| 192 |
+
# %%
|
| 193 |
+
import torch
|
| 194 |
+
import torch.nn.functional as F
|
| 195 |
+
|
| 196 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 197 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU")
|
| 198 |
+
|
| 199 |
+
def extract_e2v(stems, tag):
|
| 200 |
+
from tqdm.auto import tqdm
|
| 201 |
+
cache_path = os.path.join(CACHE_DIR, f"e2v_{tag}.npz")
|
| 202 |
+
store = {}
|
| 203 |
+
if os.path.exists(cache_path):
|
| 204 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 205 |
+
store = {k: z[k] for k in z.files}
|
| 206 |
+
print(f"[e2v/{tag}] nạp cache: {len(store)}")
|
| 207 |
+
todo = [s for s in stems if s not in store]
|
| 208 |
+
if todo:
|
| 209 |
+
from funasr import AutoModel
|
| 210 |
+
m = AutoModel(model="iic/emotion2vec_plus_large", hub="hf", device=device)
|
| 211 |
+
for i, s in enumerate(tqdm(todo, desc=f"e2v {tag}")):
|
| 212 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 213 |
+
if not os.path.exists(wav):
|
| 214 |
+
continue
|
| 215 |
+
r = m.generate(wav, granularity="utterance", extract_embedding=True)[0]
|
| 216 |
+
emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
|
| 217 |
+
probs = {e: 0.0 for e in EMOTIONS5}
|
| 218 |
+
for lab, sc in zip(r["labels"], r["scores"]):
|
| 219 |
+
name = lab.split("/")[-1]
|
| 220 |
+
if name in probs:
|
| 221 |
+
probs[name] = float(sc)
|
| 222 |
+
tot = sum(probs.values())
|
| 223 |
+
p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)
|
| 224 |
+
store[s] = np.concatenate([emb, p5]).astype(np.float32)
|
| 225 |
+
if (i + 1) % 500 == 0:
|
| 226 |
+
np.savez(cache_path, **store)
|
| 227 |
+
np.savez(cache_path, **store)
|
| 228 |
+
del m
|
| 229 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 230 |
+
return {s: (v[:-5], v[-5:]) for s, v in store.items()}
|
| 231 |
+
|
| 232 |
+
def _pool_feat(features):
|
| 233 |
+
f = features.detach().cpu().numpy()
|
| 234 |
+
if f.ndim <= 1:
|
| 235 |
+
return f.reshape(-1).astype(np.float32)
|
| 236 |
+
return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)
|
| 237 |
+
|
| 238 |
+
def extract_sailer(stems, tag):
|
| 239 |
+
import librosa
|
| 240 |
+
from tqdm.auto import tqdm
|
| 241 |
+
cache_path = os.path.join(CACHE_DIR, f"sailer_{tag}.npz")
|
| 242 |
+
store = {}
|
| 243 |
+
if os.path.exists(cache_path):
|
| 244 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 245 |
+
store = {k: z[k] for k in z.files}
|
| 246 |
+
print(f"[sailer/{tag}] nạp cache: {len(store)}")
|
| 247 |
+
todo = [s for s in stems if s not in store]
|
| 248 |
+
if todo:
|
| 249 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper
|
| 250 |
+
sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device).eval()
|
| 251 |
+
with torch.no_grad():
|
| 252 |
+
for i, s in enumerate(tqdm(todo, desc=f"sailer {tag}")):
|
| 253 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 254 |
+
if not os.path.exists(wav):
|
| 255 |
+
continue
|
| 256 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 257 |
+
wave = wave[: 15 * 16000]
|
| 258 |
+
data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
|
| 259 |
+
logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
|
| 260 |
+
emb = _pool_feat(feat)
|
| 261 |
+
p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
|
| 262 |
+
vad3 = np.array([1 + 4 * float(valence.item()),
|
| 263 |
+
1 + 4 * float(arousal.item()),
|
| 264 |
+
1 + 4 * float(dominance.item())], dtype=np.float32)
|
| 265 |
+
store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)
|
| 266 |
+
if (i + 1) % 500 == 0:
|
| 267 |
+
np.savez(cache_path, **store)
|
| 268 |
+
np.savez(cache_path, **store)
|
| 269 |
+
del sailer
|
| 270 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 271 |
+
return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}
|
| 272 |
+
|
| 273 |
+
def extract_utmos(names, tag):
|
| 274 |
+
"""Chấm UTMOS từng wav (theo TÊN, vì DEV gọi .wav theo tên). → dict {stem: score}.
|
| 275 |
+
Cache CACHE_DIR/utmos_<tag>.npz. Dùng vừa làm đầu vào QMOS head, vừa làm baseline so sánh."""
|
| 276 |
+
import librosa
|
| 277 |
+
from tqdm.auto import tqdm
|
| 278 |
+
cache_path = os.path.join(CACHE_DIR, f"utmos_{tag}.npz")
|
| 279 |
+
store = {}
|
| 280 |
+
if os.path.exists(cache_path):
|
| 281 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 282 |
+
store = {k: float(z[k]) for k in z.files}
|
| 283 |
+
print(f"[utmos/{tag}] nạp cache: {len(store)}")
|
| 284 |
+
todo = [n for n in names if stem(n) not in store]
|
| 285 |
+
if todo:
|
| 286 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong",
|
| 287 |
+
trust_repo=True).to(device).eval()
|
| 288 |
+
with torch.no_grad():
|
| 289 |
+
for i, n in enumerate(tqdm(todo, desc=f"utmos {tag}")):
|
| 290 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else n + ".wav")
|
| 291 |
+
if not os.path.exists(wav):
|
| 292 |
+
continue
|
| 293 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 294 |
+
store[stem(n)] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),
|
| 295 |
+
sr=16000).mean().item())
|
| 296 |
+
if (i + 1) % 500 == 0:
|
| 297 |
+
np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
|
| 298 |
+
np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
|
| 299 |
+
del predictor
|
| 300 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 301 |
+
return store
|
| 302 |
+
|
| 303 |
+
# %% [markdown]
|
| 304 |
+
# ## 4. Dựng feature + nhãn cho train
|
| 305 |
+
# Feature audio (cảm xúc) = `[e2v_emb | e2v_p5 | sailer_emb | sailer_p9 | sailer_vad3]` (như exp04).
|
| 306 |
+
# Thêm: vector **UTMOS** (1 số/ wav) cho QMOS head, và nhãn **qMOS**.
|
| 307 |
+
|
| 308 |
+
# %%
|
| 309 |
+
train_stems = list(train_df["wavID"])
|
| 310 |
+
if LIMIT_TRAIN:
|
| 311 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 312 |
+
|
| 313 |
+
e2v_tr = extract_e2v(train_stems, "train") if USE_E2V else {}
|
| 314 |
+
sailer_tr = extract_sailer(train_stems, "train") if USE_SAILER else {}
|
| 315 |
+
utmos_tr = extract_utmos(train_stems, "train") if USE_UTMOS_FEAT else {}
|
| 316 |
+
|
| 317 |
+
def audio_feature(sid, e2v_map, sailer_map):
|
| 318 |
+
parts = []
|
| 319 |
+
if USE_E2V:
|
| 320 |
+
pk = e2v_map.get(sid)
|
| 321 |
+
if pk is None:
|
| 322 |
+
return None
|
| 323 |
+
emb, p5 = pk
|
| 324 |
+
parts.append(emb)
|
| 325 |
+
if USE_CLASSPROB:
|
| 326 |
+
parts.append(p5)
|
| 327 |
+
if USE_SAILER:
|
| 328 |
+
pk = sailer_map.get(sid)
|
| 329 |
+
if pk is None:
|
| 330 |
+
return None
|
| 331 |
+
emb, p9, vad3 = pk
|
| 332 |
+
parts.append(emb)
|
| 333 |
+
if USE_CLASSPROB:
|
| 334 |
+
parts.append(p9); parts.append(vad3)
|
| 335 |
+
return np.concatenate(parts).astype(np.float32)
|
| 336 |
+
|
| 337 |
+
def onehot_target(tgt):
|
| 338 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 339 |
+
if tgt in EMOTIONS5:
|
| 340 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 341 |
+
return v
|
| 342 |
+
|
| 343 |
+
lab = train_df.set_index("wavID")
|
| 344 |
+
X, T, U, y_qmos, y_emos, y_vad, y_cat = [], [], [], [], [], [], []
|
| 345 |
+
for s in train_stems:
|
| 346 |
+
f = audio_feature(s, e2v_tr, sailer_tr)
|
| 347 |
+
tgt = target_map.get(s)
|
| 348 |
+
if f is None or tgt is None or s not in lab.index:
|
| 349 |
+
continue
|
| 350 |
+
if USE_UTMOS_FEAT and s not in utmos_tr:
|
| 351 |
+
continue
|
| 352 |
+
X.append(f)
|
| 353 |
+
T.append(onehot_target(tgt))
|
| 354 |
+
U.append(utmos_tr.get(s, 3.0) if USE_UTMOS_FEAT else 0.0)
|
| 355 |
+
y_qmos.append(lab.loc[s, "qmos"])
|
| 356 |
+
y_emos.append(lab.loc[s, "emos"])
|
| 357 |
+
y_vad.append([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]])
|
| 358 |
+
y_cat.append([lab.loc[s, f"cat{i}"] for i in range(len(EMOTIONS5))])
|
| 359 |
+
|
| 360 |
+
X = np.stack(X).astype(np.float32)
|
| 361 |
+
T = np.stack(T).astype(np.float32)
|
| 362 |
+
U = np.array(U, dtype=np.float32).reshape(-1, 1)
|
| 363 |
+
y_qmos = np.array(y_qmos, dtype=np.float32)
|
| 364 |
+
y_emos = np.array(y_emos, dtype=np.float32)
|
| 365 |
+
y_vad = np.array(y_vad, dtype=np.float32)
|
| 366 |
+
y_cat = np.array(y_cat, dtype=np.float32)
|
| 367 |
+
FEAT_DIM = X.shape[1]
|
| 368 |
+
print(f"Train: X={X.shape} U={U.shape} qmos={y_qmos.shape} emos={y_emos.shape} vad={y_vad.shape}")
|
| 369 |
+
|
| 370 |
+
# Chuẩn hóa feature audio + UTMOS (z-score), lưu mean/std.
|
| 371 |
+
feat_mean = X.mean(0, keepdims=True); feat_std = X.std(0, keepdims=True) + 1e-6
|
| 372 |
+
Xn = (X - feat_mean) / feat_std
|
| 373 |
+
u_mu, u_sd = float(U.mean()), float(U.std() + 1e-6)
|
| 374 |
+
Un = (U - u_mu) / u_sd
|
| 375 |
+
|
| 376 |
+
# Chuẩn hóa nhãn liên tục về z-score.
|
| 377 |
+
qmos_mu, qmos_sd = float(y_qmos.mean()), float(y_qmos.std() + 1e-6)
|
| 378 |
+
y_qmos_z = (y_qmos - qmos_mu) / qmos_sd
|
| 379 |
+
emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6)
|
| 380 |
+
y_emos_z = (y_emos - emos_mu) / emos_sd
|
| 381 |
+
if HAS_VAD:
|
| 382 |
+
vad_mu = np.nanmean(y_vad, axis=0); vad_sd = np.nanstd(y_vad, axis=0) + 1e-6
|
| 383 |
+
y_vad_z = (y_vad - vad_mu) / vad_sd
|
| 384 |
+
else:
|
| 385 |
+
vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
|
| 386 |
+
y_vad_z = np.zeros_like(y_vad)
|
| 387 |
+
|
| 388 |
+
# %% [markdown]
|
| 389 |
+
# ## 5. Model fusion multi-task (6 head) + train loop
|
| 390 |
+
# Thêm so exp04: **QMOS head** nhận `[trunk | UTMOS]` → 1; `qmos` vào uncertainty weighting (6 task).
|
| 391 |
+
|
| 392 |
+
# %%
|
| 393 |
+
import torch.nn as nn
|
| 394 |
+
from scipy.stats import spearmanr
|
| 395 |
+
from sklearn.model_selection import train_test_split
|
| 396 |
+
|
| 397 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 398 |
+
N_EMO = len(EMOTIONS5)
|
| 399 |
+
idx_all = np.arange(X.shape[0])
|
| 400 |
+
tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)
|
| 401 |
+
|
| 402 |
+
def to_t(a):
|
| 403 |
+
return torch.tensor(a, dtype=torch.float32, device=device)
|
| 404 |
+
|
| 405 |
+
Xn_t, T_t, Un_t = to_t(Xn), to_t(T), to_t(Un)
|
| 406 |
+
qmos_t = to_t(y_qmos_z).unsqueeze(1)
|
| 407 |
+
emos_t = to_t(y_emos_z).unsqueeze(1)
|
| 408 |
+
vad_t = to_t(y_vad_z)
|
| 409 |
+
cat_t = to_t(y_cat)
|
| 410 |
+
|
| 411 |
+
class FusionMTL6(nn.Module):
|
| 412 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo, use_utmos):
|
| 413 |
+
super().__init__()
|
| 414 |
+
self.use_utmos = use_utmos
|
| 415 |
+
self.trunk = nn.Sequential(
|
| 416 |
+
nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 417 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 418 |
+
)
|
| 419 |
+
self.qmos = nn.Sequential( # nhận [trunk | utmos] nếu bật
|
| 420 |
+
nn.Linear(trunk_h + (1 if use_utmos else 0), head_h), nn.ReLU(), nn.Dropout(p),
|
| 421 |
+
nn.Linear(head_h, 1))
|
| 422 |
+
self.emos = nn.Sequential( # nhận [trunk | target]
|
| 423 |
+
nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 424 |
+
self.cat = nn.Sequential(
|
| 425 |
+
nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 426 |
+
self.vad = nn.Sequential(
|
| 427 |
+
nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 428 |
+
|
| 429 |
+
def forward(self, x, tgt, utmos):
|
| 430 |
+
h = self.trunk(x)
|
| 431 |
+
qmos_in = torch.cat([h, utmos], dim=1) if self.use_utmos else h
|
| 432 |
+
qmos = self.qmos(qmos_in)
|
| 433 |
+
emos = self.emos(torch.cat([h, tgt], dim=1))
|
| 434 |
+
cat_logits = self.cat(h)
|
| 435 |
+
vad = self.vad(h)
|
| 436 |
+
return qmos, emos, cat_logits, vad
|
| 437 |
+
|
| 438 |
+
model = FusionMTL6(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO, USE_UTMOS_FEAT).to(device)
|
| 439 |
+
|
| 440 |
+
TASKS = ["qmos", "emos", "cat", "val", "aro", "dom"]
|
| 441 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 442 |
+
params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
|
| 443 |
+
opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)
|
| 444 |
+
mse = nn.MSELoss(reduction="none")
|
| 445 |
+
|
| 446 |
+
def soft_ce(logits, target_dist):
|
| 447 |
+
logq = F.log_softmax(logits, dim=1)
|
| 448 |
+
return -(target_dist * logq).sum(dim=1)
|
| 449 |
+
|
| 450 |
+
def task_losses(qmos_p, emos_p, cat_logits, vad_p, b):
|
| 451 |
+
L = {}
|
| 452 |
+
L["qmos"] = mse(qmos_p, qmos_t[b]).mean()
|
| 453 |
+
L["emos"] = mse(emos_p, emos_t[b]).mean()
|
| 454 |
+
L["cat"] = soft_ce(cat_logits, cat_t[b]).mean()
|
| 455 |
+
if HAS_VAD:
|
| 456 |
+
L["val"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()
|
| 457 |
+
L["aro"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()
|
| 458 |
+
L["dom"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()
|
| 459 |
+
else:
|
| 460 |
+
z = torch.zeros((), device=device)
|
| 461 |
+
L["val"] = L["aro"] = L["dom"] = z
|
| 462 |
+
return L
|
| 463 |
+
|
| 464 |
+
def combine(L):
|
| 465 |
+
if USE_UNCERTAINTY:
|
| 466 |
+
tot = 0.0
|
| 467 |
+
for i, t in enumerate(TASKS):
|
| 468 |
+
tot = tot + torch.exp(-log_var[i]) * L[t] + log_var[i]
|
| 469 |
+
return tot
|
| 470 |
+
return sum(LOSS_W[t] * L[t] for t in TASKS)
|
| 471 |
+
|
| 472 |
+
@torch.no_grad()
|
| 473 |
+
def eval_val():
|
| 474 |
+
model.eval()
|
| 475 |
+
qp, ep, cl, vp = model(Xn_t[va_idx], T_t[va_idx], Un_t[va_idx])
|
| 476 |
+
qp = qp.cpu().numpy().ravel(); ep = ep.cpu().numpy().ravel()
|
| 477 |
+
out = {"qmos": spearmanr(qp, y_qmos[va_idx]).correlation,
|
| 478 |
+
"emos": spearmanr(ep, y_emos[va_idx]).correlation}
|
| 479 |
+
if USE_UTMOS_FEAT:
|
| 480 |
+
out["qmos_utmos"] = spearmanr(U[va_idx, 0], y_qmos[va_idx]).correlation # baseline UTMOS đơn lẻ
|
| 481 |
+
if HAS_VAD:
|
| 482 |
+
vp = vp.cpu().numpy()
|
| 483 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 484 |
+
out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation
|
| 485 |
+
q = F.softmax(cl, dim=1).cpu().numpy(); p = y_cat[va_idx]
|
| 486 |
+
kl = (p * (np.log(p + 1e-9) - np.log(q + 1e-9))).sum(1).mean()
|
| 487 |
+
out["cat_negkl"] = float(-kl)
|
| 488 |
+
return out
|
| 489 |
+
|
| 490 |
+
def val_score(m):
|
| 491 |
+
"""Điểm tổng early-stop = TB SRCC các task liên tục (qmos+emos+VAD)."""
|
| 492 |
+
keys = ["qmos", "emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 493 |
+
return float(np.mean([m[k] for k in keys]))
|
| 494 |
+
|
| 495 |
+
best_score, best_state, bad = -1e9, None, 0
|
| 496 |
+
tr_t = torch.tensor(tr_idx, device=device)
|
| 497 |
+
for ep_i in range(1, EPOCHS + 1):
|
| 498 |
+
model.train()
|
| 499 |
+
perm = tr_t[torch.randperm(len(tr_t), device=device)]
|
| 500 |
+
run = 0.0
|
| 501 |
+
for i in range(0, len(perm), BATCH):
|
| 502 |
+
b = perm[i:i + BATCH]
|
| 503 |
+
opt.zero_grad()
|
| 504 |
+
qmos_p, emos_p, cat_logits, vad_p = model(Xn_t[b], T_t[b], Un_t[b])
|
| 505 |
+
L = task_losses(qmos_p, emos_p, cat_logits, vad_p, b)
|
| 506 |
+
loss = combine(L)
|
| 507 |
+
loss.backward(); opt.step()
|
| 508 |
+
run += loss.item() * len(b)
|
| 509 |
+
m = eval_val()
|
| 510 |
+
sc = val_score(m)
|
| 511 |
+
if sc > best_score:
|
| 512 |
+
best_score = sc
|
| 513 |
+
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
| 514 |
+
bad = 0
|
| 515 |
+
else:
|
| 516 |
+
bad += 1
|
| 517 |
+
if ep_i % 5 == 0 or ep_i == 1:
|
| 518 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in ["qmos", "emos", "val", "aro", "dom"] if k in m)
|
| 519 |
+
print(f"epoch {ep_i:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}")
|
| 520 |
+
if bad >= PATIENCE:
|
| 521 |
+
print(f"Early stop ở epoch {ep_i}.")
|
| 522 |
+
break
|
| 523 |
+
|
| 524 |
+
model.load_state_dict(best_state)
|
| 525 |
+
final = eval_val()
|
| 526 |
+
print("\n✅ VAL (nội bộ) — exp07 (fusion + QMOS head):")
|
| 527 |
+
print(f" QMOS SRCC = {final['qmos']:.4f}", end="")
|
| 528 |
+
if "qmos_utmos" in final:
|
| 529 |
+
tag = "✅ vượt UTMOS" if final["qmos"] > final["qmos_utmos"] else "⚠️ CHƯA vượt UTMOS"
|
| 530 |
+
print(f" (UTMOS đơn lẻ = {final['qmos_utmos']:.4f} → {tag})")
|
| 531 |
+
else:
|
| 532 |
+
print()
|
| 533 |
+
print(f" EMOS SRCC = {final['emos']:.4f} (mốc exp04 = {EXP04['emos']})")
|
| 534 |
+
if HAS_VAD:
|
| 535 |
+
print(f" VAL/ARO/DOM = {final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}"
|
| 536 |
+
f" (mốc exp04 = {EXP04['val']}/{EXP04['aro']}/{EXP04['dom']})")
|
| 537 |
+
# Cảnh báo negative transfer (gộp QMOS làm tụt cảm xúc)
|
| 538 |
+
warn = []
|
| 539 |
+
if final["emos"] < EXP04["emos"] - 0.02:
|
| 540 |
+
warn.append(f"EMOS {final['emos']:.3f} < {EXP04['emos']}")
|
| 541 |
+
if HAS_VAD:
|
| 542 |
+
for t in ["val", "aro", "dom"]:
|
| 543 |
+
if final[t] < EXP04[t] - 0.02:
|
| 544 |
+
warn.append(f"{t.upper()} {final[t]:.3f} < {EXP04[t]}")
|
| 545 |
+
if warn:
|
| 546 |
+
print(" ⚠️ NEGATIVE TRANSFER? Cảm xúc tụt so exp04:", "; ".join(warn),
|
| 547 |
+
"\n → cân nhắc giữ exp04 cho 5 cột cảm xúc + chỉ lấy QMOS từ exp07/exp06.")
|
| 548 |
+
else:
|
| 549 |
+
print(" ✅ Không thấy 5 cột cảm xúc tụt rõ so exp04.")
|
| 550 |
+
if USE_UNCERTAINTY:
|
| 551 |
+
print(" log σ² mỗi task:", {t: round(float(log_var[i]), 3) for i, t in enumerate(TASKS)})
|
| 552 |
+
|
| 553 |
+
torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
|
| 554 |
+
"u_mu": u_mu, "u_sd": u_sd,
|
| 555 |
+
"qmos_mu": qmos_mu, "qmos_sd": qmos_sd, "emos_mu": emos_mu, "emos_sd": emos_sd,
|
| 556 |
+
"vad_mu": vad_mu, "vad_sd": vad_sd, "FEAT_DIM": FEAT_DIM,
|
| 557 |
+
"USE_E2V": USE_E2V, "USE_SAILER": USE_SAILER, "USE_CLASSPROB": USE_CLASSPROB,
|
| 558 |
+
"USE_UTMOS_FEAT": USE_UTMOS_FEAT, "val_score": best_score},
|
| 559 |
+
os.path.join(OUT_DIR, "fusion_qmos_mtl.pt"))
|
| 560 |
+
print("Đã lưu", os.path.join(OUT_DIR, "fusion_qmos_mtl.pt"))
|
| 561 |
+
|
| 562 |
+
# %% [markdown]
|
| 563 |
+
# ## 6. Dự đoán DEV → `answer.txt` đủ 6 cột (QMOS giờ từ HEAD, không phải SpeechMOS riêng)
|
| 564 |
+
|
| 565 |
+
# %%
|
| 566 |
+
def list_dev():
|
| 567 |
+
with open(DEV_SCP) as f:
|
| 568 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 569 |
+
|
| 570 |
+
dev_names = list_dev()
|
| 571 |
+
if LIMIT_DEV:
|
| 572 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 573 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 574 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 575 |
+
|
| 576 |
+
e2v_dev = extract_e2v(dev_stems, "dev") if USE_E2V else {}
|
| 577 |
+
sailer_dev = extract_sailer(dev_stems, "dev") if USE_SAILER else {}
|
| 578 |
+
utmos_dev = extract_utmos(dev_names, "dev") if USE_UTMOS_FEAT else {}
|
| 579 |
+
|
| 580 |
+
@torch.no_grad()
|
| 581 |
+
def predict_all(sid):
|
| 582 |
+
f = audio_feature(sid, e2v_dev, sailer_dev)
|
| 583 |
+
if f is None:
|
| 584 |
+
return None
|
| 585 |
+
fn = (f[None, :] - feat_mean) / feat_std
|
| 586 |
+
tgt = onehot_target(target_map.get(sid))[None, :]
|
| 587 |
+
u = np.array([[utmos_dev.get(sid, 3.0)]], dtype=np.float32)
|
| 588 |
+
un = (u - u_mu) / u_sd
|
| 589 |
+
model.eval()
|
| 590 |
+
qmos_p, emos_p, cat_logits, vad_p = model(to_t(fn), to_t(tgt), to_t(un))
|
| 591 |
+
qmos = float(qmos_p.item()) * qmos_sd + qmos_mu
|
| 592 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 593 |
+
cat5 = F.softmax(cat_logits, dim=1)[0].cpu().numpy()
|
| 594 |
+
vad3 = vad_p[0].cpu().numpy() * vad_sd + vad_mu
|
| 595 |
+
return qmos, emos, cat5, vad3
|
| 596 |
+
|
| 597 |
+
def fmt_cat(probs5):
|
| 598 |
+
return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 599 |
+
|
| 600 |
+
def build_answer(out_path):
|
| 601 |
+
from tqdm.auto import tqdm
|
| 602 |
+
n_real = n_default = 0
|
| 603 |
+
with open(out_path, "w") as f:
|
| 604 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 605 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 606 |
+
sid = stem(name)
|
| 607 |
+
pred = predict_all(sid)
|
| 608 |
+
if pred is None:
|
| 609 |
+
# rơi về: QMOS=UTMOS nếu có, còn lại mặc định
|
| 610 |
+
qmos = utmos_dev.get(sid, 3.0)
|
| 611 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])
|
| 612 |
+
n_default += 1
|
| 613 |
+
else:
|
| 614 |
+
qmos, emos, cat5, vad3 = pred
|
| 615 |
+
n_real += 1
|
| 616 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
|
| 617 |
+
f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 618 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | head thật {n_real}, mặc định {n_default}")
|
| 619 |
+
|
| 620 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 621 |
+
build_answer(answer_path)
|
| 622 |
+
|
| 623 |
+
# %% [markdown]
|
| 624 |
+
# ## 7. Validate + đóng zip
|
| 625 |
+
|
| 626 |
+
# %%
|
| 627 |
+
def validate(path):
|
| 628 |
+
import csv
|
| 629 |
+
with open(path) as f:
|
| 630 |
+
rows = list(csv.reader(f))
|
| 631 |
+
header = rows[0]
|
| 632 |
+
assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
|
| 633 |
+
for i, r in enumerate(rows[1:], 2):
|
| 634 |
+
assert len(r) == len(header), f"Dòng {i} sai số cột"
|
| 635 |
+
print(f"OK: {len(rows)-1} dòng, header = {header}")
|
| 636 |
+
|
| 637 |
+
validate(answer_path)
|
| 638 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp07_fusion_qmos.zip answer.txt "
|
| 639 |
+
f"&& unzip -l submission_track2_exp07_fusion_qmos.zip")
|
| 640 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp07_fusion_qmos.zip"))
|
| 641 |
+
|
| 642 |
+
# %% [markdown]
|
| 643 |
+
# ## Ghi chú
|
| 644 |
+
# - **Lần đầu** đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`; OK rồi đặt `None`.
|
| 645 |
+
# - **Đọc kết quả mục 5 theo 2 câu hỏi:**
|
| 646 |
+
# 1. QMOS head có **vượt UTMOS đơn lẻ (0.414)** không? (dòng "vượt/CHƯA vượt UTMOS")
|
| 647 |
+
# 2. Gộp QMOS có **làm tụt** EMOS/VAD so exp04 không? (dòng "NEGATIVE TRANSFER?")
|
| 648 |
+
# - **Quyết định nộp:**
|
| 649 |
+
# - Nếu QMOS↑ và cảm xúc KHÔNG tụt → nộp answer.txt exp07 (1 model trọn 6 cột — đẹp cho paper).
|
| 650 |
+
# - Nếu QMOS↑ nhưng cảm xúc TỤT → giữ exp04 cho 5 cột cảm xúc, chỉ lấy **cột QMOS** của exp07/exp06 ghép vào.
|
| 651 |
+
# - Nếu QMOS không vượt UTMOS → kết luận "chất lượng trực giao cảm xúc" (vẫn là phát hiện cho paper); giữ exp04.
|
| 652 |
+
# - **Ablation cho paper**: `USE_UTMOS_FEAT=False` (QMOS chỉ từ trunk cảm xúc) → đo trực tiếp giả thuyết của bạn.
|
| 653 |
+
# - Cache dùng CHUNG `fusion_cache/` với exp04 → **Save Version** giữ lại.
|
| 654 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp07).
|
track2/exp08_finetune_emotion.ipynb
ADDED
|
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "ee3b7231",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp08 (FINE-TUNE WavLM cho 5 cột cảm xúc) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Khác mọi exp trước:** exp03–07 đều **đóng băng** backbone (chỉ trích đặc trưng + train head nhỏ trên cache).\n",
|
| 11 |
+
"exp08 **MỞ BĂNG (fine-tune)** WavLM-large để nó học lại đặc trưng riêng cho bài MOS cảm xúc 2026.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"## Thiết kế (chốt với mentor 5/6)\n",
|
| 14 |
+
"```\n",
|
| 15 |
+
" wav ─┬─► WavLM-large (warm-start SAILER, TRAINABLE: chỉ mở băng N lớp trên) ─► pool ─► emb_wavlm ┐\n",
|
| 16 |
+
" └─► audeering MSP-dim (FROZEN, cache .npz) ─► [emb_aud | vad3] ├─► TRUNK ─┬─► EMOS (+target)\n",
|
| 17 |
+
" ┘ ├─► CAT (5)\n",
|
| 18 |
+
" └─► VAD (3)\n",
|
| 19 |
+
" QMOS: KHÔNG train ở đây → mượn cột QMOS của exp07 (0.548) hoặc UTMOSv2 (T05, vô địch VMC2024).\n",
|
| 20 |
+
"```\n",
|
| 21 |
+
"- **Warm-start:** khởi tạo WavLM từ checkpoint **SAILER** (`tiantiaf/wavlm-large-categorical-emotion`,\n",
|
| 22 |
+
" đã giỏi cảm xúc) thay vì WavLM \"trắng\" → điểm xuất phát tốt hơn nhiều.\n",
|
| 23 |
+
"- **Phụ (frozen):** audeering — dimensional, bổ trợ góc nhìn categorical của WavLM, kỳ vọng kéo **VAL**.\n",
|
| 24 |
+
"- **Đóng băng partial:** chỉ train `UNFREEZE_TOP_LAYERS` lớp Transformer trên cùng + feature-extractor giữ băng\n",
|
| 25 |
+
" → tiết kiệm VRAM T4 + chống overfit (chỉ 12.7k mẫu).\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"## ⚠️ Đánh đổi phải biết trước (so freeze+head)\n",
|
| 28 |
+
"- **Mất lợi thế cache:** mỗi epoch chạy lại cả WavLM (forward+backward) → chậm & đốt giờ GPU (30h/tuần).\n",
|
| 29 |
+
" → **Lần đầu BẮT BUỘC đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.\n",
|
| 30 |
+
"- **Dễ overfit / OOM:** nếu OOM → giảm `BATCH`, tăng `ACCUM`, giảm `MAX_SECONDS`, giảm `UNFREEZE_TOP_LAYERS`.\n",
|
| 31 |
+
"- **Lưới an toàn:** exp07 vẫn là bản nộp vô địch tới khi exp08 **thắng trên VAL nội bộ** (đừng đốt lượt nộp).\n",
|
| 32 |
+
"\n",
|
| 33 |
+
"**Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All."
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"id": "656d8385",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"id": "38d86264",
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": [
|
| 51 |
+
"import os\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
|
| 54 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 55 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
|
| 56 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
|
| 57 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 60 |
+
"CACHE_DIR = \"/kaggle/working/ft_cache\" # cache audeering (.npz) — backbone WavLM KHÔNG cache (đang train)\n",
|
| 61 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"# (Tùy chọn) TÁI DÙNG cache audeering cũ: trỏ tới dataset chứa aud_train.npz/aud_dev.npz → tự copy sang CACHE_DIR.\n",
|
| 64 |
+
"# Để \"\" nếu chạy mới hoàn toàn. /kaggle/input read-only nên phải copy sang working để ghi/append.\n",
|
| 65 |
+
"CACHE_INPUT = \"/kaggle/input/datasets/minhtoan2/cache-exp8\" # << SỬA slug cho khớp (hoặc \"\")\n",
|
| 66 |
+
"if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
|
| 67 |
+
" import shutil\n",
|
| 68 |
+
" _n = 0\n",
|
| 69 |
+
" for _fn in os.listdir(CACHE_INPUT):\n",
|
| 70 |
+
" if _fn.startswith(\"aud_\") and _fn.endswith(\".npz\"):\n",
|
| 71 |
+
" shutil.copy(os.path.join(CACHE_INPUT, _fn), os.path.join(CACHE_DIR, _fn)); _n += 1\n",
|
| 72 |
+
" print(f\"📦 Tái dùng cache: copy {_n} file aud_*.npz từ {CACHE_INPUT} → {CACHE_DIR}\")\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"# Mượn cột QMOS của exp07 (tốt nhất 0.548). Trỏ tới answer.txt exp07 nếu có; không thì dùng UTMOSv2.\n",
|
| 75 |
+
"EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) Add Input answer.txt exp07; không có → UTMOSv2\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"# ── Fine-tune / siêu tham số ─────────────────────────────────────────────────\n",
|
| 78 |
+
"DEVICE = \"cuda\"\n",
|
| 79 |
+
"SR = 16000\n",
|
| 80 |
+
"MAX_SECONDS = 8 # cắt audio để chặn bộ nhớ backprop; OOM thì giảm còn 6\n",
|
| 81 |
+
"UNFREEZE_TOP_LAYERS = 6 # số lớp Transformer trên cùng được train (0 = freeze hết = quay về head-only)\n",
|
| 82 |
+
"TRUNK_HIDDEN = 512\n",
|
| 83 |
+
"HEAD_HIDDEN = 128\n",
|
| 84 |
+
"DROPOUT = 0.3\n",
|
| 85 |
+
"LR_BACKBONE = 1e-5 # LR nhỏ cho backbone fine-tune\n",
|
| 86 |
+
"LR_HEAD = 1e-3 # LR lớn cho trunk + head (train từ đầu)\n",
|
| 87 |
+
"WEIGHT_DECAY = 1e-5\n",
|
| 88 |
+
"EPOCHS = 12 # TRẦN; early-stop quyết định số epoch thực (8 hơi thấp cho lần chạy thật)\n",
|
| 89 |
+
"PATIENCE = 3 # dừng khi val SRCC không lên 3 epoch; LUÔN giữ best_state\n",
|
| 90 |
+
"BATCH = 4 # nhỏ vì backbone to; tăng ACCUM để bù\n",
|
| 91 |
+
"ACCUM = 8 # effective batch = BATCH*ACCUM = 32\n",
|
| 92 |
+
"VAL_FRAC = 0.10\n",
|
| 93 |
+
"SEED = 42\n",
|
| 94 |
+
"USE_AMP = True # mixed precision fp16 — tiết kiệm VRAM\n",
|
| 95 |
+
"USE_GRAD_CKPT = True # gradient checkpointing — tiết kiệm VRAM (đổi lấy chậm hơn)\n",
|
| 96 |
+
"USE_AUDEERING = True # nhánh phụ frozen audeering; False = chỉ WavLM\n",
|
| 97 |
+
"USE_UNCERTAINTY = True # tự cân 5 loss (Kendall); False = trọng số 1.0\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"LIMIT_TRAIN = 300 # << LẦN ĐẦU để 300; chạy thật đặt None\n",
|
| 100 |
+
"LIMIT_DEV = 20 # << LẦN ĐẦU để 20; chạy thật đặt None\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"# Mốc exp07 để so (cảnh báo nếu fine-tune KHÔNG thắng → giữ exp07)\n",
|
| 103 |
+
"EXP07 = {\"emos\": 0.795, \"cat_err\": 0.153, \"val\": 0.581, \"aro\": 0.752, \"dom\": 0.705}\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"_EMO_ALIAS = {\n",
|
| 108 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 109 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 110 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 111 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 112 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 113 |
+
"}\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"def norm_emotion(label):\n",
|
| 116 |
+
" key = str(label).strip().lower()\n",
|
| 117 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 118 |
+
"\n",
|
| 119 |
+
"def stem(p):\n",
|
| 120 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 121 |
+
"\n",
|
| 122 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 123 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 124 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
|
| 125 |
+
"print(f\"Fine-tune: mở băng {UNFREEZE_TOP_LAYERS} lớp trên · BATCH {BATCH}×ACCUM {ACCUM} · MAX {MAX_SECONDS}s\")"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "markdown",
|
| 130 |
+
"id": "ed538923",
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"source": [
|
| 133 |
+
"## 1. Cài đặt + tải code SAILER (clone + sys.path, KHÔNG pip install -e .)"
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": null,
|
| 139 |
+
"id": "f052d016",
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"outputs": [],
|
| 142 |
+
"source": [
|
| 143 |
+
"import sys, subprocess\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"def pip_install(*pkgs):\n",
|
| 146 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\",\n",
|
| 149 |
+
" \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 152 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 153 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 154 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 155 |
+
"if REPO_DIR not in sys.path:\n",
|
| 156 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 157 |
+
]
|
| 158 |
+
},
|
| 159 |
+
{
|
| 160 |
+
"cell_type": "markdown",
|
| 161 |
+
"id": "0bf41e8c",
|
| 162 |
+
"metadata": {},
|
| 163 |
+
"source": [
|
| 164 |
+
"## 2. Nạp SAILER → lấy backbone WavLM bên trong để FINE-TUNE\n",
|
| 165 |
+
"Thay vì gọi wrapper như hộp đen, ta **lôi module WavLM-large (HuggingFace) bên trong wrapper** ra\n",
|
| 166 |
+
"→ toàn quyền đóng băng/mở băng từng lớp + tự pool. Nếu không tìm thấy (cấu trúc lạ) → **fallback**\n",
|
| 167 |
+
"nạp `microsoft/wavlm-large` trắng (mất warm-start, có cảnh báo)."
|
| 168 |
+
]
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"cell_type": "code",
|
| 172 |
+
"execution_count": null,
|
| 173 |
+
"id": "50a7cac6",
|
| 174 |
+
"metadata": {
|
| 175 |
+
"lines_to_next_cell": 1
|
| 176 |
+
},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": [
|
| 179 |
+
"import torch\n",
|
| 180 |
+
"import torch.nn as nn\n",
|
| 181 |
+
"import torch.nn.functional as F\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 184 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
|
| 185 |
+
"\n",
|
| 186 |
+
"def find_hf_backbone(module):\n",
|
| 187 |
+
" \"\"\"Tìm submodule kiểu HF Wav2Vec2/WavLM backbone: có .feature_extractor và .encoder.layers.\"\"\"\n",
|
| 188 |
+
" cands = []\n",
|
| 189 |
+
" for name, m in module.named_modules():\n",
|
| 190 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 191 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 192 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 193 |
+
" cands.append((name, m))\n",
|
| 194 |
+
" if not cands:\n",
|
| 195 |
+
" return None, None\n",
|
| 196 |
+
" cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
|
| 197 |
+
" return cands[0]\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"wavlm = None\n",
|
| 200 |
+
"try:\n",
|
| 201 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 202 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 203 |
+
" name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 204 |
+
" if wavlm is not None:\n",
|
| 205 |
+
" print(f\"✅ Warm-start SAILER: lấy backbone WavLM bên trong wrapper tại '.{name}' \"\n",
|
| 206 |
+
" f\"({sum(p.numel() for p in wavlm.parameters())/1e6:.0f}M params)\")\n",
|
| 207 |
+
" else:\n",
|
| 208 |
+
" print(\"⚠️ Không tìm thấy backbone HF bên trong wrapper SAILER → sẽ fallback WavLM trắng.\")\n",
|
| 209 |
+
"except Exception as e:\n",
|
| 210 |
+
" print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"if wavlm is None:\n",
|
| 213 |
+
" from transformers import WavLMModel\n",
|
| 214 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 215 |
+
" print(\"ℹ️ Fallback: nạp microsoft/wavlm-large (KHÔNG warm-start SAILER).\")\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"wavlm = wavlm.to(device)\n",
|
| 218 |
+
"WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 219 |
+
"\n",
|
| 220 |
+
"# ── Đóng băng partial: feature-extractor + tất cả trừ UNFREEZE_TOP_LAYERS lớp trên ──\n",
|
| 221 |
+
"for p in wavlm.parameters():\n",
|
| 222 |
+
" p.requires_grad = False\n",
|
| 223 |
+
"enc_layers = wavlm.encoder.layers\n",
|
| 224 |
+
"n_layers = len(enc_layers)\n",
|
| 225 |
+
"for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:\n",
|
| 226 |
+
" for p in layer.parameters():\n",
|
| 227 |
+
" p.requires_grad = True\n",
|
| 228 |
+
"n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)\n",
|
| 229 |
+
"print(f\"WavLM: {n_layers} lớp encoder · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} lớp trên \"\n",
|
| 230 |
+
" f\"→ {n_train/1e6:.1f}M param train (trên dim {WAVLM_DIM})\")\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"if USE_GRAD_CKPT:\n",
|
| 233 |
+
" wavlm.gradient_checkpointing_enable()\n",
|
| 234 |
+
" if hasattr(wavlm, \"enable_input_require_grads\"):\n",
|
| 235 |
+
" wavlm.enable_input_require_grads() # cần khi grad-ckpt + lớp dưới đóng băng\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"def masked_mean(hidden, attn_mask):\n",
|
| 238 |
+
" \"\"\"Mean-pool theo thời gian, bỏ qua phần pad (giữ gradient).\"\"\"\n",
|
| 239 |
+
" if attn_mask is None:\n",
|
| 240 |
+
" return hidden.mean(dim=1)\n",
|
| 241 |
+
" try:\n",
|
| 242 |
+
" fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
|
| 243 |
+
" except Exception:\n",
|
| 244 |
+
" return hidden.mean(dim=1)\n",
|
| 245 |
+
" fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
|
| 246 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"def wavlm_embed(input_values, attn_mask):\n",
|
| 249 |
+
" out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state # [B,T,D]\n",
|
| 250 |
+
" return masked_mean(out, attn_mask)"
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "markdown",
|
| 255 |
+
"id": "d8b8b8de",
|
| 256 |
+
"metadata": {},
|
| 257 |
+
"source": [
|
| 258 |
+
"## 3. Nạp audeering MSP-dim (FROZEN) — đặc trưng phụ\n",
|
| 259 |
+
"Lấy `[emb_pool(1024) | vad3(1–5)]` mỗi wav rồi **cache .npz** (chạy 1 lần). Kỹ thuật nạp head tay\n",
|
| 260 |
+
"y hệt exp05 (tránh lỗi version transformers khi subclass `Wav2Vec2PreTrainedModel`)."
|
| 261 |
+
]
|
| 262 |
+
},
|
| 263 |
+
{
|
| 264 |
+
"cell_type": "code",
|
| 265 |
+
"execution_count": null,
|
| 266 |
+
"id": "8731aa54",
|
| 267 |
+
"metadata": {},
|
| 268 |
+
"outputs": [],
|
| 269 |
+
"source": [
|
| 270 |
+
"AUD_DIM = 0\n",
|
| 271 |
+
"aud_backbone = aud_head = aud_proc = None\n",
|
| 272 |
+
"if USE_AUDEERING:\n",
|
| 273 |
+
" from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 274 |
+
" from huggingface_hub import hf_hub_download\n",
|
| 275 |
+
" AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 276 |
+
" aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 277 |
+
" aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 278 |
+
" aud_backbone = Wav2Vec2Model(aud_cfg)\n",
|
| 279 |
+
" try:\n",
|
| 280 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 281 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 282 |
+
" except Exception:\n",
|
| 283 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 284 |
+
" bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 285 |
+
" missing, unexpected = aud_backbone.load_state_dict(bb_sd, strict=False)\n",
|
| 286 |
+
" print(f\" audeering backbone: thiếu {len(missing)} / dư {len(unexpected)} key (strict=False)\")\n",
|
| 287 |
+
" _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 288 |
+
" _out = _sd[\"classifier.out_proj.weight\"].shape[0]\n",
|
| 289 |
+
" aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
|
| 290 |
+
" aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 291 |
+
" aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 292 |
+
" aud_backbone = aud_backbone.to(device).eval()\n",
|
| 293 |
+
" aud_head = aud_head.to(device).eval()\n",
|
| 294 |
+
" AUD_DIM = _hid + 3 # emb_pool + [VAL,ARO,DOM]\n",
|
| 295 |
+
" print(f\"✅ audeering frozen (đặc trưng phụ {AUD_DIM}-D = emb {_hid} + vad 3)\")"
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"cell_type": "code",
|
| 300 |
+
"execution_count": null,
|
| 301 |
+
"id": "d12e4737",
|
| 302 |
+
"metadata": {
|
| 303 |
+
"lines_to_next_cell": 1
|
| 304 |
+
},
|
| 305 |
+
"outputs": [],
|
| 306 |
+
"source": [
|
| 307 |
+
"import numpy as np\n",
|
| 308 |
+
"import librosa\n",
|
| 309 |
+
"from tqdm.auto import tqdm\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"def load_wav(name_or_stem, in_wav_dir=True):\n",
|
| 312 |
+
" p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(\n",
|
| 313 |
+
" WAV_DIR, name_or_stem if str(name_or_stem).endswith(\".wav\") else str(name_or_stem) + \".wav\")\n",
|
| 314 |
+
" if not os.path.exists(p):\n",
|
| 315 |
+
" return None\n",
|
| 316 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 317 |
+
" return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
|
| 318 |
+
"\n",
|
| 319 |
+
"@torch.no_grad()\n",
|
| 320 |
+
"def extract_audeering(stems, tag):\n",
|
| 321 |
+
" \"\"\"→ dict {stem: float32[AUD_DIM]}; cache CACHE_DIR/aud_<tag>.npz (resume mỗi 500).\"\"\"\n",
|
| 322 |
+
" if not USE_AUDEERING:\n",
|
| 323 |
+
" return {}\n",
|
| 324 |
+
" cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
|
| 325 |
+
" store = {}\n",
|
| 326 |
+
" if os.path.exists(cache_path):\n",
|
| 327 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 328 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 329 |
+
" print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
|
| 330 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 331 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
|
| 332 |
+
" wave = load_wav(s)\n",
|
| 333 |
+
" if wave is None:\n",
|
| 334 |
+
" continue\n",
|
| 335 |
+
" x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 336 |
+
" x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
|
| 337 |
+
" h = aud_backbone(x)[0].mean(dim=1) # [1, hid]\n",
|
| 338 |
+
" out = aud_head(h)[0].cpu().numpy() # [arousal, dominance, valence] ∈[0,1]\n",
|
| 339 |
+
" vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
|
| 340 |
+
" store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
|
| 341 |
+
" if (i + 1) % 500 == 0:\n",
|
| 342 |
+
" np.savez(cache_path, **store)\n",
|
| 343 |
+
" if todo:\n",
|
| 344 |
+
" np.savez(cache_path, **store)\n",
|
| 345 |
+
" return store"
|
| 346 |
+
]
|
| 347 |
+
},
|
| 348 |
+
{
|
| 349 |
+
"cell_type": "markdown",
|
| 350 |
+
"id": "3397dbe7",
|
| 351 |
+
"metadata": {},
|
| 352 |
+
"source": [
|
| 353 |
+
"## 4. Đọc & gộp nhãn theo wavID (EMOS / VAD / CAT) — như exp04/07 nhưng KHÔNG cần qMOS"
|
| 354 |
+
]
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"cell_type": "code",
|
| 358 |
+
"execution_count": null,
|
| 359 |
+
"id": "df3b95e3",
|
| 360 |
+
"metadata": {},
|
| 361 |
+
"outputs": [],
|
| 362 |
+
"source": [
|
| 363 |
+
"import pandas as pd\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"def load_target_emotions():\n",
|
| 366 |
+
" tgt = {}\n",
|
| 367 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 368 |
+
" for ln in f:\n",
|
| 369 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 370 |
+
" if len(parts) >= 2:\n",
|
| 371 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 372 |
+
" return tgt\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"def _col(cols_map, *names, df=None, default_idx=None):\n",
|
| 375 |
+
" for n in names:\n",
|
| 376 |
+
" if n in cols_map:\n",
|
| 377 |
+
" return cols_map[n]\n",
|
| 378 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 379 |
+
"\n",
|
| 380 |
+
"def parse_emocat_votes(cell):\n",
|
| 381 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 382 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 383 |
+
" e = norm_emotion(tok)\n",
|
| 384 |
+
" if e in EMOTIONS5:\n",
|
| 385 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 386 |
+
" return v\n",
|
| 387 |
+
"\n",
|
| 388 |
+
"def load_train_labels():\n",
|
| 389 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 390 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 391 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
|
| 392 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 393 |
+
" val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 394 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 395 |
+
" assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
|
| 396 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 397 |
+
" rows = []\n",
|
| 398 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 399 |
+
" rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
|
| 400 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 401 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 402 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 403 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 404 |
+
" if cat_col:\n",
|
| 405 |
+
" for cell in g[cat_col]:\n",
|
| 406 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 407 |
+
" s = votes.sum()\n",
|
| 408 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
|
| 409 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 410 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 411 |
+
" rows.append(rec)\n",
|
| 412 |
+
" return pd.DataFrame(rows)\n",
|
| 413 |
+
"\n",
|
| 414 |
+
"target_map = load_target_emotions()\n",
|
| 415 |
+
"train_df = load_train_labels()\n",
|
| 416 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 417 |
+
"print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
|
| 418 |
+
]
|
| 419 |
+
},
|
| 420 |
+
{
|
| 421 |
+
"cell_type": "markdown",
|
| 422 |
+
"id": "48ea29a7",
|
| 423 |
+
"metadata": {},
|
| 424 |
+
"source": [
|
| 425 |
+
"## 5. Dataset / DataLoader (load wav theo batch — KHÔNG cache WavLM vì đang train)"
|
| 426 |
+
]
|
| 427 |
+
},
|
| 428 |
+
{
|
| 429 |
+
"cell_type": "code",
|
| 430 |
+
"execution_count": null,
|
| 431 |
+
"id": "478f2af9",
|
| 432 |
+
"metadata": {},
|
| 433 |
+
"outputs": [],
|
| 434 |
+
"source": [
|
| 435 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 436 |
+
"\n",
|
| 437 |
+
"train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
|
| 438 |
+
"if LIMIT_TRAIN:\n",
|
| 439 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 440 |
+
"aud_tr = extract_audeering(train_stems, \"train\")\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"# Chuẩn hóa nhãn liên tục về z-score (để các MSE cùng thang) — lưu để giải mã lúc dự đoán.\n",
|
| 445 |
+
"def _zfit(arr):\n",
|
| 446 |
+
" a = np.asarray(arr, dtype=np.float32)\n",
|
| 447 |
+
" return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
|
| 448 |
+
"\n",
|
| 449 |
+
"emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
|
| 450 |
+
"if HAS_VAD:\n",
|
| 451 |
+
" vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
|
| 452 |
+
" vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
|
| 453 |
+
"else:\n",
|
| 454 |
+
" vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
|
| 455 |
+
"\n",
|
| 456 |
+
"def onehot_target(tgt):\n",
|
| 457 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 458 |
+
" if tgt in EMOTIONS5:\n",
|
| 459 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 460 |
+
" return v\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"class EmoDataset(Dataset):\n",
|
| 463 |
+
" def __init__(self, stems):\n",
|
| 464 |
+
" self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]\n",
|
| 465 |
+
" def __len__(self):\n",
|
| 466 |
+
" return len(self.stems)\n",
|
| 467 |
+
" def __getitem__(self, i):\n",
|
| 468 |
+
" s = self.stems[i]\n",
|
| 469 |
+
" wave = load_wav(s)\n",
|
| 470 |
+
" emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
|
| 471 |
+
" if HAS_VAD:\n",
|
| 472 |
+
" vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
|
| 473 |
+
" else:\n",
|
| 474 |
+
" vad = np.zeros(3, dtype=np.float32)\n",
|
| 475 |
+
" cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
|
| 476 |
+
" aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)\n",
|
| 477 |
+
" return {\"wave\": wave, \"tgt\": onehot_target(target_map.get(s)), \"aud\": aud,\n",
|
| 478 |
+
" \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
|
| 479 |
+
" \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
|
| 480 |
+
" \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
|
| 481 |
+
"\n",
|
| 482 |
+
"def collate(batch):\n",
|
| 483 |
+
" lens = [len(b[\"wave\"]) for b in batch]\n",
|
| 484 |
+
" L = max(lens)\n",
|
| 485 |
+
" waves = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 486 |
+
" mask = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 487 |
+
" for i, b in enumerate(batch):\n",
|
| 488 |
+
" waves[i, : len(b[\"wave\"])] = b[\"wave\"]; mask[i, : len(b[\"wave\"])] = 1.0\n",
|
| 489 |
+
" out = {\n",
|
| 490 |
+
" \"input_values\": torch.from_numpy(waves), \"attn_mask\": torch.from_numpy(mask).long(),\n",
|
| 491 |
+
" \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
|
| 492 |
+
" \"aud\": torch.from_numpy(np.stack([b[\"aud\"] for b in batch])) if USE_AUDEERING else None,\n",
|
| 493 |
+
" \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
|
| 494 |
+
" \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
|
| 495 |
+
" \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
|
| 496 |
+
" \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
|
| 497 |
+
" \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
|
| 498 |
+
" }\n",
|
| 499 |
+
" return out\n",
|
| 500 |
+
"\n",
|
| 501 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 502 |
+
"ds = EmoDataset(train_stems)\n",
|
| 503 |
+
"print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
|
| 504 |
+
"tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
|
| 505 |
+
"tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
|
| 506 |
+
"va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
|
| 507 |
+
]
|
| 508 |
+
},
|
| 509 |
+
{
|
| 510 |
+
"cell_type": "markdown",
|
| 511 |
+
"id": "f3342c6f",
|
| 512 |
+
"metadata": {},
|
| 513 |
+
"source": [
|
| 514 |
+
"## 6. Head fusion (trunk + 3 head cảm xúc) + train loop (AMP + grad accumulation)"
|
| 515 |
+
]
|
| 516 |
+
},
|
| 517 |
+
{
|
| 518 |
+
"cell_type": "code",
|
| 519 |
+
"execution_count": null,
|
| 520 |
+
"id": "3671b2da",
|
| 521 |
+
"metadata": {
|
| 522 |
+
"lines_to_next_cell": 1
|
| 523 |
+
},
|
| 524 |
+
"outputs": [],
|
| 525 |
+
"source": [
|
| 526 |
+
"from scipy.stats import spearmanr\n",
|
| 527 |
+
"\n",
|
| 528 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 529 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 530 |
+
"TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"class EmoHeads(nn.Module):\n",
|
| 533 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 534 |
+
" super().__init__()\n",
|
| 535 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 536 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 537 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 538 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 539 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 540 |
+
" def forward(self, feat, tgt):\n",
|
| 541 |
+
" h = self.trunk(feat)\n",
|
| 542 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
|
| 545 |
+
"print(f\"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM if USE_AUDEERING else 0})\")\n",
|
| 546 |
+
"\n",
|
| 547 |
+
"TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 548 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 549 |
+
"bb_params = [p for p in wavlm.parameters() if p.requires_grad]\n",
|
| 550 |
+
"head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 551 |
+
"opt = torch.optim.AdamW([\n",
|
| 552 |
+
" {\"params\": bb_params, \"lr\": LR_BACKBONE},\n",
|
| 553 |
+
" {\"params\": head_params, \"lr\": LR_HEAD},\n",
|
| 554 |
+
"], weight_decay=WEIGHT_DECAY)\n",
|
| 555 |
+
"scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
|
| 556 |
+
"mse = nn.MSELoss()\n",
|
| 557 |
+
"\n",
|
| 558 |
+
"def soft_ce(logits, target_dist):\n",
|
| 559 |
+
" return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"def forward_batch(b):\n",
|
| 562 |
+
" feat_wavlm = wavlm_embed(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
|
| 563 |
+
" if USE_AUDEERING:\n",
|
| 564 |
+
" feat = torch.cat([feat_wavlm, b[\"aud\"].to(device)], dim=1)\n",
|
| 565 |
+
" else:\n",
|
| 566 |
+
" feat = feat_wavlm\n",
|
| 567 |
+
" return heads(feat, b[\"tgt\"].to(device))\n",
|
| 568 |
+
"\n",
|
| 569 |
+
"def compute_loss(emos_p, cat_l, vad_p, b):\n",
|
| 570 |
+
" L = {}\n",
|
| 571 |
+
" L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
|
| 572 |
+
" L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
|
| 573 |
+
" if HAS_VAD:\n",
|
| 574 |
+
" vt = b[\"vad\"].to(device)\n",
|
| 575 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
|
| 576 |
+
" else:\n",
|
| 577 |
+
" z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 578 |
+
" if USE_UNCERTAINTY:\n",
|
| 579 |
+
" return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
|
| 580 |
+
" return sum(L.values())\n",
|
| 581 |
+
"\n",
|
| 582 |
+
"@torch.no_grad()\n",
|
| 583 |
+
"def evaluate():\n",
|
| 584 |
+
" wavlm.eval(); heads.eval()\n",
|
| 585 |
+
" P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
|
| 586 |
+
" catP, catY = [], []\n",
|
| 587 |
+
" for b in va_loader:\n",
|
| 588 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 589 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 590 |
+
" P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
|
| 591 |
+
" vad_p = vad_p.float().cpu().numpy()\n",
|
| 592 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 593 |
+
" P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
|
| 594 |
+
" catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
|
| 595 |
+
" out = {}\n",
|
| 596 |
+
" for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
|
| 597 |
+
" out[t] = spearmanr(P[t], Y[t]).correlation\n",
|
| 598 |
+
" q = np.concatenate(catP); p = np.concatenate(catY)\n",
|
| 599 |
+
" out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean()) # ~ tổng |Δ| trung bình (xấp xỉ CAT-ERR)\n",
|
| 600 |
+
" return out\n",
|
| 601 |
+
"\n",
|
| 602 |
+
"def mean_srcc(m):\n",
|
| 603 |
+
" keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 604 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 605 |
+
"\n",
|
| 606 |
+
"# Lưu checkpoint FULL (có backbone WavLM) — gọi NGAY mỗi best để kernel chết giữa chừng vẫn còn file.\n",
|
| 607 |
+
"CKPT_PATH = os.path.join(OUT_DIR, \"ft_emotion_full.pt\")\n",
|
| 608 |
+
"def save_full_ckpt(state, val_emos=float(\"nan\")):\n",
|
| 609 |
+
" torch.save({\"wavlm\": state[\"wavlm\"], \"heads\": state[\"heads\"],\n",
|
| 610 |
+
" \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
|
| 611 |
+
" \"WAVLM_DIM\": WAVLM_DIM, \"AUD_DIM\": AUD_DIM,\n",
|
| 612 |
+
" \"UNFREEZE_TOP_LAYERS\": UNFREEZE_TOP_LAYERS, \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
|
| 613 |
+
"\n",
|
| 614 |
+
"best, best_state, bad = -1e9, None, 0\n",
|
| 615 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 616 |
+
" wavlm.train(); heads.train()\n",
|
| 617 |
+
" opt.zero_grad(); run = 0.0; nb = 0\n",
|
| 618 |
+
" for step, b in enumerate(tqdm(tr_loader, desc=f\"epoch {ep}\")):\n",
|
| 619 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 620 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 621 |
+
" loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
|
| 622 |
+
" scaler.scale(loss).backward()\n",
|
| 623 |
+
" if (step + 1) % ACCUM == 0:\n",
|
| 624 |
+
" scaler.step(opt); scaler.update(); opt.zero_grad()\n",
|
| 625 |
+
" run += loss.item() * ACCUM; nb += 1\n",
|
| 626 |
+
" m = evaluate(); sc = mean_srcc(m)\n",
|
| 627 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
|
| 628 |
+
" print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
|
| 629 |
+
" if sc > best:\n",
|
| 630 |
+
" best = sc\n",
|
| 631 |
+
" best_state = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
|
| 632 |
+
" \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
|
| 633 |
+
" save_full_ckpt(best_state, m[\"emos\"]) # LƯU NGAY mỗi best → an toàn nếu kernel chết\n",
|
| 634 |
+
" print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\")\n",
|
| 635 |
+
" bad = 0\n",
|
| 636 |
+
" else:\n",
|
| 637 |
+
" bad += 1\n",
|
| 638 |
+
" if bad >= PATIENCE:\n",
|
| 639 |
+
" print(f\"Early stop ở epoch {ep}.\"); break\n",
|
| 640 |
+
"\n",
|
| 641 |
+
"if best_state:\n",
|
| 642 |
+
" wavlm.load_state_dict(best_state[\"wavlm\"]); heads.load_state_dict(best_state[\"heads\"])\n",
|
| 643 |
+
"final = evaluate()\n",
|
| 644 |
+
"print(\"\\n✅ VAL (nội bộ) — exp08 (fine-tune WavLM cho cảm xúc):\")\n",
|
| 645 |
+
"print(f\" EMOS={final['emos']:.4f} (exp07 {EXP07['emos']})\")\n",
|
| 646 |
+
"if HAS_VAD:\n",
|
| 647 |
+
" print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} \"\n",
|
| 648 |
+
" f\"(exp07 {EXP07['val']}/{EXP07['aro']}/{EXP07['dom']})\")\n",
|
| 649 |
+
"warn = [f\"EMOS {final['emos']:.3f}<{EXP07['emos']}\"] if final[\"emos\"] < EXP07[\"emos\"] - 0.005 else []\n",
|
| 650 |
+
"if HAS_VAD:\n",
|
| 651 |
+
" warn += [f\"{t.upper()} {final[t]:.3f}<{EXP07[t]}\" for t in [\"val\", \"aro\", \"dom\"] if final[t] < EXP07[t] - 0.005]\n",
|
| 652 |
+
"print(\" ⚠️ CHƯA thắng exp07 ở:\", \"; \".join(warn), \"→ cân nhắc giữ exp07.\" if warn else \"\")\n",
|
| 653 |
+
"if not warn:\n",
|
| 654 |
+
" print(\" ✅ Fine-tune thắng/ngang exp07 ở mọi cột cảm xúc → đáng nộp.\")\n",
|
| 655 |
+
"# Lưu lần cuối từ best (đã lưu sẵn mỗi best trong loop; đây là phát cuối cho chắc).\n",
|
| 656 |
+
"save_full_ckpt(best_state if best_state else\n",
|
| 657 |
+
" {\"wavlm\": wavlm.state_dict(), \"heads\": heads.state_dict()}, final[\"emos\"])\n",
|
| 658 |
+
"print(f\"✅ Đã lưu {CKPT_PATH} (CÓ backbone WavLM + heads → resume được). \"\n",
|
| 659 |
+
" f\"NHỚ Save Version để file ra Output!\")"
|
| 660 |
+
]
|
| 661 |
+
},
|
| 662 |
+
{
|
| 663 |
+
"cell_type": "markdown",
|
| 664 |
+
"id": "7b0b1b42",
|
| 665 |
+
"metadata": {},
|
| 666 |
+
"source": [
|
| 667 |
+
"## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc từ exp08; QMOS mượn exp07 hoặc UTMOS)"
|
| 668 |
+
]
|
| 669 |
+
},
|
| 670 |
+
{
|
| 671 |
+
"cell_type": "code",
|
| 672 |
+
"execution_count": null,
|
| 673 |
+
"id": "b29f616c",
|
| 674 |
+
"metadata": {
|
| 675 |
+
"lines_to_next_cell": 1
|
| 676 |
+
},
|
| 677 |
+
"outputs": [],
|
| 678 |
+
"source": [
|
| 679 |
+
"def list_dev():\n",
|
| 680 |
+
" with open(DEV_SCP) as f:\n",
|
| 681 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 682 |
+
"\n",
|
| 683 |
+
"dev_names = list_dev()\n",
|
| 684 |
+
"if LIMIT_DEV:\n",
|
| 685 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 686 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 687 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 688 |
+
"aud_dev = extract_audeering(dev_stems, \"dev\")\n",
|
| 689 |
+
"\n",
|
| 690 |
+
"# QMOS: ưu tiên mượn cột QMOS của exp07; không có file → chấm UTMOSv2 (T05, vô địch VMC2024).\n",
|
| 691 |
+
"def load_exp07_qmos():\n",
|
| 692 |
+
" if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
|
| 693 |
+
" import csv\n",
|
| 694 |
+
" d = {}\n",
|
| 695 |
+
" with open(EXP07_ANSWER) as f:\n",
|
| 696 |
+
" r = csv.DictReader(f)\n",
|
| 697 |
+
" for row in r:\n",
|
| 698 |
+
" d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
|
| 699 |
+
" print(f\"✅ Mượn QMOS từ exp07 ({EXP07_ANSWER}): {len(d)//2} wav\")\n",
|
| 700 |
+
" return d\n",
|
| 701 |
+
" return None\n",
|
| 702 |
+
"\n",
|
| 703 |
+
"qmos_map = load_exp07_qmos()\n",
|
| 704 |
+
"if qmos_map is None:\n",
|
| 705 |
+
" print(\"ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024 Track 1).\")\n",
|
| 706 |
+
" pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\") # cần Internet On, checkpoint tự tải\n",
|
| 707 |
+
" import utmosv2\n",
|
| 708 |
+
" v2 = utmosv2.create_model(pretrained=True)\n",
|
| 709 |
+
" qmos_map = {}\n",
|
| 710 |
+
" for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
|
| 711 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
|
| 712 |
+
" if not os.path.exists(wav):\n",
|
| 713 |
+
" continue\n",
|
| 714 |
+
" out = v2.predict(input_path=wav) # trả float hoặc dict {'predicted_mos': ...} tùy phiên bản\n",
|
| 715 |
+
" qmos_map[n] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
|
| 716 |
+
" del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 717 |
+
"\n",
|
| 718 |
+
"@torch.no_grad()\n",
|
| 719 |
+
"def predict_emotion(sid):\n",
|
| 720 |
+
" wave = load_wav(sid)\n",
|
| 721 |
+
" if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
|
| 722 |
+
" return None\n",
|
| 723 |
+
" wavlm.eval(); heads.eval()\n",
|
| 724 |
+
" iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 725 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
|
| 726 |
+
" tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
|
| 727 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 728 |
+
" fw = wavlm_embed(iv, am)\n",
|
| 729 |
+
" if USE_AUDEERING:\n",
|
| 730 |
+
" aud = torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)\n",
|
| 731 |
+
" feat = torch.cat([fw, aud], dim=1)\n",
|
| 732 |
+
" else:\n",
|
| 733 |
+
" feat = fw\n",
|
| 734 |
+
" emos_p, cat_l, vad_p = heads(feat, tgt)\n",
|
| 735 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 736 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 737 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 738 |
+
" return emos, cat5, vad3\n",
|
| 739 |
+
"\n",
|
| 740 |
+
"def fmt_cat(p5):\n",
|
| 741 |
+
" return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 742 |
+
"\n",
|
| 743 |
+
"def build_answer(out_path):\n",
|
| 744 |
+
" n_real = n_def = 0\n",
|
| 745 |
+
" with open(out_path, \"w\") as f:\n",
|
| 746 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 747 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 748 |
+
" sid = stem(name)\n",
|
| 749 |
+
" pr = predict_emotion(sid)\n",
|
| 750 |
+
" if pr is None:\n",
|
| 751 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
|
| 752 |
+
" else:\n",
|
| 753 |
+
" emos, cat5, vad3 = pr; n_real += 1\n",
|
| 754 |
+
" qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
|
| 755 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 756 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
|
| 757 |
+
"\n",
|
| 758 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 759 |
+
"build_answer(answer_path)"
|
| 760 |
+
]
|
| 761 |
+
},
|
| 762 |
+
{
|
| 763 |
+
"cell_type": "markdown",
|
| 764 |
+
"id": "2e9cab58",
|
| 765 |
+
"metadata": {},
|
| 766 |
+
"source": [
|
| 767 |
+
"## 8. Validate + đóng zip"
|
| 768 |
+
]
|
| 769 |
+
},
|
| 770 |
+
{
|
| 771 |
+
"cell_type": "code",
|
| 772 |
+
"execution_count": null,
|
| 773 |
+
"id": "88cb0280",
|
| 774 |
+
"metadata": {},
|
| 775 |
+
"outputs": [],
|
| 776 |
+
"source": [
|
| 777 |
+
"def validate(path):\n",
|
| 778 |
+
" import csv\n",
|
| 779 |
+
" with open(path) as f:\n",
|
| 780 |
+
" rows = list(csv.reader(f))\n",
|
| 781 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0] and \"EMOS\" in rows[0], \"Header sai\"\n",
|
| 782 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 783 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 784 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 785 |
+
"\n",
|
| 786 |
+
"validate(answer_path)\n",
|
| 787 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp08_ft-emotion.zip answer.txt \"\n",
|
| 788 |
+
" f\"&& unzip -l submission_track2_exp08_ft-emotion.zip\")\n",
|
| 789 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp08_ft-emotion.zip\"))"
|
| 790 |
+
]
|
| 791 |
+
},
|
| 792 |
+
{
|
| 793 |
+
"cell_type": "markdown",
|
| 794 |
+
"id": "e2018df3",
|
| 795 |
+
"metadata": {},
|
| 796 |
+
"source": [
|
| 797 |
+
"## Ghi chú\n",
|
| 798 |
+
"- **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để kiểm tra chạy trơn (1 epoch xong không OOM); rồi đặt `None`.\n",
|
| 799 |
+
"- **OOM trên T4?** giảm theo thứ tự: `MAX_SECONDS` (8→6) → `UNFREEZE_TOP_LAYERS` (6→4→2) → `BATCH` (4→2, tăng `ACCUM`).\n",
|
| 800 |
+
"- **Đọc mục 6:** so EMOS/VAD VAL nội bộ với mốc exp07 (EMOS 0.795 · VAL 0.581 · ARO 0.752 · DOM 0.705).\n",
|
| 801 |
+
" - Nếu fine-tune **thắng** → nộp answer.txt exp08 (5 cột cảm xúc của exp08 + QMOS mượn exp07).\n",
|
| 802 |
+
" - Nếu **thua** → giữ exp07; vẫn là kết quả cho paper (\"fine-tune chưa vượt frozen-fusion trên data nhỏ\").\n",
|
| 803 |
+
"- **QMOS:** Add Input answer.txt exp07 vào `/kaggle/input/exp07-answer/answer.txt` để mượn cột QMOS 0.548;\n",
|
| 804 |
+
" không có thì tự chấm UTMOSv2 (T05, vô địch VMC2024 — mạnh hơn UTMOS, cần Internet On).\n",
|
| 805 |
+
"- **Ablation cho paper:** `UNFREEZE_TOP_LAYERS=0` (≈ head-only) vs `=6` (fine-tune) → bảng \"frozen vs fine-tuned\".\n",
|
| 806 |
+
" `USE_AUDEERING=False` → đo đóng góp nhánh phụ.\n",
|
| 807 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp08)."
|
| 808 |
+
]
|
| 809 |
+
}
|
| 810 |
+
],
|
| 811 |
+
"metadata": {
|
| 812 |
+
"jupytext": {
|
| 813 |
+
"cell_metadata_filter": "-all",
|
| 814 |
+
"main_language": "python",
|
| 815 |
+
"notebook_metadata_filter": "-all"
|
| 816 |
+
}
|
| 817 |
+
},
|
| 818 |
+
"nbformat": 4,
|
| 819 |
+
"nbformat_minor": 5
|
| 820 |
+
}
|
track2/exp08_finetune_emotion_pipeline.py
ADDED
|
@@ -0,0 +1,673 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp08 (FINE-TUNE WavLM cho 5 cột cảm xúc) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Khác mọi exp trước:** exp03–07 đều **đóng băng** backbone (chỉ trích đặc trưng + train head nhỏ trên cache).
|
| 5 |
+
# exp08 **MỞ BĂNG (fine-tune)** WavLM-large để nó học lại đặc trưng riêng cho bài MOS cảm xúc 2026.
|
| 6 |
+
#
|
| 7 |
+
# ## Thiết kế (chốt với mentor 5/6)
|
| 8 |
+
# ```
|
| 9 |
+
# wav ─┬─► WavLM-large (warm-start SAILER, TRAINABLE: chỉ mở băng N lớp trên) ─► pool ─► emb_wavlm ┐
|
| 10 |
+
# └─► audeering MSP-dim (FROZEN, cache .npz) ─► [emb_aud | vad3] ├─► TRUNK ─┬─► EMOS (+target)
|
| 11 |
+
# ┘ ├─► CAT (5)
|
| 12 |
+
# └─► VAD (3)
|
| 13 |
+
# QMOS: KHÔNG train ở đây → mượn cột QMOS của exp07 (0.548) hoặc UTMOSv2 (T05, vô địch VMC2024).
|
| 14 |
+
# ```
|
| 15 |
+
# - **Warm-start:** khởi tạo WavLM từ checkpoint **SAILER** (`tiantiaf/wavlm-large-categorical-emotion`,
|
| 16 |
+
# đã giỏi cảm xúc) thay vì WavLM "trắng" → điểm xuất phát tốt hơn nhiều.
|
| 17 |
+
# - **Phụ (frozen):** audeering — dimensional, bổ trợ góc nhìn categorical của WavLM, kỳ vọng kéo **VAL**.
|
| 18 |
+
# - **Đóng băng partial:** chỉ train `UNFREEZE_TOP_LAYERS` lớp Transformer trên cùng + feature-extractor giữ băng
|
| 19 |
+
# → tiết kiệm VRAM T4 + chống overfit (chỉ 12.7k mẫu).
|
| 20 |
+
#
|
| 21 |
+
# ## ⚠️ Đánh đổi phải biết trước (so freeze+head)
|
| 22 |
+
# - **Mất lợi thế cache:** mỗi epoch chạy lại cả WavLM (forward+backward) → chậm & đốt giờ GPU (30h/tuần).
|
| 23 |
+
# → **Lần đầu BẮT BUỘC đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.
|
| 24 |
+
# - **Dễ overfit / OOM:** nếu OOM → giảm `BATCH`, tăng `ACCUM`, giảm `MAX_SECONDS`, giảm `UNFREEZE_TOP_LAYERS`.
|
| 25 |
+
# - **Lưới an toàn:** exp07 vẫn là bản nộp vô địch tới khi exp08 **thắng trên VAL nội bộ** (đừng đốt lượt nộp).
|
| 26 |
+
#
|
| 27 |
+
# **Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.
|
| 28 |
+
|
| 29 |
+
# %% [markdown]
|
| 30 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 31 |
+
|
| 32 |
+
# %%
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
|
| 36 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 37 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
|
| 38 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
|
| 39 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 40 |
+
|
| 41 |
+
OUT_DIR = "/kaggle/working"
|
| 42 |
+
CACHE_DIR = "/kaggle/working/ft_cache" # cache audeering (.npz) — backbone WavLM KHÔNG cache (đang train)
|
| 43 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# (Tùy chọn) TÁI DÙNG cache audeering cũ: trỏ tới dataset chứa aud_train.npz/aud_dev.npz → tự copy sang CACHE_DIR.
|
| 46 |
+
# Để "" nếu chạy mới hoàn toàn. /kaggle/input read-only nên phải copy sang working để ghi/append.
|
| 47 |
+
CACHE_INPUT = "/kaggle/input/datasets/minhtoan2/cache-exp8" # << SỬA slug cho khớp (hoặc "")
|
| 48 |
+
if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
|
| 49 |
+
import shutil
|
| 50 |
+
_n = 0
|
| 51 |
+
for _fn in os.listdir(CACHE_INPUT):
|
| 52 |
+
if _fn.startswith("aud_") and _fn.endswith(".npz"):
|
| 53 |
+
shutil.copy(os.path.join(CACHE_INPUT, _fn), os.path.join(CACHE_DIR, _fn)); _n += 1
|
| 54 |
+
print(f"📦 Tái dùng cache: copy {_n} file aud_*.npz từ {CACHE_INPUT} → {CACHE_DIR}")
|
| 55 |
+
|
| 56 |
+
# Mượn cột QMOS của exp07 (tốt nhất 0.548). Trỏ tới answer.txt exp07 nếu có; không thì dùng UTMOSv2.
|
| 57 |
+
EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) Add Input answer.txt exp07; không có → UTMOSv2
|
| 58 |
+
|
| 59 |
+
# ── Fine-tune / siêu tham số ─────────────────────────────────────────────────
|
| 60 |
+
DEVICE = "cuda"
|
| 61 |
+
SR = 16000
|
| 62 |
+
MAX_SECONDS = 8 # cắt audio để chặn bộ nhớ backprop; OOM thì giảm còn 6
|
| 63 |
+
UNFREEZE_TOP_LAYERS = 6 # số lớp Transformer trên cùng được train (0 = freeze hết = quay về head-only)
|
| 64 |
+
TRUNK_HIDDEN = 512
|
| 65 |
+
HEAD_HIDDEN = 128
|
| 66 |
+
DROPOUT = 0.3
|
| 67 |
+
LR_BACKBONE = 1e-5 # LR nhỏ cho backbone fine-tune
|
| 68 |
+
LR_HEAD = 1e-3 # LR lớn cho trunk + head (train từ đầu)
|
| 69 |
+
WEIGHT_DECAY = 1e-5
|
| 70 |
+
EPOCHS = 12 # TRẦN; early-stop quyết định số epoch thực (8 hơi thấp cho lần chạy thật)
|
| 71 |
+
PATIENCE = 3 # dừng khi val SRCC không lên 3 epoch; LUÔN giữ best_state
|
| 72 |
+
BATCH = 4 # nhỏ vì backbone to; tăng ACCUM để bù
|
| 73 |
+
ACCUM = 8 # effective batch = BATCH*ACCUM = 32
|
| 74 |
+
VAL_FRAC = 0.10
|
| 75 |
+
SEED = 42
|
| 76 |
+
USE_AMP = True # mixed precision fp16 — tiết kiệm VRAM
|
| 77 |
+
USE_GRAD_CKPT = True # gradient checkpointing — tiết kiệm VRAM (đổi lấy chậm hơn)
|
| 78 |
+
USE_AUDEERING = True # nhánh phụ frozen audeering; False = chỉ WavLM
|
| 79 |
+
USE_UNCERTAINTY = True # tự cân 5 loss (Kendall); False = trọng số 1.0
|
| 80 |
+
|
| 81 |
+
LIMIT_TRAIN = 300 # << LẦN ĐẦU để 300; chạy thật đặt None
|
| 82 |
+
LIMIT_DEV = 20 # << LẦN ĐẦU để 20; chạy thật đặt None
|
| 83 |
+
|
| 84 |
+
# Mốc exp07 để so (cảnh báo nếu fine-tune KHÔNG thắng → giữ exp07)
|
| 85 |
+
EXP07 = {"emos": 0.795, "cat_err": 0.153, "val": 0.581, "aro": 0.752, "dom": 0.705}
|
| 86 |
+
|
| 87 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 88 |
+
|
| 89 |
+
_EMO_ALIAS = {
|
| 90 |
+
"angry": "angry", "anger": "angry",
|
| 91 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 92 |
+
"neutral": "neutral", "calm": "neutral",
|
| 93 |
+
"sad": "sad", "sadness": "sad",
|
| 94 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
def norm_emotion(label):
|
| 98 |
+
key = str(label).strip().lower()
|
| 99 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 100 |
+
|
| 101 |
+
def stem(p):
|
| 102 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 103 |
+
|
| 104 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 105 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 106 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 107 |
+
print(f"Fine-tune: mở băng {UNFREEZE_TOP_LAYERS} lớp trên · BATCH {BATCH}×ACCUM {ACCUM} · MAX {MAX_SECONDS}s")
|
| 108 |
+
|
| 109 |
+
# %% [markdown]
|
| 110 |
+
# ## 1. Cài đặt + tải code SAILER (clone + sys.path, KHÔNG pip install -e .)
|
| 111 |
+
|
| 112 |
+
# %%
|
| 113 |
+
import sys, subprocess
|
| 114 |
+
|
| 115 |
+
def pip_install(*pkgs):
|
| 116 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 117 |
+
|
| 118 |
+
pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile",
|
| 119 |
+
"scipy", "scikit-learn", "pandas", "tqdm")
|
| 120 |
+
|
| 121 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 122 |
+
if not os.path.exists(REPO_DIR):
|
| 123 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 124 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 125 |
+
if REPO_DIR not in sys.path:
|
| 126 |
+
sys.path.insert(0, REPO_DIR)
|
| 127 |
+
|
| 128 |
+
# %% [markdown]
|
| 129 |
+
# ## 2. Nạp SAILER → lấy backbone WavLM bên trong để FINE-TUNE
|
| 130 |
+
# Thay vì gọi wrapper như hộp đen, ta **lôi module WavLM-large (HuggingFace) bên trong wrapper** ra
|
| 131 |
+
# → toàn quyền đóng băng/mở băng từng lớp + tự pool. Nếu không tìm thấy (cấu trúc lạ) → **fallback**
|
| 132 |
+
# nạp `microsoft/wavlm-large` trắng (mất warm-start, có cảnh báo).
|
| 133 |
+
|
| 134 |
+
# %%
|
| 135 |
+
import torch
|
| 136 |
+
import torch.nn as nn
|
| 137 |
+
import torch.nn.functional as F
|
| 138 |
+
|
| 139 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 140 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
|
| 141 |
+
|
| 142 |
+
def find_hf_backbone(module):
|
| 143 |
+
"""Tìm submodule kiểu HF Wav2Vec2/WavLM backbone: có .feature_extractor và .encoder.layers."""
|
| 144 |
+
cands = []
|
| 145 |
+
for name, m in module.named_modules():
|
| 146 |
+
enc = getattr(m, "encoder", None)
|
| 147 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 148 |
+
and getattr(enc, "layers", None) is not None:
|
| 149 |
+
cands.append((name, m))
|
| 150 |
+
if not cands:
|
| 151 |
+
return None, None
|
| 152 |
+
cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
|
| 153 |
+
return cands[0]
|
| 154 |
+
|
| 155 |
+
wavlm = None
|
| 156 |
+
try:
|
| 157 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 158 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 159 |
+
name, wavlm = find_hf_backbone(_wrapper)
|
| 160 |
+
if wavlm is not None:
|
| 161 |
+
print(f"✅ Warm-start SAILER: lấy backbone WavLM bên trong wrapper tại '.{name}' "
|
| 162 |
+
f"({sum(p.numel() for p in wavlm.parameters())/1e6:.0f}M params)")
|
| 163 |
+
else:
|
| 164 |
+
print("⚠️ Không tìm thấy backbone HF bên trong wrapper SAILER → sẽ fallback WavLM trắng.")
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
|
| 167 |
+
|
| 168 |
+
if wavlm is None:
|
| 169 |
+
from transformers import WavLMModel
|
| 170 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 171 |
+
print("ℹ️ Fallback: nạp microsoft/wavlm-large (KHÔNG warm-start SAILER).")
|
| 172 |
+
|
| 173 |
+
wavlm = wavlm.to(device)
|
| 174 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 175 |
+
|
| 176 |
+
# ── Đóng băng partial: feature-extractor + tất cả trừ UNFREEZE_TOP_LAYERS lớp trên ──
|
| 177 |
+
for p in wavlm.parameters():
|
| 178 |
+
p.requires_grad = False
|
| 179 |
+
enc_layers = wavlm.encoder.layers
|
| 180 |
+
n_layers = len(enc_layers)
|
| 181 |
+
for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:
|
| 182 |
+
for p in layer.parameters():
|
| 183 |
+
p.requires_grad = True
|
| 184 |
+
n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)
|
| 185 |
+
print(f"WavLM: {n_layers} lớp encoder · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} lớp trên "
|
| 186 |
+
f"→ {n_train/1e6:.1f}M param train (trên dim {WAVLM_DIM})")
|
| 187 |
+
|
| 188 |
+
if USE_GRAD_CKPT:
|
| 189 |
+
wavlm.gradient_checkpointing_enable()
|
| 190 |
+
if hasattr(wavlm, "enable_input_require_grads"):
|
| 191 |
+
wavlm.enable_input_require_grads() # cần khi grad-ckpt + lớp dưới đóng băng
|
| 192 |
+
|
| 193 |
+
def masked_mean(hidden, attn_mask):
|
| 194 |
+
"""Mean-pool theo thời gian, bỏ qua phần pad (giữ gradient)."""
|
| 195 |
+
if attn_mask is None:
|
| 196 |
+
return hidden.mean(dim=1)
|
| 197 |
+
try:
|
| 198 |
+
fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
|
| 199 |
+
except Exception:
|
| 200 |
+
return hidden.mean(dim=1)
|
| 201 |
+
fm = fm.unsqueeze(-1).to(hidden.dtype)
|
| 202 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 203 |
+
|
| 204 |
+
def wavlm_embed(input_values, attn_mask):
|
| 205 |
+
out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state # [B,T,D]
|
| 206 |
+
return masked_mean(out, attn_mask)
|
| 207 |
+
|
| 208 |
+
# %% [markdown]
|
| 209 |
+
# ## 3. Nạp audeering MSP-dim (FROZEN) — đặc trưng phụ
|
| 210 |
+
# Lấy `[emb_pool(1024) | vad3(1–5)]` mỗi wav rồi **cache .npz** (chạy 1 lần). Kỹ thuật nạp head tay
|
| 211 |
+
# y hệt exp05 (tránh lỗi version transformers khi subclass `Wav2Vec2PreTrainedModel`).
|
| 212 |
+
|
| 213 |
+
# %%
|
| 214 |
+
AUD_DIM = 0
|
| 215 |
+
aud_backbone = aud_head = aud_proc = None
|
| 216 |
+
if USE_AUDEERING:
|
| 217 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 218 |
+
from huggingface_hub import hf_hub_download
|
| 219 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 220 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 221 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 222 |
+
aud_backbone = Wav2Vec2Model(aud_cfg)
|
| 223 |
+
try:
|
| 224 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 225 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 226 |
+
except Exception:
|
| 227 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 228 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 229 |
+
missing, unexpected = aud_backbone.load_state_dict(bb_sd, strict=False)
|
| 230 |
+
print(f" audeering backbone: thiếu {len(missing)} / dư {len(unexpected)} key (strict=False)")
|
| 231 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 232 |
+
_out = _sd["classifier.out_proj.weight"].shape[0]
|
| 233 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
|
| 234 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 235 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 236 |
+
aud_backbone = aud_backbone.to(device).eval()
|
| 237 |
+
aud_head = aud_head.to(device).eval()
|
| 238 |
+
AUD_DIM = _hid + 3 # emb_pool + [VAL,ARO,DOM]
|
| 239 |
+
print(f"✅ audeering frozen (đặc trưng phụ {AUD_DIM}-D = emb {_hid} + vad 3)")
|
| 240 |
+
|
| 241 |
+
# %%
|
| 242 |
+
import numpy as np
|
| 243 |
+
import librosa
|
| 244 |
+
from tqdm.auto import tqdm
|
| 245 |
+
|
| 246 |
+
def load_wav(name_or_stem, in_wav_dir=True):
|
| 247 |
+
p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(
|
| 248 |
+
WAV_DIR, name_or_stem if str(name_or_stem).endswith(".wav") else str(name_or_stem) + ".wav")
|
| 249 |
+
if not os.path.exists(p):
|
| 250 |
+
return None
|
| 251 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 252 |
+
return wave[: MAX_SECONDS * SR].astype(np.float32)
|
| 253 |
+
|
| 254 |
+
@torch.no_grad()
|
| 255 |
+
def extract_audeering(stems, tag):
|
| 256 |
+
"""→ dict {stem: float32[AUD_DIM]}; cache CACHE_DIR/aud_<tag>.npz (resume mỗi 500)."""
|
| 257 |
+
if not USE_AUDEERING:
|
| 258 |
+
return {}
|
| 259 |
+
cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
|
| 260 |
+
store = {}
|
| 261 |
+
if os.path.exists(cache_path):
|
| 262 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 263 |
+
store = {k: z[k] for k in z.files}
|
| 264 |
+
print(f"[aud/{tag}] nạp cache: {len(store)}")
|
| 265 |
+
todo = [s for s in stems if s not in store]
|
| 266 |
+
for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
|
| 267 |
+
wave = load_wav(s)
|
| 268 |
+
if wave is None:
|
| 269 |
+
continue
|
| 270 |
+
x = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 271 |
+
x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
|
| 272 |
+
h = aud_backbone(x)[0].mean(dim=1) # [1, hid]
|
| 273 |
+
out = aud_head(h)[0].cpu().numpy() # [arousal, dominance, valence] ∈[0,1]
|
| 274 |
+
vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
|
| 275 |
+
store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
|
| 276 |
+
if (i + 1) % 500 == 0:
|
| 277 |
+
np.savez(cache_path, **store)
|
| 278 |
+
if todo:
|
| 279 |
+
np.savez(cache_path, **store)
|
| 280 |
+
return store
|
| 281 |
+
|
| 282 |
+
# %% [markdown]
|
| 283 |
+
# ## 4. Đọc & gộp nhãn theo wavID (EMOS / VAD / CAT) — như exp04/07 nhưng KHÔNG cần qMOS
|
| 284 |
+
|
| 285 |
+
# %%
|
| 286 |
+
import pandas as pd
|
| 287 |
+
|
| 288 |
+
def load_target_emotions():
|
| 289 |
+
tgt = {}
|
| 290 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 291 |
+
for ln in f:
|
| 292 |
+
parts = ln.strip().split("|")
|
| 293 |
+
if len(parts) >= 2:
|
| 294 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 295 |
+
return tgt
|
| 296 |
+
|
| 297 |
+
def _col(cols_map, *names, df=None, default_idx=None):
|
| 298 |
+
for n in names:
|
| 299 |
+
if n in cols_map:
|
| 300 |
+
return cols_map[n]
|
| 301 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 302 |
+
|
| 303 |
+
def parse_emocat_votes(cell):
|
| 304 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 305 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 306 |
+
e = norm_emotion(tok)
|
| 307 |
+
if e in EMOTIONS5:
|
| 308 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 309 |
+
return v
|
| 310 |
+
|
| 311 |
+
def load_train_labels():
|
| 312 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 313 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 314 |
+
wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
|
| 315 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 316 |
+
val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
|
| 317 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 318 |
+
assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
|
| 319 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 320 |
+
rows = []
|
| 321 |
+
for sid, g in df.groupby("_stem"):
|
| 322 |
+
rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
|
| 323 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 324 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 325 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 326 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 327 |
+
if cat_col:
|
| 328 |
+
for cell in g[cat_col]:
|
| 329 |
+
votes += parse_emocat_votes(cell)
|
| 330 |
+
s = votes.sum()
|
| 331 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
|
| 332 |
+
for i in range(len(EMOTIONS5)):
|
| 333 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 334 |
+
rows.append(rec)
|
| 335 |
+
return pd.DataFrame(rows)
|
| 336 |
+
|
| 337 |
+
target_map = load_target_emotions()
|
| 338 |
+
train_df = load_train_labels()
|
| 339 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 340 |
+
print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
|
| 341 |
+
|
| 342 |
+
# %% [markdown]
|
| 343 |
+
# ## 5. Dataset / DataLoader (load wav theo batch — KHÔNG cache WavLM vì đang train)
|
| 344 |
+
|
| 345 |
+
# %%
|
| 346 |
+
from torch.utils.data import Dataset, DataLoader
|
| 347 |
+
|
| 348 |
+
train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
|
| 349 |
+
if LIMIT_TRAIN:
|
| 350 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 351 |
+
aud_tr = extract_audeering(train_stems, "train")
|
| 352 |
+
|
| 353 |
+
lab = train_df.set_index("wavID")
|
| 354 |
+
|
| 355 |
+
# Chuẩn hóa nhãn liên tục về z-score (để các MSE cùng thang) — lưu để giải mã lúc dự đoán.
|
| 356 |
+
def _zfit(arr):
|
| 357 |
+
a = np.asarray(arr, dtype=np.float32)
|
| 358 |
+
return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
|
| 359 |
+
|
| 360 |
+
emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
|
| 361 |
+
if HAS_VAD:
|
| 362 |
+
vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], dtype=np.float32)
|
| 363 |
+
vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], dtype=np.float32)
|
| 364 |
+
else:
|
| 365 |
+
vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
|
| 366 |
+
|
| 367 |
+
def onehot_target(tgt):
|
| 368 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 369 |
+
if tgt in EMOTIONS5:
|
| 370 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 371 |
+
return v
|
| 372 |
+
|
| 373 |
+
class EmoDataset(Dataset):
|
| 374 |
+
def __init__(self, stems):
|
| 375 |
+
self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]
|
| 376 |
+
def __len__(self):
|
| 377 |
+
return len(self.stems)
|
| 378 |
+
def __getitem__(self, i):
|
| 379 |
+
s = self.stems[i]
|
| 380 |
+
wave = load_wav(s)
|
| 381 |
+
emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
|
| 382 |
+
if HAS_VAD:
|
| 383 |
+
vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
|
| 384 |
+
else:
|
| 385 |
+
vad = np.zeros(3, dtype=np.float32)
|
| 386 |
+
cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
|
| 387 |
+
aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)
|
| 388 |
+
return {"wave": wave, "tgt": onehot_target(target_map.get(s)), "aud": aud,
|
| 389 |
+
"emos": np.float32(emos), "vad": vad, "cat": cat,
|
| 390 |
+
"emos_raw": np.float32(lab.loc[s, "emos"]),
|
| 391 |
+
"vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
|
| 392 |
+
|
| 393 |
+
def collate(batch):
|
| 394 |
+
lens = [len(b["wave"]) for b in batch]
|
| 395 |
+
L = max(lens)
|
| 396 |
+
waves = np.zeros((len(batch), L), dtype=np.float32)
|
| 397 |
+
mask = np.zeros((len(batch), L), dtype=np.float32)
|
| 398 |
+
for i, b in enumerate(batch):
|
| 399 |
+
waves[i, : len(b["wave"])] = b["wave"]; mask[i, : len(b["wave"])] = 1.0
|
| 400 |
+
out = {
|
| 401 |
+
"input_values": torch.from_numpy(waves), "attn_mask": torch.from_numpy(mask).long(),
|
| 402 |
+
"tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
|
| 403 |
+
"aud": torch.from_numpy(np.stack([b["aud"] for b in batch])) if USE_AUDEERING else None,
|
| 404 |
+
"emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
|
| 405 |
+
"vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
|
| 406 |
+
"cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
|
| 407 |
+
"emos_raw": np.stack([b["emos_raw"] for b in batch]),
|
| 408 |
+
"vad_raw": np.stack([b["vad_raw"] for b in batch]),
|
| 409 |
+
}
|
| 410 |
+
return out
|
| 411 |
+
|
| 412 |
+
from sklearn.model_selection import train_test_split
|
| 413 |
+
ds = EmoDataset(train_stems)
|
| 414 |
+
print("Dataset hợp lệ:", len(ds), "wav")
|
| 415 |
+
tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
|
| 416 |
+
tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
|
| 417 |
+
va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
|
| 418 |
+
|
| 419 |
+
# %% [markdown]
|
| 420 |
+
# ## 6. Head fusion (trunk + 3 head cảm xúc) + train loop (AMP + grad accumulation)
|
| 421 |
+
|
| 422 |
+
# %%
|
| 423 |
+
from scipy.stats import spearmanr
|
| 424 |
+
|
| 425 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 426 |
+
N_EMO = len(EMOTIONS5)
|
| 427 |
+
TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
|
| 428 |
+
|
| 429 |
+
class EmoHeads(nn.Module):
|
| 430 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 431 |
+
super().__init__()
|
| 432 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 433 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 434 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 435 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 436 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 437 |
+
def forward(self, feat, tgt):
|
| 438 |
+
h = self.trunk(feat)
|
| 439 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 440 |
+
|
| 441 |
+
heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
|
| 442 |
+
print(f"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM if USE_AUDEERING else 0})")
|
| 443 |
+
|
| 444 |
+
TASKS = ["emos", "cat", "val", "aro", "dom"]
|
| 445 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 446 |
+
bb_params = [p for p in wavlm.parameters() if p.requires_grad]
|
| 447 |
+
head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
|
| 448 |
+
opt = torch.optim.AdamW([
|
| 449 |
+
{"params": bb_params, "lr": LR_BACKBONE},
|
| 450 |
+
{"params": head_params, "lr": LR_HEAD},
|
| 451 |
+
], weight_decay=WEIGHT_DECAY)
|
| 452 |
+
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
|
| 453 |
+
mse = nn.MSELoss()
|
| 454 |
+
|
| 455 |
+
def soft_ce(logits, target_dist):
|
| 456 |
+
return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
|
| 457 |
+
|
| 458 |
+
def forward_batch(b):
|
| 459 |
+
feat_wavlm = wavlm_embed(b["input_values"].to(device), b["attn_mask"].to(device))
|
| 460 |
+
if USE_AUDEERING:
|
| 461 |
+
feat = torch.cat([feat_wavlm, b["aud"].to(device)], dim=1)
|
| 462 |
+
else:
|
| 463 |
+
feat = feat_wavlm
|
| 464 |
+
return heads(feat, b["tgt"].to(device))
|
| 465 |
+
|
| 466 |
+
def compute_loss(emos_p, cat_l, vad_p, b):
|
| 467 |
+
L = {}
|
| 468 |
+
L["emos"] = mse(emos_p, b["emos"].to(device))
|
| 469 |
+
L["cat"] = soft_ce(cat_l, b["cat"].to(device))
|
| 470 |
+
if HAS_VAD:
|
| 471 |
+
vt = b["vad"].to(device)
|
| 472 |
+
L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
|
| 473 |
+
else:
|
| 474 |
+
z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
|
| 475 |
+
if USE_UNCERTAINTY:
|
| 476 |
+
return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
|
| 477 |
+
return sum(L.values())
|
| 478 |
+
|
| 479 |
+
@torch.no_grad()
|
| 480 |
+
def evaluate():
|
| 481 |
+
wavlm.eval(); heads.eval()
|
| 482 |
+
P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
|
| 483 |
+
catP, catY = [], []
|
| 484 |
+
for b in va_loader:
|
| 485 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 486 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 487 |
+
P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
|
| 488 |
+
vad_p = vad_p.float().cpu().numpy()
|
| 489 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 490 |
+
P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
|
| 491 |
+
catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
|
| 492 |
+
out = {}
|
| 493 |
+
for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
|
| 494 |
+
out[t] = spearmanr(P[t], Y[t]).correlation
|
| 495 |
+
q = np.concatenate(catP); p = np.concatenate(catY)
|
| 496 |
+
out["cat_err"] = float(np.abs(q - p).sum(1).mean()) # ~ tổng |Δ| trung bình (xấp xỉ CAT-ERR)
|
| 497 |
+
return out
|
| 498 |
+
|
| 499 |
+
def mean_srcc(m):
|
| 500 |
+
keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 501 |
+
return float(np.mean([m[k] for k in keys]))
|
| 502 |
+
|
| 503 |
+
# Lưu checkpoint FULL (có backbone WavLM) — gọi NGAY mỗi best để kernel chết giữa chừng vẫn còn file.
|
| 504 |
+
CKPT_PATH = os.path.join(OUT_DIR, "ft_emotion_full.pt")
|
| 505 |
+
def save_full_ckpt(state, val_emos=float("nan")):
|
| 506 |
+
torch.save({"wavlm": state["wavlm"], "heads": state["heads"],
|
| 507 |
+
"emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
|
| 508 |
+
"WAVLM_DIM": WAVLM_DIM, "AUD_DIM": AUD_DIM,
|
| 509 |
+
"UNFREEZE_TOP_LAYERS": UNFREEZE_TOP_LAYERS, "val_emos": float(val_emos)}, CKPT_PATH)
|
| 510 |
+
|
| 511 |
+
best, best_state, bad = -1e9, None, 0
|
| 512 |
+
for ep in range(1, EPOCHS + 1):
|
| 513 |
+
wavlm.train(); heads.train()
|
| 514 |
+
opt.zero_grad(); run = 0.0; nb = 0
|
| 515 |
+
for step, b in enumerate(tqdm(tr_loader, desc=f"epoch {ep}")):
|
| 516 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 517 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 518 |
+
loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
|
| 519 |
+
scaler.scale(loss).backward()
|
| 520 |
+
if (step + 1) % ACCUM == 0:
|
| 521 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 522 |
+
run += loss.item() * ACCUM; nb += 1
|
| 523 |
+
m = evaluate(); sc = mean_srcc(m)
|
| 524 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
|
| 525 |
+
print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
|
| 526 |
+
if sc > best:
|
| 527 |
+
best = sc
|
| 528 |
+
best_state = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
|
| 529 |
+
"heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
|
| 530 |
+
save_full_ckpt(best_state, m["emos"]) # LƯU NGAY mỗi best → an toàn nếu kernel chết
|
| 531 |
+
print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})")
|
| 532 |
+
bad = 0
|
| 533 |
+
else:
|
| 534 |
+
bad += 1
|
| 535 |
+
if bad >= PATIENCE:
|
| 536 |
+
print(f"Early stop ở epoch {ep}."); break
|
| 537 |
+
|
| 538 |
+
if best_state:
|
| 539 |
+
wavlm.load_state_dict(best_state["wavlm"]); heads.load_state_dict(best_state["heads"])
|
| 540 |
+
final = evaluate()
|
| 541 |
+
print("\n✅ VAL (nội bộ) — exp08 (fine-tune WavLM cho cảm xúc):")
|
| 542 |
+
print(f" EMOS={final['emos']:.4f} (exp07 {EXP07['emos']})")
|
| 543 |
+
if HAS_VAD:
|
| 544 |
+
print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} "
|
| 545 |
+
f"(exp07 {EXP07['val']}/{EXP07['aro']}/{EXP07['dom']})")
|
| 546 |
+
warn = [f"EMOS {final['emos']:.3f}<{EXP07['emos']}"] if final["emos"] < EXP07["emos"] - 0.005 else []
|
| 547 |
+
if HAS_VAD:
|
| 548 |
+
warn += [f"{t.upper()} {final[t]:.3f}<{EXP07[t]}" for t in ["val", "aro", "dom"] if final[t] < EXP07[t] - 0.005]
|
| 549 |
+
print(" ⚠️ CHƯA thắng exp07 ở:", "; ".join(warn), "→ cân nhắc giữ exp07." if warn else "")
|
| 550 |
+
if not warn:
|
| 551 |
+
print(" ✅ Fine-tune thắng/ngang exp07 ở mọi cột cảm xúc → đáng nộp.")
|
| 552 |
+
# Lưu lần cuối từ best (đã lưu sẵn mỗi best trong loop; đây là phát cuối cho chắc).
|
| 553 |
+
save_full_ckpt(best_state if best_state else
|
| 554 |
+
{"wavlm": wavlm.state_dict(), "heads": heads.state_dict()}, final["emos"])
|
| 555 |
+
print(f"✅ Đã lưu {CKPT_PATH} (CÓ backbone WavLM + heads → resume được). "
|
| 556 |
+
f"NHỚ Save Version để file ra Output!")
|
| 557 |
+
|
| 558 |
+
# %% [markdown]
|
| 559 |
+
# ## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc từ exp08; QMOS mượn exp07 hoặc UTMOS)
|
| 560 |
+
|
| 561 |
+
# %%
|
| 562 |
+
def list_dev():
|
| 563 |
+
with open(DEV_SCP) as f:
|
| 564 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 565 |
+
|
| 566 |
+
dev_names = list_dev()
|
| 567 |
+
if LIMIT_DEV:
|
| 568 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 569 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 570 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 571 |
+
aud_dev = extract_audeering(dev_stems, "dev")
|
| 572 |
+
|
| 573 |
+
# QMOS: ưu tiên mượn cột QMOS của exp07; không có file → chấm UTMOSv2 (T05, vô địch VMC2024).
|
| 574 |
+
def load_exp07_qmos():
|
| 575 |
+
if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
|
| 576 |
+
import csv
|
| 577 |
+
d = {}
|
| 578 |
+
with open(EXP07_ANSWER) as f:
|
| 579 |
+
r = csv.DictReader(f)
|
| 580 |
+
for row in r:
|
| 581 |
+
d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
|
| 582 |
+
print(f"✅ Mượn QMOS từ exp07 ({EXP07_ANSWER}): {len(d)//2} wav")
|
| 583 |
+
return d
|
| 584 |
+
return None
|
| 585 |
+
|
| 586 |
+
qmos_map = load_exp07_qmos()
|
| 587 |
+
if qmos_map is None:
|
| 588 |
+
print("ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024 Track 1).")
|
| 589 |
+
pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git") # cần Internet On, checkpoint tự tải
|
| 590 |
+
import utmosv2
|
| 591 |
+
v2 = utmosv2.create_model(pretrained=True)
|
| 592 |
+
qmos_map = {}
|
| 593 |
+
for n in tqdm(dev_names, desc="UTMOSv2"):
|
| 594 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
|
| 595 |
+
if not os.path.exists(wav):
|
| 596 |
+
continue
|
| 597 |
+
out = v2.predict(input_path=wav) # trả float hoặc dict {'predicted_mos': ...} tùy phiên bản
|
| 598 |
+
qmos_map[n] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
|
| 599 |
+
del v2; torch.cuda.empty_cache() if device == "cuda" else None
|
| 600 |
+
|
| 601 |
+
@torch.no_grad()
|
| 602 |
+
def predict_emotion(sid):
|
| 603 |
+
wave = load_wav(sid)
|
| 604 |
+
if wave is None or (USE_AUDEERING and sid not in aud_dev):
|
| 605 |
+
return None
|
| 606 |
+
wavlm.eval(); heads.eval()
|
| 607 |
+
iv = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 608 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
|
| 609 |
+
tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
|
| 610 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 611 |
+
fw = wavlm_embed(iv, am)
|
| 612 |
+
if USE_AUDEERING:
|
| 613 |
+
aud = torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)
|
| 614 |
+
feat = torch.cat([fw, aud], dim=1)
|
| 615 |
+
else:
|
| 616 |
+
feat = fw
|
| 617 |
+
emos_p, cat_l, vad_p = heads(feat, tgt)
|
| 618 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 619 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 620 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 621 |
+
return emos, cat5, vad3
|
| 622 |
+
|
| 623 |
+
def fmt_cat(p5):
|
| 624 |
+
return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 625 |
+
|
| 626 |
+
def build_answer(out_path):
|
| 627 |
+
n_real = n_def = 0
|
| 628 |
+
with open(out_path, "w") as f:
|
| 629 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 630 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 631 |
+
sid = stem(name)
|
| 632 |
+
pr = predict_emotion(sid)
|
| 633 |
+
if pr is None:
|
| 634 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
|
| 635 |
+
else:
|
| 636 |
+
emos, cat5, vad3 = pr; n_real += 1
|
| 637 |
+
qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
|
| 638 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 639 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
|
| 640 |
+
|
| 641 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 642 |
+
build_answer(answer_path)
|
| 643 |
+
|
| 644 |
+
# %% [markdown]
|
| 645 |
+
# ## 8. Validate + đóng zip
|
| 646 |
+
|
| 647 |
+
# %%
|
| 648 |
+
def validate(path):
|
| 649 |
+
import csv
|
| 650 |
+
with open(path) as f:
|
| 651 |
+
rows = list(csv.reader(f))
|
| 652 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0] and "EMOS" in rows[0], "Header sai"
|
| 653 |
+
for i, r in enumerate(rows[1:], 2):
|
| 654 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 655 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 656 |
+
|
| 657 |
+
validate(answer_path)
|
| 658 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp08_ft-emotion.zip answer.txt "
|
| 659 |
+
f"&& unzip -l submission_track2_exp08_ft-emotion.zip")
|
| 660 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp08_ft-emotion.zip"))
|
| 661 |
+
|
| 662 |
+
# %% [markdown]
|
| 663 |
+
# ## Ghi chú
|
| 664 |
+
# - **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để kiểm tra chạy trơn (1 epoch xong không OOM); rồi đặt `None`.
|
| 665 |
+
# - **OOM trên T4?** giảm theo thứ tự: `MAX_SECONDS` (8→6) → `UNFREEZE_TOP_LAYERS` (6→4→2) → `BATCH` (4→2, tăng `ACCUM`).
|
| 666 |
+
# - **Đọc mục 6:** so EMOS/VAD VAL nội bộ với mốc exp07 (EMOS 0.795 · VAL 0.581 · ARO 0.752 · DOM 0.705).
|
| 667 |
+
# - Nếu fine-tune **thắng** → nộp answer.txt exp08 (5 cột cảm xúc của exp08 + QMOS mượn exp07).
|
| 668 |
+
# - Nếu **thua** → giữ exp07; vẫn là kết quả cho paper ("fine-tune chưa vượt frozen-fusion trên data nhỏ").
|
| 669 |
+
# - **QMOS:** Add Input answer.txt exp07 vào `/kaggle/input/exp07-answer/answer.txt` để mượn cột QMOS 0.548;
|
| 670 |
+
# không có thì tự chấm UTMOSv2 (T05, vô địch VMC2024 — mạnh hơn UTMOS, cần Internet On).
|
| 671 |
+
# - **Ablation cho paper:** `UNFREEZE_TOP_LAYERS=0` (≈ head-only) vs `=6` (fine-tune) → bảng "frozen vs fine-tuned".
|
| 672 |
+
# `USE_AUDEERING=False` → đo đóng góp nhánh phụ.
|
| 673 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp08).
|
track2/exp08b_finetune_resume.ipynb
ADDED
|
@@ -0,0 +1,782 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "ce468400",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp08-RESUME (fine-tune TIẾP từ checkpoint + cache) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục đích:** train tiếp model fine-tune cảm xúc (exp08) từ **checkpoint đã lưu** thay vì train lại từ\n",
|
| 11 |
+
"đầu — tiết kiệm giờ GPU. Tận dụng:\n",
|
| 12 |
+
"- `ft_emotion_full.pt` (CÓ cả backbone WavLM + heads + thống kê chuẩn hóa) → nạp lại đúng trạng thái.\n",
|
| 13 |
+
"- **cache audeering** `aud_*.npz` (đặc trưng frozen) → KHÔNG trích lại (~đỡ chục phút).\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"> ⚠️ Bắt buộc dùng checkpoint **đủ backbone** (`ft_emotion_full.pt` từ cell \"TRAIN TIẾP\", hoặc bản\n",
|
| 16 |
+
"> `ft_emotion_meta.pt` MỚI đã vá để lưu cả `wavlm`). Bản `ft_emotion_meta.pt` CŨ chỉ có `heads` → KHÔNG dùng được.\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"## Chuẩn bị input trên Kaggle (Add Input)\n",
|
| 19 |
+
"1. Dataset Track 2 (`vmc2026-track2-full`) — wav + nhãn.\n",
|
| 20 |
+
"2. **Checkpoint**: upload `ft_emotion_full.pt` thành 1 Dataset → trỏ `RESUME_CKPT`.\n",
|
| 21 |
+
"3. **Cache** (tùy chọn nhưng nên có): upload thư mục chứa `aud_train.npz`, `aud_dev.npz` → trỏ `CACHE_INPUT`.\n",
|
| 22 |
+
"4. (tùy chọn) `answer.txt` exp07 để mượn cột QMOS 0.548.\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"**Cách chạy:** GPU T4 + Internet On → sửa các slug ở cell 0 → Run All. Lần đầu để `LIMIT_TRAIN=300`."
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"id": "1c6752ee",
|
| 30 |
+
"metadata": {},
|
| 31 |
+
"source": [
|
| 32 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 33 |
+
]
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"cell_type": "code",
|
| 37 |
+
"execution_count": null,
|
| 38 |
+
"id": "8d6317ac",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"outputs": [],
|
| 41 |
+
"source": [
|
| 42 |
+
"import os, shutil\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
|
| 45 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 46 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\"\n",
|
| 47 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\"\n",
|
| 48 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"# ── Checkpoint + cache để RESUME ─────────────────────────────────────────────\n",
|
| 51 |
+
"RESUME_CKPT = \"/kaggle/input/ft-emotion-full/ft_emotion_full.pt\" # << CHECKPOINT đủ backbone\n",
|
| 52 |
+
"CACHE_INPUT = \"/kaggle/input/ft-emotion-cache\" # << thư mục chứa aud_*.npz (hoặc \"\" nếu không có)\n",
|
| 53 |
+
"EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) mượn QMOS 0.548; không có → UTMOSv2\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 56 |
+
"CACHE_DIR = \"/kaggle/working/ft_cache\" # /kaggle/input read-only → copy cache sang đây để ghi/append được\n",
|
| 57 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"# ── Fine-tune / siêu tham số (train TIẾP) ────────────────────────────────────\n",
|
| 60 |
+
"DEVICE = \"cuda\"\n",
|
| 61 |
+
"SR = 16000\n",
|
| 62 |
+
"MAX_SECONDS = 8\n",
|
| 63 |
+
"UNFREEZE_TOP_LAYERS = 6 # PHẢI khớp checkpoint (mặc định exp08 = 6)\n",
|
| 64 |
+
"TRUNK_HIDDEN = 512 # PHẢI khớp checkpoint\n",
|
| 65 |
+
"HEAD_HIDDEN = 128 # PHẢI khớp checkpoint\n",
|
| 66 |
+
"DROPOUT = 0.3\n",
|
| 67 |
+
"LR_BACKBONE = 1e-5\n",
|
| 68 |
+
"LR_HEAD = 1e-3\n",
|
| 69 |
+
"RESUME_LR_SCALE = 1.0 # <1.0 để giảm LR khi train tiếp (vd 0.5 nếu val đã chững)\n",
|
| 70 |
+
"WEIGHT_DECAY = 1e-5\n",
|
| 71 |
+
"EPOCHS = 10 # số epoch train THÊM (run này)\n",
|
| 72 |
+
"PATIENCE = 5 # dừng khi val không lên; LUÔN giữ best\n",
|
| 73 |
+
"BATCH = 4\n",
|
| 74 |
+
"ACCUM = 8 # effective batch = 32\n",
|
| 75 |
+
"VAL_FRAC = 0.10\n",
|
| 76 |
+
"SEED = 42\n",
|
| 77 |
+
"USE_AMP = True\n",
|
| 78 |
+
"USE_GRAD_CKPT = True\n",
|
| 79 |
+
"USE_AUDEERING = True # PHẢI khớp checkpoint (exp08 = True)\n",
|
| 80 |
+
"USE_UNCERTAINTY = True\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
|
| 83 |
+
"LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"# Mốc exp07 + exp08 để so\n",
|
| 86 |
+
"EXP07 = {\"emos\": 0.795, \"cat_err\": 0.153, \"val\": 0.581, \"aro\": 0.752, \"dom\": 0.705}\n",
|
| 87 |
+
"EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751} # bản đã nộp\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 90 |
+
"_EMO_ALIAS = {\n",
|
| 91 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 92 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 93 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 94 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 95 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 96 |
+
"}\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"def norm_emotion(label):\n",
|
| 99 |
+
" key = str(label).strip().lower()\n",
|
| 100 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"def stem(p):\n",
|
| 103 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 106 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP, RESUME_CKPT]:\n",
|
| 107 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"# Copy cache (aud_*.npz) từ input read-only sang working để append được\n",
|
| 110 |
+
"if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
|
| 111 |
+
" n = 0\n",
|
| 112 |
+
" for fn in os.listdir(CACHE_INPUT):\n",
|
| 113 |
+
" if fn.startswith(\"aud_\") and fn.endswith(\".npz\"):\n",
|
| 114 |
+
" shutil.copy(os.path.join(CACHE_INPUT, fn), os.path.join(CACHE_DIR, fn)); n += 1\n",
|
| 115 |
+
" print(f\"📦 Copy {n} file cache audeering từ {CACHE_INPUT} → {CACHE_DIR}\")\n",
|
| 116 |
+
"else:\n",
|
| 117 |
+
" print(\"ℹ️ Không có CACHE_INPUT → sẽ tự trích audeering (chậm hơn lần đầu).\")"
|
| 118 |
+
]
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"cell_type": "markdown",
|
| 122 |
+
"id": "57e6416a",
|
| 123 |
+
"metadata": {},
|
| 124 |
+
"source": [
|
| 125 |
+
"## 1. Cài đặt + tải code SAILER (để dựng đúng kiến trúc WavLM rồi nạp checkpoint đè lên)"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "code",
|
| 130 |
+
"execution_count": null,
|
| 131 |
+
"id": "76497a3f",
|
| 132 |
+
"metadata": {},
|
| 133 |
+
"outputs": [],
|
| 134 |
+
"source": [
|
| 135 |
+
"import sys, subprocess\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"def pip_install(*pkgs):\n",
|
| 138 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\",\n",
|
| 141 |
+
" \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 144 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 145 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 146 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 147 |
+
"if REPO_DIR not in sys.path:\n",
|
| 148 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 149 |
+
]
|
| 150 |
+
},
|
| 151 |
+
{
|
| 152 |
+
"cell_type": "markdown",
|
| 153 |
+
"id": "cf6cf213",
|
| 154 |
+
"metadata": {},
|
| 155 |
+
"source": [
|
| 156 |
+
"## 2. Dựng WavLM (như exp08) → NẠP trọng số backbone từ checkpoint\n",
|
| 157 |
+
"Dựng đúng kiến trúc (SAILER wrapper → lấy backbone HF; fallback WavLM trắng), rồi `load_state_dict`\n",
|
| 158 |
+
"bằng `ckpt[\"wavlm\"]` → khôi phục đúng trạng thái fine-tune đã lưu."
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "code",
|
| 163 |
+
"execution_count": null,
|
| 164 |
+
"id": "20a2e84b",
|
| 165 |
+
"metadata": {
|
| 166 |
+
"lines_to_next_cell": 1
|
| 167 |
+
},
|
| 168 |
+
"outputs": [],
|
| 169 |
+
"source": [
|
| 170 |
+
"import torch\n",
|
| 171 |
+
"import torch.nn as nn\n",
|
| 172 |
+
"import torch.nn.functional as F\n",
|
| 173 |
+
"\n",
|
| 174 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 175 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"ckpt = torch.load(RESUME_CKPT, map_location=\"cpu\", weights_only=False) # ckpt có numpy (vad_mu) → cần False\n",
|
| 178 |
+
"assert \"wavlm\" in ckpt, (\"❌ Checkpoint KHÔNG có 'wavlm' (backbone). Đây là bản ft_emotion_meta.pt CŨ \"\n",
|
| 179 |
+
" \"chỉ lưu heads → không resume được. Hãy dùng ft_emotion_full.pt.\")\n",
|
| 180 |
+
"print(\"✅ Nạp checkpoint:\", RESUME_CKPT, \"| keys:\", list(ckpt.keys()))\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"def find_hf_backbone(module):\n",
|
| 183 |
+
" cands = []\n",
|
| 184 |
+
" for name, m in module.named_modules():\n",
|
| 185 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 186 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 187 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 188 |
+
" cands.append((name, m))\n",
|
| 189 |
+
" if not cands:\n",
|
| 190 |
+
" return None, None\n",
|
| 191 |
+
" cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
|
| 192 |
+
" return cands[0]\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"wavlm = None\n",
|
| 195 |
+
"try:\n",
|
| 196 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 197 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 198 |
+
" name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 199 |
+
" if wavlm is not None:\n",
|
| 200 |
+
" print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'\")\n",
|
| 201 |
+
"except Exception as e:\n",
|
| 202 |
+
" print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
|
| 203 |
+
"\n",
|
| 204 |
+
"if wavlm is None:\n",
|
| 205 |
+
" from transformers import WavLMModel\n",
|
| 206 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 207 |
+
" print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"wavlm = wavlm.to(device)\n",
|
| 210 |
+
"WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"# Nạp trọng số đã fine-tune từ checkpoint (đè lên kiến trúc vừa dựng)\n",
|
| 213 |
+
"miss, unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
|
| 214 |
+
"print(f\"🔁 load wavlm từ checkpoint: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).\")\n",
|
| 215 |
+
"if len(miss) > 20 or len(unexp) > 20:\n",
|
| 216 |
+
" print(\" ⚠️ Lệch key nhiều → kiến trúc có thể không khớp checkpoint. Kiểm tra UNFREEZE/USE_AUDEERING.\")\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"# Đóng băng partial: chỉ mở UNFREEZE_TOP_LAYERS lớp trên\n",
|
| 219 |
+
"for p in wavlm.parameters():\n",
|
| 220 |
+
" p.requires_grad = False\n",
|
| 221 |
+
"enc_layers = wavlm.encoder.layers\n",
|
| 222 |
+
"n_layers = len(enc_layers)\n",
|
| 223 |
+
"for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:\n",
|
| 224 |
+
" for p in layer.parameters():\n",
|
| 225 |
+
" p.requires_grad = True\n",
|
| 226 |
+
"n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)\n",
|
| 227 |
+
"print(f\"WavLM: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {WAVLM_DIM})\")\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"if USE_GRAD_CKPT:\n",
|
| 230 |
+
" wavlm.gradient_checkpointing_enable()\n",
|
| 231 |
+
" if hasattr(wavlm, \"enable_input_require_grads\"):\n",
|
| 232 |
+
" wavlm.enable_input_require_grads()\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"def masked_mean(hidden, attn_mask):\n",
|
| 235 |
+
" if attn_mask is None:\n",
|
| 236 |
+
" return hidden.mean(dim=1)\n",
|
| 237 |
+
" try:\n",
|
| 238 |
+
" fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
|
| 239 |
+
" except Exception:\n",
|
| 240 |
+
" return hidden.mean(dim=1)\n",
|
| 241 |
+
" fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
|
| 242 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"def wavlm_embed(input_values, attn_mask):\n",
|
| 245 |
+
" out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
|
| 246 |
+
" return masked_mean(out, attn_mask)"
|
| 247 |
+
]
|
| 248 |
+
},
|
| 249 |
+
{
|
| 250 |
+
"cell_type": "markdown",
|
| 251 |
+
"id": "156a5f4d",
|
| 252 |
+
"metadata": {},
|
| 253 |
+
"source": [
|
| 254 |
+
"## 3. audeering FROZEN (đặc trưng phụ) — dùng cache nếu có"
|
| 255 |
+
]
|
| 256 |
+
},
|
| 257 |
+
{
|
| 258 |
+
"cell_type": "code",
|
| 259 |
+
"execution_count": null,
|
| 260 |
+
"id": "670569c7",
|
| 261 |
+
"metadata": {
|
| 262 |
+
"lines_to_next_cell": 1
|
| 263 |
+
},
|
| 264 |
+
"outputs": [],
|
| 265 |
+
"source": [
|
| 266 |
+
"import numpy as np\n",
|
| 267 |
+
"import librosa\n",
|
| 268 |
+
"from tqdm.auto import tqdm\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"AUD_DIM = 0\n",
|
| 271 |
+
"aud_backbone = aud_head = aud_proc = None\n",
|
| 272 |
+
"if USE_AUDEERING:\n",
|
| 273 |
+
" from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 274 |
+
" from huggingface_hub import hf_hub_download\n",
|
| 275 |
+
" AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 276 |
+
" aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 277 |
+
" aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 278 |
+
" aud_backbone = Wav2Vec2Model(aud_cfg)\n",
|
| 279 |
+
" try:\n",
|
| 280 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 281 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 282 |
+
" except Exception:\n",
|
| 283 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 284 |
+
" bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 285 |
+
" aud_backbone.load_state_dict(bb_sd, strict=False)\n",
|
| 286 |
+
" _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 287 |
+
" _out = _sd[\"classifier.out_proj.weight\"].shape[0]\n",
|
| 288 |
+
" aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
|
| 289 |
+
" aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 290 |
+
" aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 291 |
+
" aud_backbone = aud_backbone.to(device).eval()\n",
|
| 292 |
+
" aud_head = aud_head.to(device).eval()\n",
|
| 293 |
+
" AUD_DIM = _hid + 3\n",
|
| 294 |
+
" print(f\"✅ audeering frozen ({AUD_DIM}-D)\")\n",
|
| 295 |
+
"\n",
|
| 296 |
+
"def load_wav(name_or_stem):\n",
|
| 297 |
+
" p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(\n",
|
| 298 |
+
" WAV_DIR, name_or_stem if str(name_or_stem).endswith(\".wav\") else str(name_or_stem) + \".wav\")\n",
|
| 299 |
+
" if not os.path.exists(p):\n",
|
| 300 |
+
" return None\n",
|
| 301 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 302 |
+
" return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"@torch.no_grad()\n",
|
| 305 |
+
"def extract_audeering(stems, tag):\n",
|
| 306 |
+
" if not USE_AUDEERING:\n",
|
| 307 |
+
" return {}\n",
|
| 308 |
+
" cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
|
| 309 |
+
" store = {}\n",
|
| 310 |
+
" if os.path.exists(cache_path):\n",
|
| 311 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 312 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 313 |
+
" print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
|
| 314 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 315 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
|
| 316 |
+
" wave = load_wav(s)\n",
|
| 317 |
+
" if wave is None:\n",
|
| 318 |
+
" continue\n",
|
| 319 |
+
" x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 320 |
+
" x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
|
| 321 |
+
" h = aud_backbone(x)[0].mean(dim=1)\n",
|
| 322 |
+
" out = aud_head(h)[0].cpu().numpy()\n",
|
| 323 |
+
" vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32)\n",
|
| 324 |
+
" store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
|
| 325 |
+
" if (i + 1) % 500 == 0:\n",
|
| 326 |
+
" np.savez(cache_path, **store)\n",
|
| 327 |
+
" if todo:\n",
|
| 328 |
+
" np.savez(cache_path, **store)\n",
|
| 329 |
+
" return store"
|
| 330 |
+
]
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"cell_type": "markdown",
|
| 334 |
+
"id": "c5ed6f49",
|
| 335 |
+
"metadata": {},
|
| 336 |
+
"source": [
|
| 337 |
+
"## 4. Đọc & gộp nhãn theo wavID"
|
| 338 |
+
]
|
| 339 |
+
},
|
| 340 |
+
{
|
| 341 |
+
"cell_type": "code",
|
| 342 |
+
"execution_count": null,
|
| 343 |
+
"id": "910d097f",
|
| 344 |
+
"metadata": {},
|
| 345 |
+
"outputs": [],
|
| 346 |
+
"source": [
|
| 347 |
+
"import pandas as pd\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"def load_target_emotions():\n",
|
| 350 |
+
" tgt = {}\n",
|
| 351 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 352 |
+
" for ln in f:\n",
|
| 353 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 354 |
+
" if len(parts) >= 2:\n",
|
| 355 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 356 |
+
" return tgt\n",
|
| 357 |
+
"\n",
|
| 358 |
+
"def _col(cols_map, *names, df=None, default_idx=None):\n",
|
| 359 |
+
" for n in names:\n",
|
| 360 |
+
" if n in cols_map:\n",
|
| 361 |
+
" return cols_map[n]\n",
|
| 362 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 363 |
+
"\n",
|
| 364 |
+
"def parse_emocat_votes(cell):\n",
|
| 365 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 366 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 367 |
+
" e = norm_emotion(tok)\n",
|
| 368 |
+
" if e in EMOTIONS5:\n",
|
| 369 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 370 |
+
" return v\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"def load_train_labels():\n",
|
| 373 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 374 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 375 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
|
| 376 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 377 |
+
" val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 378 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 379 |
+
" assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
|
| 380 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 381 |
+
" rows = []\n",
|
| 382 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 383 |
+
" rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
|
| 384 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 385 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 386 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 387 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 388 |
+
" if cat_col:\n",
|
| 389 |
+
" for cell in g[cat_col]:\n",
|
| 390 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 391 |
+
" s = votes.sum()\n",
|
| 392 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
|
| 393 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 394 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 395 |
+
" rows.append(rec)\n",
|
| 396 |
+
" return pd.DataFrame(rows)\n",
|
| 397 |
+
"\n",
|
| 398 |
+
"target_map = load_target_emotions()\n",
|
| 399 |
+
"train_df = load_train_labels()\n",
|
| 400 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 401 |
+
"print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
|
| 402 |
+
]
|
| 403 |
+
},
|
| 404 |
+
{
|
| 405 |
+
"cell_type": "markdown",
|
| 406 |
+
"id": "1d9509a7",
|
| 407 |
+
"metadata": {},
|
| 408 |
+
"source": [
|
| 409 |
+
"## 5. Dataset/loader — DÙNG thống kê chuẩn hóa TỪ CHECKPOINT (để khớp head đã train)"
|
| 410 |
+
]
|
| 411 |
+
},
|
| 412 |
+
{
|
| 413 |
+
"cell_type": "code",
|
| 414 |
+
"execution_count": null,
|
| 415 |
+
"id": "4c09387b",
|
| 416 |
+
"metadata": {},
|
| 417 |
+
"outputs": [],
|
| 418 |
+
"source": [
|
| 419 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 420 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 421 |
+
"\n",
|
| 422 |
+
"train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
|
| 423 |
+
"if LIMIT_TRAIN:\n",
|
| 424 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 425 |
+
"aud_tr = extract_audeering(train_stems, \"train\")\n",
|
| 426 |
+
"\n",
|
| 427 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 428 |
+
"\n",
|
| 429 |
+
"# QUAN TRỌNG: lấy mean/std từ checkpoint (head đã train theo thang này) thay vì tính lại.\n",
|
| 430 |
+
"emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
|
| 431 |
+
"vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
|
| 432 |
+
"print(f\"Dùng chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
|
| 433 |
+
"\n",
|
| 434 |
+
"def onehot_target(tgt):\n",
|
| 435 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 436 |
+
" if tgt in EMOTIONS5:\n",
|
| 437 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 438 |
+
" return v\n",
|
| 439 |
+
"\n",
|
| 440 |
+
"class EmoDataset(Dataset):\n",
|
| 441 |
+
" def __init__(self, stems):\n",
|
| 442 |
+
" self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]\n",
|
| 443 |
+
" def __len__(self):\n",
|
| 444 |
+
" return len(self.stems)\n",
|
| 445 |
+
" def __getitem__(self, i):\n",
|
| 446 |
+
" s = self.stems[i]\n",
|
| 447 |
+
" wave = load_wav(s)\n",
|
| 448 |
+
" emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
|
| 449 |
+
" if HAS_VAD:\n",
|
| 450 |
+
" vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
|
| 451 |
+
" else:\n",
|
| 452 |
+
" vad = np.zeros(3, dtype=np.float32)\n",
|
| 453 |
+
" cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
|
| 454 |
+
" aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)\n",
|
| 455 |
+
" return {\"wave\": wave, \"tgt\": onehot_target(target_map.get(s)), \"aud\": aud,\n",
|
| 456 |
+
" \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
|
| 457 |
+
" \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
|
| 458 |
+
" \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"def collate(batch):\n",
|
| 461 |
+
" lens = [len(b[\"wave\"]) for b in batch]\n",
|
| 462 |
+
" L = max(lens)\n",
|
| 463 |
+
" waves = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 464 |
+
" mask = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 465 |
+
" for i, b in enumerate(batch):\n",
|
| 466 |
+
" waves[i, : len(b[\"wave\"])] = b[\"wave\"]; mask[i, : len(b[\"wave\"])] = 1.0\n",
|
| 467 |
+
" return {\n",
|
| 468 |
+
" \"input_values\": torch.from_numpy(waves), \"attn_mask\": torch.from_numpy(mask).long(),\n",
|
| 469 |
+
" \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
|
| 470 |
+
" \"aud\": torch.from_numpy(np.stack([b[\"aud\"] for b in batch])) if USE_AUDEERING else None,\n",
|
| 471 |
+
" \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
|
| 472 |
+
" \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
|
| 473 |
+
" \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
|
| 474 |
+
" \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
|
| 475 |
+
" \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
|
| 476 |
+
" }\n",
|
| 477 |
+
"\n",
|
| 478 |
+
"ds = EmoDataset(train_stems)\n",
|
| 479 |
+
"print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
|
| 480 |
+
"tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
|
| 481 |
+
"tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
|
| 482 |
+
"va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
|
| 483 |
+
]
|
| 484 |
+
},
|
| 485 |
+
{
|
| 486 |
+
"cell_type": "markdown",
|
| 487 |
+
"id": "5c16d942",
|
| 488 |
+
"metadata": {},
|
| 489 |
+
"source": [
|
| 490 |
+
"## 6. Heads (NẠP từ checkpoint) + optimizer + train TIẾP"
|
| 491 |
+
]
|
| 492 |
+
},
|
| 493 |
+
{
|
| 494 |
+
"cell_type": "code",
|
| 495 |
+
"execution_count": null,
|
| 496 |
+
"id": "7adfb320",
|
| 497 |
+
"metadata": {
|
| 498 |
+
"lines_to_next_cell": 1
|
| 499 |
+
},
|
| 500 |
+
"outputs": [],
|
| 501 |
+
"source": [
|
| 502 |
+
"from scipy.stats import spearmanr\n",
|
| 503 |
+
"\n",
|
| 504 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 505 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 506 |
+
"TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"class EmoHeads(nn.Module):\n",
|
| 509 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 510 |
+
" super().__init__()\n",
|
| 511 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 512 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 513 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 514 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 515 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 516 |
+
" def forward(self, feat, tgt):\n",
|
| 517 |
+
" h = self.trunk(feat)\n",
|
| 518 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 519 |
+
"\n",
|
| 520 |
+
"heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
|
| 521 |
+
"hmiss, hunexp = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
|
| 522 |
+
"print(f\"🔁 load heads từ checkpoint: thiếu {len(hmiss)} / dư {len(hunexp)} key (kỳ vọng 0).\")\n",
|
| 523 |
+
"print(f\"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM if USE_AUDEERING else 0})\")\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 526 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 527 |
+
"bb_params = [p for p in wavlm.parameters() if p.requires_grad]\n",
|
| 528 |
+
"head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 529 |
+
"opt = torch.optim.AdamW([\n",
|
| 530 |
+
" {\"params\": bb_params, \"lr\": LR_BACKBONE * RESUME_LR_SCALE},\n",
|
| 531 |
+
" {\"params\": head_params, \"lr\": LR_HEAD * RESUME_LR_SCALE},\n",
|
| 532 |
+
"], weight_decay=WEIGHT_DECAY)\n",
|
| 533 |
+
"scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
|
| 534 |
+
"mse = nn.MSELoss()\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"def soft_ce(logits, target_dist):\n",
|
| 537 |
+
" return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
|
| 538 |
+
"\n",
|
| 539 |
+
"def forward_batch(b):\n",
|
| 540 |
+
" feat_wavlm = wavlm_embed(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
|
| 541 |
+
" feat = torch.cat([feat_wavlm, b[\"aud\"].to(device)], dim=1) if USE_AUDEERING else feat_wavlm\n",
|
| 542 |
+
" return heads(feat, b[\"tgt\"].to(device))\n",
|
| 543 |
+
"\n",
|
| 544 |
+
"def compute_loss(emos_p, cat_l, vad_p, b):\n",
|
| 545 |
+
" L = {}\n",
|
| 546 |
+
" L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
|
| 547 |
+
" L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
|
| 548 |
+
" if HAS_VAD:\n",
|
| 549 |
+
" vt = b[\"vad\"].to(device)\n",
|
| 550 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
|
| 551 |
+
" else:\n",
|
| 552 |
+
" z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 553 |
+
" if USE_UNCERTAINTY:\n",
|
| 554 |
+
" return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
|
| 555 |
+
" return sum(L.values())\n",
|
| 556 |
+
"\n",
|
| 557 |
+
"@torch.no_grad()\n",
|
| 558 |
+
"def evaluate():\n",
|
| 559 |
+
" wavlm.eval(); heads.eval()\n",
|
| 560 |
+
" P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
|
| 561 |
+
" catP, catY = [], []\n",
|
| 562 |
+
" for b in va_loader:\n",
|
| 563 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 564 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 565 |
+
" P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
|
| 566 |
+
" vad_p = vad_p.float().cpu().numpy()\n",
|
| 567 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 568 |
+
" P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
|
| 569 |
+
" catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
|
| 570 |
+
" out = {}\n",
|
| 571 |
+
" for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
|
| 572 |
+
" out[t] = spearmanr(P[t], Y[t]).correlation\n",
|
| 573 |
+
" q = np.concatenate(catP); p = np.concatenate(catY)\n",
|
| 574 |
+
" out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
|
| 575 |
+
" return out\n",
|
| 576 |
+
"\n",
|
| 577 |
+
"def mean_srcc(m):\n",
|
| 578 |
+
" keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 579 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"# Init best TỪ checkpoint hiện tại → chỉ lưu nếu train tiếp TỐT HƠN\n",
|
| 582 |
+
"m0 = evaluate(); best = mean_srcc(m0)\n",
|
| 583 |
+
"best_state = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
|
| 584 |
+
" \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
|
| 585 |
+
"print(f\"📍 Checkpoint hiện tại: mean SRCC = {best:.4f} | \"\n",
|
| 586 |
+
" + \" \".join(f\"{k}={m0[k]:.3f}\" for k in ['emos','val','aro','dom'] if k in m0))\n",
|
| 587 |
+
"\n",
|
| 588 |
+
"bad = 0\n",
|
| 589 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 590 |
+
" wavlm.train(); heads.train()\n",
|
| 591 |
+
" opt.zero_grad(); run = 0.0; nb = 0\n",
|
| 592 |
+
" for step, b in enumerate(tqdm(tr_loader, desc=f\"+epoch {ep}\")):\n",
|
| 593 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 594 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 595 |
+
" loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
|
| 596 |
+
" scaler.scale(loss).backward()\n",
|
| 597 |
+
" if (step + 1) % ACCUM == 0:\n",
|
| 598 |
+
" scaler.step(opt); scaler.update(); opt.zero_grad()\n",
|
| 599 |
+
" run += loss.item() * ACCUM; nb += 1\n",
|
| 600 |
+
" m = evaluate(); sc = mean_srcc(m)\n",
|
| 601 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
|
| 602 |
+
" print(f\"+epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
|
| 603 |
+
" if sc > best:\n",
|
| 604 |
+
" best = sc\n",
|
| 605 |
+
" best_state = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
|
| 606 |
+
" \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
|
| 607 |
+
" bad = 0\n",
|
| 608 |
+
" else:\n",
|
| 609 |
+
" bad += 1\n",
|
| 610 |
+
" if bad >= PATIENCE:\n",
|
| 611 |
+
" print(f\"Early stop (resume) ở +epoch {ep}.\"); break\n",
|
| 612 |
+
"\n",
|
| 613 |
+
"wavlm.load_state_dict(best_state[\"wavlm\"]); heads.load_state_dict(best_state[\"heads\"])\n",
|
| 614 |
+
"final = evaluate()\n",
|
| 615 |
+
"print(\"\\n✅ VAL sau resume:\")\n",
|
| 616 |
+
"print(f\" EMOS={final['emos']:.4f} (ckpt {m0['emos']:.3f} · exp08 nộp {EXP08['emos']})\")\n",
|
| 617 |
+
"if HAS_VAD:\n",
|
| 618 |
+
" print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} \"\n",
|
| 619 |
+
" f\"(exp08 nộp {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
|
| 620 |
+
"print(f\" mean SRCC: ckpt {mean_srcc(m0):.4f} → sau resume {mean_srcc(final):.4f} \"\n",
|
| 621 |
+
" + (\"🚀 cải thiện\" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else \"➖ không cải thiện (giữ ckpt cũ)\"))\n",
|
| 622 |
+
"\n",
|
| 623 |
+
"torch.save({\"wavlm\": best_state[\"wavlm\"], \"heads\": best_state[\"heads\"],\n",
|
| 624 |
+
" \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
|
| 625 |
+
" \"WAVLM_DIM\": WAVLM_DIM, \"AUD_DIM\": AUD_DIM, \"UNFREEZE_TOP_LAYERS\": UNFREEZE_TOP_LAYERS,\n",
|
| 626 |
+
" \"val_emos\": final[\"emos\"]}, os.path.join(OUT_DIR, \"ft_emotion_full.pt\"))\n",
|
| 627 |
+
"print(\"Đã lưu FULL (có backbone):\", os.path.join(OUT_DIR, \"ft_emotion_full.pt\"))"
|
| 628 |
+
]
|
| 629 |
+
},
|
| 630 |
+
{
|
| 631 |
+
"cell_type": "markdown",
|
| 632 |
+
"id": "dd6bb5d8",
|
| 633 |
+
"metadata": {},
|
| 634 |
+
"source": [
|
| 635 |
+
"## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc từ resume; QMOS mượn exp07 hoặc UTMOSv2)"
|
| 636 |
+
]
|
| 637 |
+
},
|
| 638 |
+
{
|
| 639 |
+
"cell_type": "code",
|
| 640 |
+
"execution_count": null,
|
| 641 |
+
"id": "6b7753e8",
|
| 642 |
+
"metadata": {
|
| 643 |
+
"lines_to_next_cell": 1
|
| 644 |
+
},
|
| 645 |
+
"outputs": [],
|
| 646 |
+
"source": [
|
| 647 |
+
"def list_dev():\n",
|
| 648 |
+
" with open(DEV_SCP) as f:\n",
|
| 649 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 650 |
+
"\n",
|
| 651 |
+
"dev_names = list_dev()\n",
|
| 652 |
+
"if LIMIT_DEV:\n",
|
| 653 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 654 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 655 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 656 |
+
"aud_dev = extract_audeering(dev_stems, \"dev\")\n",
|
| 657 |
+
"\n",
|
| 658 |
+
"def load_exp07_qmos():\n",
|
| 659 |
+
" if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
|
| 660 |
+
" import csv\n",
|
| 661 |
+
" d = {}\n",
|
| 662 |
+
" with open(EXP07_ANSWER) as f:\n",
|
| 663 |
+
" for row in csv.DictReader(f):\n",
|
| 664 |
+
" d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
|
| 665 |
+
" print(f\"✅ Mượn QMOS từ exp07 ({EXP07_ANSWER}): {len(d)//2} wav\")\n",
|
| 666 |
+
" return d\n",
|
| 667 |
+
" return None\n",
|
| 668 |
+
"\n",
|
| 669 |
+
"qmos_map = load_exp07_qmos()\n",
|
| 670 |
+
"if qmos_map is None:\n",
|
| 671 |
+
" print(\"ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).\")\n",
|
| 672 |
+
" pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
|
| 673 |
+
" import utmosv2\n",
|
| 674 |
+
" v2 = utmosv2.create_model(pretrained=True)\n",
|
| 675 |
+
" qmos_map = {}\n",
|
| 676 |
+
" for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
|
| 677 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
|
| 678 |
+
" if not os.path.exists(wav):\n",
|
| 679 |
+
" continue\n",
|
| 680 |
+
" out = v2.predict(input_path=wav)\n",
|
| 681 |
+
" qmos_map[n] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
|
| 682 |
+
" del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"@torch.no_grad()\n",
|
| 685 |
+
"def predict_emotion(sid):\n",
|
| 686 |
+
" wave = load_wav(sid)\n",
|
| 687 |
+
" if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
|
| 688 |
+
" return None\n",
|
| 689 |
+
" wavlm.eval(); heads.eval()\n",
|
| 690 |
+
" iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 691 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
|
| 692 |
+
" tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
|
| 693 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 694 |
+
" fw = wavlm_embed(iv, am)\n",
|
| 695 |
+
" feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw\n",
|
| 696 |
+
" emos_p, cat_l, vad_p = heads(feat, tgt)\n",
|
| 697 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 698 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 699 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 700 |
+
" return emos, cat5, vad3\n",
|
| 701 |
+
"\n",
|
| 702 |
+
"def fmt_cat(p5):\n",
|
| 703 |
+
" return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 704 |
+
"\n",
|
| 705 |
+
"def build_answer(out_path):\n",
|
| 706 |
+
" n_real = n_def = 0\n",
|
| 707 |
+
" with open(out_path, \"w\") as f:\n",
|
| 708 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 709 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 710 |
+
" sid = stem(name)\n",
|
| 711 |
+
" pr = predict_emotion(sid)\n",
|
| 712 |
+
" if pr is None:\n",
|
| 713 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
|
| 714 |
+
" else:\n",
|
| 715 |
+
" emos, cat5, vad3 = pr; n_real += 1\n",
|
| 716 |
+
" qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
|
| 717 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 718 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
|
| 719 |
+
"\n",
|
| 720 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 721 |
+
"build_answer(answer_path)"
|
| 722 |
+
]
|
| 723 |
+
},
|
| 724 |
+
{
|
| 725 |
+
"cell_type": "markdown",
|
| 726 |
+
"id": "7dac208f",
|
| 727 |
+
"metadata": {},
|
| 728 |
+
"source": [
|
| 729 |
+
"## 8. Validate + zip"
|
| 730 |
+
]
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"cell_type": "code",
|
| 734 |
+
"execution_count": null,
|
| 735 |
+
"id": "c123c058",
|
| 736 |
+
"metadata": {},
|
| 737 |
+
"outputs": [],
|
| 738 |
+
"source": [
|
| 739 |
+
"def validate(path):\n",
|
| 740 |
+
" import csv\n",
|
| 741 |
+
" with open(path) as f:\n",
|
| 742 |
+
" rows = list(csv.reader(f))\n",
|
| 743 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0], \"Header sai\"\n",
|
| 744 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 745 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 746 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 747 |
+
"\n",
|
| 748 |
+
"validate(answer_path)\n",
|
| 749 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp08_resume.zip answer.txt && unzip -l submission_track2_exp08_resume.zip\")\n",
|
| 750 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp08_resume.zip\"))"
|
| 751 |
+
]
|
| 752 |
+
},
|
| 753 |
+
{
|
| 754 |
+
"cell_type": "markdown",
|
| 755 |
+
"id": "e5f82cf0",
|
| 756 |
+
"metadata": {},
|
| 757 |
+
"source": [
|
| 758 |
+
"## Ghi chú\n",
|
| 759 |
+
"- **Đầu vào bắt buộc:** `RESUME_CKPT` = `ft_emotion_full.pt` (CÓ backbone). Bản `ft_emotion_meta.pt` cũ chỉ\n",
|
| 760 |
+
" có heads → cell 2 sẽ assert lỗi nhắc dùng file đủ.\n",
|
| 761 |
+
"- **Cache:** trỏ `CACHE_INPUT` tới dataset chứa `aud_train.npz`/`aud_dev.npz` → khỏi trích lại audeering.\n",
|
| 762 |
+
" Nếu LIMIT khác lần trước, cache thiếu stem nào sẽ tự trích bù (resume theo stem).\n",
|
| 763 |
+
"- **Chuẩn hóa lấy TỪ checkpoint** (`emos_mu/sd`, `vad_mu/sd`) → khớp thang head đã train (đừng tính lại).\n",
|
| 764 |
+
"- **best init từ checkpoint** → chỉ lưu nếu train tiếp THỰC SỰ tốt hơn (không sợ tụt).\n",
|
| 765 |
+
"- Nếu val chững: đặt `RESUME_LR_SCALE=0.5` (giảm LR) hoặc tăng `UNFREEZE_TOP_LAYERS` (lưu ý: mở thêm lớp\n",
|
| 766 |
+
" thì lớp mới chưa được train trong checkpoint → cần nhiều epoch hơn).\n",
|
| 767 |
+
"- QMOS: tốt nhất Add Input `answer.txt` exp07 (0.548). Để trộn cột chuẩn, xem kết quả exp08: 5 cột cảm xúc\n",
|
| 768 |
+
" resume + QMOS exp07 → hệ thống mạnh nhất 6 cột.\n",
|
| 769 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md`."
|
| 770 |
+
]
|
| 771 |
+
}
|
| 772 |
+
],
|
| 773 |
+
"metadata": {
|
| 774 |
+
"jupytext": {
|
| 775 |
+
"cell_metadata_filter": "-all",
|
| 776 |
+
"main_language": "python",
|
| 777 |
+
"notebook_metadata_filter": "-all"
|
| 778 |
+
}
|
| 779 |
+
},
|
| 780 |
+
"nbformat": 4,
|
| 781 |
+
"nbformat_minor": 5
|
| 782 |
+
}
|
track2/exp08b_finetune_resume_pipeline.py
ADDED
|
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp08-RESUME (fine-tune TIẾP từ checkpoint + cache) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục đích:** train tiếp model fine-tune cảm xúc (exp08) từ **checkpoint đã lưu** thay vì train lại từ
|
| 5 |
+
# đầu — tiết kiệm giờ GPU. Tận dụng:
|
| 6 |
+
# - `ft_emotion_full.pt` (CÓ cả backbone WavLM + heads + thống kê chuẩn hóa) → nạp lại đúng trạng thái.
|
| 7 |
+
# - **cache audeering** `aud_*.npz` (đặc trưng frozen) → KHÔNG trích lại (~đỡ chục phút).
|
| 8 |
+
#
|
| 9 |
+
# > ⚠️ Bắt buộc dùng checkpoint **đủ backbone** (`ft_emotion_full.pt` từ cell "TRAIN TIẾP", hoặc bản
|
| 10 |
+
# > `ft_emotion_meta.pt` MỚI đã vá để lưu cả `wavlm`). Bản `ft_emotion_meta.pt` CŨ chỉ có `heads` → KHÔNG dùng được.
|
| 11 |
+
#
|
| 12 |
+
# ## Chuẩn bị input trên Kaggle (Add Input)
|
| 13 |
+
# 1. Dataset Track 2 (`vmc2026-track2-full`) — wav + nhãn.
|
| 14 |
+
# 2. **Checkpoint**: upload `ft_emotion_full.pt` thành 1 Dataset → trỏ `RESUME_CKPT`.
|
| 15 |
+
# 3. **Cache** (tùy chọn nhưng nên có): upload thư mục chứa `aud_train.npz`, `aud_dev.npz` → trỏ `CACHE_INPUT`.
|
| 16 |
+
# 4. (tùy chọn) `answer.txt` exp07 để mượn cột QMOS 0.548.
|
| 17 |
+
#
|
| 18 |
+
# **Cách chạy:** GPU T4 + Internet On → sửa các slug ở cell 0 → Run All. Lần đầu để `LIMIT_TRAIN=300`.
|
| 19 |
+
|
| 20 |
+
# %% [markdown]
|
| 21 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 22 |
+
|
| 23 |
+
# %%
|
| 24 |
+
import os, shutil
|
| 25 |
+
|
| 26 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
|
| 27 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 28 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv"
|
| 29 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv"
|
| 30 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 31 |
+
|
| 32 |
+
# ── Checkpoint + cache để RESUME ─────────────────────────────────────────────
|
| 33 |
+
RESUME_CKPT = "/kaggle/input/ft-emotion-full/ft_emotion_full.pt" # << CHECKPOINT đủ backbone
|
| 34 |
+
CACHE_INPUT = "/kaggle/input/ft-emotion-cache" # << thư mục chứa aud_*.npz (hoặc "" nếu không có)
|
| 35 |
+
EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) mượn QMOS 0.548; không có → UTMOSv2
|
| 36 |
+
|
| 37 |
+
OUT_DIR = "/kaggle/working"
|
| 38 |
+
CACHE_DIR = "/kaggle/working/ft_cache" # /kaggle/input read-only → copy cache sang đây để ghi/append được
|
| 39 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 40 |
+
|
| 41 |
+
# ── Fine-tune / siêu tham số (train TIẾP) ────────────────────────────────────
|
| 42 |
+
DEVICE = "cuda"
|
| 43 |
+
SR = 16000
|
| 44 |
+
MAX_SECONDS = 8
|
| 45 |
+
UNFREEZE_TOP_LAYERS = 6 # PHẢI khớp checkpoint (mặc định exp08 = 6)
|
| 46 |
+
TRUNK_HIDDEN = 512 # PHẢI khớp checkpoint
|
| 47 |
+
HEAD_HIDDEN = 128 # PHẢI khớp checkpoint
|
| 48 |
+
DROPOUT = 0.3
|
| 49 |
+
LR_BACKBONE = 1e-5
|
| 50 |
+
LR_HEAD = 1e-3
|
| 51 |
+
RESUME_LR_SCALE = 1.0 # <1.0 để giảm LR khi train tiếp (vd 0.5 nếu val đã chững)
|
| 52 |
+
WEIGHT_DECAY = 1e-5
|
| 53 |
+
EPOCHS = 10 # số epoch train THÊM (run này)
|
| 54 |
+
PATIENCE = 5 # dừng khi val không lên; LUÔN giữ best
|
| 55 |
+
BATCH = 4
|
| 56 |
+
ACCUM = 8 # effective batch = 32
|
| 57 |
+
VAL_FRAC = 0.10
|
| 58 |
+
SEED = 42
|
| 59 |
+
USE_AMP = True
|
| 60 |
+
USE_GRAD_CKPT = True
|
| 61 |
+
USE_AUDEERING = True # PHẢI khớp checkpoint (exp08 = True)
|
| 62 |
+
USE_UNCERTAINTY = True
|
| 63 |
+
|
| 64 |
+
LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
|
| 65 |
+
LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
|
| 66 |
+
|
| 67 |
+
# Mốc exp07 + exp08 để so
|
| 68 |
+
EXP07 = {"emos": 0.795, "cat_err": 0.153, "val": 0.581, "aro": 0.752, "dom": 0.705}
|
| 69 |
+
EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751} # bản đã nộp
|
| 70 |
+
|
| 71 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 72 |
+
_EMO_ALIAS = {
|
| 73 |
+
"angry": "angry", "anger": "angry",
|
| 74 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 75 |
+
"neutral": "neutral", "calm": "neutral",
|
| 76 |
+
"sad": "sad", "sadness": "sad",
|
| 77 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
def norm_emotion(label):
|
| 81 |
+
key = str(label).strip().lower()
|
| 82 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 83 |
+
|
| 84 |
+
def stem(p):
|
| 85 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 86 |
+
|
| 87 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 88 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP, RESUME_CKPT]:
|
| 89 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 90 |
+
|
| 91 |
+
# Copy cache (aud_*.npz) từ input read-only sang working để append được
|
| 92 |
+
if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
|
| 93 |
+
n = 0
|
| 94 |
+
for fn in os.listdir(CACHE_INPUT):
|
| 95 |
+
if fn.startswith("aud_") and fn.endswith(".npz"):
|
| 96 |
+
shutil.copy(os.path.join(CACHE_INPUT, fn), os.path.join(CACHE_DIR, fn)); n += 1
|
| 97 |
+
print(f"📦 Copy {n} file cache audeering từ {CACHE_INPUT} → {CACHE_DIR}")
|
| 98 |
+
else:
|
| 99 |
+
print("ℹ️ Không có CACHE_INPUT → sẽ tự trích audeering (chậm hơn lần đầu).")
|
| 100 |
+
|
| 101 |
+
# %% [markdown]
|
| 102 |
+
# ## 1. Cài đặt + tải code SAILER (để dựng đúng kiến trúc WavLM rồi nạp checkpoint đè lên)
|
| 103 |
+
|
| 104 |
+
# %%
|
| 105 |
+
import sys, subprocess
|
| 106 |
+
|
| 107 |
+
def pip_install(*pkgs):
|
| 108 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 109 |
+
|
| 110 |
+
pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile",
|
| 111 |
+
"scipy", "scikit-learn", "pandas", "tqdm")
|
| 112 |
+
|
| 113 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 114 |
+
if not os.path.exists(REPO_DIR):
|
| 115 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 116 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 117 |
+
if REPO_DIR not in sys.path:
|
| 118 |
+
sys.path.insert(0, REPO_DIR)
|
| 119 |
+
|
| 120 |
+
# %% [markdown]
|
| 121 |
+
# ## 2. Dựng WavLM (như exp08) → NẠP trọng số backbone từ checkpoint
|
| 122 |
+
# Dựng đúng kiến trúc (SAILER wrapper → lấy backbone HF; fallback WavLM trắng), rồi `load_state_dict`
|
| 123 |
+
# bằng `ckpt["wavlm"]` → khôi phục đúng trạng thái fine-tune đã lưu.
|
| 124 |
+
|
| 125 |
+
# %%
|
| 126 |
+
import torch
|
| 127 |
+
import torch.nn as nn
|
| 128 |
+
import torch.nn.functional as F
|
| 129 |
+
|
| 130 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 131 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
|
| 132 |
+
|
| 133 |
+
ckpt = torch.load(RESUME_CKPT, map_location="cpu", weights_only=False) # ckpt có numpy (vad_mu) → cần False
|
| 134 |
+
assert "wavlm" in ckpt, ("❌ Checkpoint KHÔNG có 'wavlm' (backbone). Đây là bản ft_emotion_meta.pt CŨ "
|
| 135 |
+
"chỉ lưu heads → không resume được. Hãy dùng ft_emotion_full.pt.")
|
| 136 |
+
print("✅ Nạp checkpoint:", RESUME_CKPT, "| keys:", list(ckpt.keys()))
|
| 137 |
+
|
| 138 |
+
def find_hf_backbone(module):
|
| 139 |
+
cands = []
|
| 140 |
+
for name, m in module.named_modules():
|
| 141 |
+
enc = getattr(m, "encoder", None)
|
| 142 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 143 |
+
and getattr(enc, "layers", None) is not None:
|
| 144 |
+
cands.append((name, m))
|
| 145 |
+
if not cands:
|
| 146 |
+
return None, None
|
| 147 |
+
cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
|
| 148 |
+
return cands[0]
|
| 149 |
+
|
| 150 |
+
wavlm = None
|
| 151 |
+
try:
|
| 152 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 153 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 154 |
+
name, wavlm = find_hf_backbone(_wrapper)
|
| 155 |
+
if wavlm is not None:
|
| 156 |
+
print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'")
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
|
| 159 |
+
|
| 160 |
+
if wavlm is None:
|
| 161 |
+
from transformers import WavLMModel
|
| 162 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 163 |
+
print("ℹ️ Fallback: microsoft/wavlm-large.")
|
| 164 |
+
|
| 165 |
+
wavlm = wavlm.to(device)
|
| 166 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 167 |
+
|
| 168 |
+
# Nạp trọng số đã fine-tune từ checkpoint (đè lên kiến trúc vừa dựng)
|
| 169 |
+
miss, unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
|
| 170 |
+
print(f"🔁 load wavlm từ checkpoint: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).")
|
| 171 |
+
if len(miss) > 20 or len(unexp) > 20:
|
| 172 |
+
print(" ⚠️ Lệch key nhiều → kiến trúc có thể không khớp checkpoint. Kiểm tra UNFREEZE/USE_AUDEERING.")
|
| 173 |
+
|
| 174 |
+
# Đóng băng partial: chỉ mở UNFREEZE_TOP_LAYERS lớp trên
|
| 175 |
+
for p in wavlm.parameters():
|
| 176 |
+
p.requires_grad = False
|
| 177 |
+
enc_layers = wavlm.encoder.layers
|
| 178 |
+
n_layers = len(enc_layers)
|
| 179 |
+
for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:
|
| 180 |
+
for p in layer.parameters():
|
| 181 |
+
p.requires_grad = True
|
| 182 |
+
n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)
|
| 183 |
+
print(f"WavLM: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {WAVLM_DIM})")
|
| 184 |
+
|
| 185 |
+
if USE_GRAD_CKPT:
|
| 186 |
+
wavlm.gradient_checkpointing_enable()
|
| 187 |
+
if hasattr(wavlm, "enable_input_require_grads"):
|
| 188 |
+
wavlm.enable_input_require_grads()
|
| 189 |
+
|
| 190 |
+
def masked_mean(hidden, attn_mask):
|
| 191 |
+
if attn_mask is None:
|
| 192 |
+
return hidden.mean(dim=1)
|
| 193 |
+
try:
|
| 194 |
+
fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
|
| 195 |
+
except Exception:
|
| 196 |
+
return hidden.mean(dim=1)
|
| 197 |
+
fm = fm.unsqueeze(-1).to(hidden.dtype)
|
| 198 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 199 |
+
|
| 200 |
+
def wavlm_embed(input_values, attn_mask):
|
| 201 |
+
out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
|
| 202 |
+
return masked_mean(out, attn_mask)
|
| 203 |
+
|
| 204 |
+
# %% [markdown]
|
| 205 |
+
# ## 3. audeering FROZEN (đặc trưng phụ) — dùng cache nếu có
|
| 206 |
+
|
| 207 |
+
# %%
|
| 208 |
+
import numpy as np
|
| 209 |
+
import librosa
|
| 210 |
+
from tqdm.auto import tqdm
|
| 211 |
+
|
| 212 |
+
AUD_DIM = 0
|
| 213 |
+
aud_backbone = aud_head = aud_proc = None
|
| 214 |
+
if USE_AUDEERING:
|
| 215 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 216 |
+
from huggingface_hub import hf_hub_download
|
| 217 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 218 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 219 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 220 |
+
aud_backbone = Wav2Vec2Model(aud_cfg)
|
| 221 |
+
try:
|
| 222 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 223 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 224 |
+
except Exception:
|
| 225 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 226 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 227 |
+
aud_backbone.load_state_dict(bb_sd, strict=False)
|
| 228 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 229 |
+
_out = _sd["classifier.out_proj.weight"].shape[0]
|
| 230 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
|
| 231 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 232 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 233 |
+
aud_backbone = aud_backbone.to(device).eval()
|
| 234 |
+
aud_head = aud_head.to(device).eval()
|
| 235 |
+
AUD_DIM = _hid + 3
|
| 236 |
+
print(f"✅ audeering frozen ({AUD_DIM}-D)")
|
| 237 |
+
|
| 238 |
+
def load_wav(name_or_stem):
|
| 239 |
+
p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(
|
| 240 |
+
WAV_DIR, name_or_stem if str(name_or_stem).endswith(".wav") else str(name_or_stem) + ".wav")
|
| 241 |
+
if not os.path.exists(p):
|
| 242 |
+
return None
|
| 243 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 244 |
+
return wave[: MAX_SECONDS * SR].astype(np.float32)
|
| 245 |
+
|
| 246 |
+
@torch.no_grad()
|
| 247 |
+
def extract_audeering(stems, tag):
|
| 248 |
+
if not USE_AUDEERING:
|
| 249 |
+
return {}
|
| 250 |
+
cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
|
| 251 |
+
store = {}
|
| 252 |
+
if os.path.exists(cache_path):
|
| 253 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 254 |
+
store = {k: z[k] for k in z.files}
|
| 255 |
+
print(f"[aud/{tag}] nạp cache: {len(store)}")
|
| 256 |
+
todo = [s for s in stems if s not in store]
|
| 257 |
+
for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
|
| 258 |
+
wave = load_wav(s)
|
| 259 |
+
if wave is None:
|
| 260 |
+
continue
|
| 261 |
+
x = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 262 |
+
x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
|
| 263 |
+
h = aud_backbone(x)[0].mean(dim=1)
|
| 264 |
+
out = aud_head(h)[0].cpu().numpy()
|
| 265 |
+
vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32)
|
| 266 |
+
store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
|
| 267 |
+
if (i + 1) % 500 == 0:
|
| 268 |
+
np.savez(cache_path, **store)
|
| 269 |
+
if todo:
|
| 270 |
+
np.savez(cache_path, **store)
|
| 271 |
+
return store
|
| 272 |
+
|
| 273 |
+
# %% [markdown]
|
| 274 |
+
# ## 4. Đọc & gộp nhãn theo wavID
|
| 275 |
+
|
| 276 |
+
# %%
|
| 277 |
+
import pandas as pd
|
| 278 |
+
|
| 279 |
+
def load_target_emotions():
|
| 280 |
+
tgt = {}
|
| 281 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 282 |
+
for ln in f:
|
| 283 |
+
parts = ln.strip().split("|")
|
| 284 |
+
if len(parts) >= 2:
|
| 285 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 286 |
+
return tgt
|
| 287 |
+
|
| 288 |
+
def _col(cols_map, *names, df=None, default_idx=None):
|
| 289 |
+
for n in names:
|
| 290 |
+
if n in cols_map:
|
| 291 |
+
return cols_map[n]
|
| 292 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 293 |
+
|
| 294 |
+
def parse_emocat_votes(cell):
|
| 295 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 296 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 297 |
+
e = norm_emotion(tok)
|
| 298 |
+
if e in EMOTIONS5:
|
| 299 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 300 |
+
return v
|
| 301 |
+
|
| 302 |
+
def load_train_labels():
|
| 303 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 304 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 305 |
+
wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
|
| 306 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 307 |
+
val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
|
| 308 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 309 |
+
assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
|
| 310 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 311 |
+
rows = []
|
| 312 |
+
for sid, g in df.groupby("_stem"):
|
| 313 |
+
rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
|
| 314 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 315 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 316 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 317 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 318 |
+
if cat_col:
|
| 319 |
+
for cell in g[cat_col]:
|
| 320 |
+
votes += parse_emocat_votes(cell)
|
| 321 |
+
s = votes.sum()
|
| 322 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
|
| 323 |
+
for i in range(len(EMOTIONS5)):
|
| 324 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 325 |
+
rows.append(rec)
|
| 326 |
+
return pd.DataFrame(rows)
|
| 327 |
+
|
| 328 |
+
target_map = load_target_emotions()
|
| 329 |
+
train_df = load_train_labels()
|
| 330 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 331 |
+
print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
|
| 332 |
+
|
| 333 |
+
# %% [markdown]
|
| 334 |
+
# ## 5. Dataset/loader — DÙNG thống kê chuẩn hóa TỪ CHECKPOINT (để khớp head đã train)
|
| 335 |
+
|
| 336 |
+
# %%
|
| 337 |
+
from torch.utils.data import Dataset, DataLoader
|
| 338 |
+
from sklearn.model_selection import train_test_split
|
| 339 |
+
|
| 340 |
+
train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
|
| 341 |
+
if LIMIT_TRAIN:
|
| 342 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 343 |
+
aud_tr = extract_audeering(train_stems, "train")
|
| 344 |
+
|
| 345 |
+
lab = train_df.set_index("wavID")
|
| 346 |
+
|
| 347 |
+
# QUAN TRỌNG: lấy mean/std từ checkpoint (head đã train theo thang này) thay vì tính lại.
|
| 348 |
+
emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
|
| 349 |
+
vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
|
| 350 |
+
print(f"Dùng chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
|
| 351 |
+
|
| 352 |
+
def onehot_target(tgt):
|
| 353 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 354 |
+
if tgt in EMOTIONS5:
|
| 355 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 356 |
+
return v
|
| 357 |
+
|
| 358 |
+
class EmoDataset(Dataset):
|
| 359 |
+
def __init__(self, stems):
|
| 360 |
+
self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]
|
| 361 |
+
def __len__(self):
|
| 362 |
+
return len(self.stems)
|
| 363 |
+
def __getitem__(self, i):
|
| 364 |
+
s = self.stems[i]
|
| 365 |
+
wave = load_wav(s)
|
| 366 |
+
emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
|
| 367 |
+
if HAS_VAD:
|
| 368 |
+
vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
|
| 369 |
+
else:
|
| 370 |
+
vad = np.zeros(3, dtype=np.float32)
|
| 371 |
+
cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
|
| 372 |
+
aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)
|
| 373 |
+
return {"wave": wave, "tgt": onehot_target(target_map.get(s)), "aud": aud,
|
| 374 |
+
"emos": np.float32(emos), "vad": vad, "cat": cat,
|
| 375 |
+
"emos_raw": np.float32(lab.loc[s, "emos"]),
|
| 376 |
+
"vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
|
| 377 |
+
|
| 378 |
+
def collate(batch):
|
| 379 |
+
lens = [len(b["wave"]) for b in batch]
|
| 380 |
+
L = max(lens)
|
| 381 |
+
waves = np.zeros((len(batch), L), dtype=np.float32)
|
| 382 |
+
mask = np.zeros((len(batch), L), dtype=np.float32)
|
| 383 |
+
for i, b in enumerate(batch):
|
| 384 |
+
waves[i, : len(b["wave"])] = b["wave"]; mask[i, : len(b["wave"])] = 1.0
|
| 385 |
+
return {
|
| 386 |
+
"input_values": torch.from_numpy(waves), "attn_mask": torch.from_numpy(mask).long(),
|
| 387 |
+
"tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
|
| 388 |
+
"aud": torch.from_numpy(np.stack([b["aud"] for b in batch])) if USE_AUDEERING else None,
|
| 389 |
+
"emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
|
| 390 |
+
"vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
|
| 391 |
+
"cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
|
| 392 |
+
"emos_raw": np.stack([b["emos_raw"] for b in batch]),
|
| 393 |
+
"vad_raw": np.stack([b["vad_raw"] for b in batch]),
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
ds = EmoDataset(train_stems)
|
| 397 |
+
print("Dataset hợp lệ:", len(ds), "wav")
|
| 398 |
+
tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
|
| 399 |
+
tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
|
| 400 |
+
va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
|
| 401 |
+
|
| 402 |
+
# %% [markdown]
|
| 403 |
+
# ## 6. Heads (NẠP từ checkpoint) + optimizer + train TIẾP
|
| 404 |
+
|
| 405 |
+
# %%
|
| 406 |
+
from scipy.stats import spearmanr
|
| 407 |
+
|
| 408 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 409 |
+
N_EMO = len(EMOTIONS5)
|
| 410 |
+
TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
|
| 411 |
+
|
| 412 |
+
class EmoHeads(nn.Module):
|
| 413 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 414 |
+
super().__init__()
|
| 415 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 416 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 417 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 418 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 419 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 420 |
+
def forward(self, feat, tgt):
|
| 421 |
+
h = self.trunk(feat)
|
| 422 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 423 |
+
|
| 424 |
+
heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
|
| 425 |
+
hmiss, hunexp = heads.load_state_dict(ckpt["heads"], strict=False)
|
| 426 |
+
print(f"🔁 load heads từ checkpoint: thiếu {len(hmiss)} / dư {len(hunexp)} key (kỳ vọng 0).")
|
| 427 |
+
print(f"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM if USE_AUDEERING else 0})")
|
| 428 |
+
|
| 429 |
+
TASKS = ["emos", "cat", "val", "aro", "dom"]
|
| 430 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 431 |
+
bb_params = [p for p in wavlm.parameters() if p.requires_grad]
|
| 432 |
+
head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
|
| 433 |
+
opt = torch.optim.AdamW([
|
| 434 |
+
{"params": bb_params, "lr": LR_BACKBONE * RESUME_LR_SCALE},
|
| 435 |
+
{"params": head_params, "lr": LR_HEAD * RESUME_LR_SCALE},
|
| 436 |
+
], weight_decay=WEIGHT_DECAY)
|
| 437 |
+
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
|
| 438 |
+
mse = nn.MSELoss()
|
| 439 |
+
|
| 440 |
+
def soft_ce(logits, target_dist):
|
| 441 |
+
return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
|
| 442 |
+
|
| 443 |
+
def forward_batch(b):
|
| 444 |
+
feat_wavlm = wavlm_embed(b["input_values"].to(device), b["attn_mask"].to(device))
|
| 445 |
+
feat = torch.cat([feat_wavlm, b["aud"].to(device)], dim=1) if USE_AUDEERING else feat_wavlm
|
| 446 |
+
return heads(feat, b["tgt"].to(device))
|
| 447 |
+
|
| 448 |
+
def compute_loss(emos_p, cat_l, vad_p, b):
|
| 449 |
+
L = {}
|
| 450 |
+
L["emos"] = mse(emos_p, b["emos"].to(device))
|
| 451 |
+
L["cat"] = soft_ce(cat_l, b["cat"].to(device))
|
| 452 |
+
if HAS_VAD:
|
| 453 |
+
vt = b["vad"].to(device)
|
| 454 |
+
L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
|
| 455 |
+
else:
|
| 456 |
+
z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
|
| 457 |
+
if USE_UNCERTAINTY:
|
| 458 |
+
return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
|
| 459 |
+
return sum(L.values())
|
| 460 |
+
|
| 461 |
+
@torch.no_grad()
|
| 462 |
+
def evaluate():
|
| 463 |
+
wavlm.eval(); heads.eval()
|
| 464 |
+
P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
|
| 465 |
+
catP, catY = [], []
|
| 466 |
+
for b in va_loader:
|
| 467 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 468 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 469 |
+
P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
|
| 470 |
+
vad_p = vad_p.float().cpu().numpy()
|
| 471 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 472 |
+
P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
|
| 473 |
+
catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
|
| 474 |
+
out = {}
|
| 475 |
+
for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
|
| 476 |
+
out[t] = spearmanr(P[t], Y[t]).correlation
|
| 477 |
+
q = np.concatenate(catP); p = np.concatenate(catY)
|
| 478 |
+
out["cat_err"] = float(np.abs(q - p).sum(1).mean())
|
| 479 |
+
return out
|
| 480 |
+
|
| 481 |
+
def mean_srcc(m):
|
| 482 |
+
keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 483 |
+
return float(np.mean([m[k] for k in keys]))
|
| 484 |
+
|
| 485 |
+
# Init best TỪ checkpoint hiện tại → chỉ lưu nếu train tiếp TỐT HƠN
|
| 486 |
+
m0 = evaluate(); best = mean_srcc(m0)
|
| 487 |
+
best_state = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
|
| 488 |
+
"heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
|
| 489 |
+
print(f"📍 Checkpoint hiện tại: mean SRCC = {best:.4f} | "
|
| 490 |
+
+ " ".join(f"{k}={m0[k]:.3f}" for k in ['emos','val','aro','dom'] if k in m0))
|
| 491 |
+
|
| 492 |
+
bad = 0
|
| 493 |
+
for ep in range(1, EPOCHS + 1):
|
| 494 |
+
wavlm.train(); heads.train()
|
| 495 |
+
opt.zero_grad(); run = 0.0; nb = 0
|
| 496 |
+
for step, b in enumerate(tqdm(tr_loader, desc=f"+epoch {ep}")):
|
| 497 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 498 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 499 |
+
loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
|
| 500 |
+
scaler.scale(loss).backward()
|
| 501 |
+
if (step + 1) % ACCUM == 0:
|
| 502 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 503 |
+
run += loss.item() * ACCUM; nb += 1
|
| 504 |
+
m = evaluate(); sc = mean_srcc(m)
|
| 505 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
|
| 506 |
+
print(f"+epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
|
| 507 |
+
if sc > best:
|
| 508 |
+
best = sc
|
| 509 |
+
best_state = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
|
| 510 |
+
"heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
|
| 511 |
+
bad = 0
|
| 512 |
+
else:
|
| 513 |
+
bad += 1
|
| 514 |
+
if bad >= PATIENCE:
|
| 515 |
+
print(f"Early stop (resume) ở +epoch {ep}."); break
|
| 516 |
+
|
| 517 |
+
wavlm.load_state_dict(best_state["wavlm"]); heads.load_state_dict(best_state["heads"])
|
| 518 |
+
final = evaluate()
|
| 519 |
+
print("\n✅ VAL sau resume:")
|
| 520 |
+
print(f" EMOS={final['emos']:.4f} (ckpt {m0['emos']:.3f} · exp08 nộp {EXP08['emos']})")
|
| 521 |
+
if HAS_VAD:
|
| 522 |
+
print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} "
|
| 523 |
+
f"(exp08 nộp {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
|
| 524 |
+
print(f" mean SRCC: ckpt {mean_srcc(m0):.4f} → sau resume {mean_srcc(final):.4f} "
|
| 525 |
+
+ ("🚀 cải thiện" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else "➖ không cải thiện (giữ ckpt cũ)"))
|
| 526 |
+
|
| 527 |
+
torch.save({"wavlm": best_state["wavlm"], "heads": best_state["heads"],
|
| 528 |
+
"emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
|
| 529 |
+
"WAVLM_DIM": WAVLM_DIM, "AUD_DIM": AUD_DIM, "UNFREEZE_TOP_LAYERS": UNFREEZE_TOP_LAYERS,
|
| 530 |
+
"val_emos": final["emos"]}, os.path.join(OUT_DIR, "ft_emotion_full.pt"))
|
| 531 |
+
print("Đã lưu FULL (có backbone):", os.path.join(OUT_DIR, "ft_emotion_full.pt"))
|
| 532 |
+
|
| 533 |
+
# %% [markdown]
|
| 534 |
+
# ## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc từ resume; QMOS mượn exp07 hoặc UTMOSv2)
|
| 535 |
+
|
| 536 |
+
# %%
|
| 537 |
+
def list_dev():
|
| 538 |
+
with open(DEV_SCP) as f:
|
| 539 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 540 |
+
|
| 541 |
+
dev_names = list_dev()
|
| 542 |
+
if LIMIT_DEV:
|
| 543 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 544 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 545 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 546 |
+
aud_dev = extract_audeering(dev_stems, "dev")
|
| 547 |
+
|
| 548 |
+
def load_exp07_qmos():
|
| 549 |
+
if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
|
| 550 |
+
import csv
|
| 551 |
+
d = {}
|
| 552 |
+
with open(EXP07_ANSWER) as f:
|
| 553 |
+
for row in csv.DictReader(f):
|
| 554 |
+
d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
|
| 555 |
+
print(f"✅ Mượn QMOS từ exp07 ({EXP07_ANSWER}): {len(d)//2} wav")
|
| 556 |
+
return d
|
| 557 |
+
return None
|
| 558 |
+
|
| 559 |
+
qmos_map = load_exp07_qmos()
|
| 560 |
+
if qmos_map is None:
|
| 561 |
+
print("ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).")
|
| 562 |
+
pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
|
| 563 |
+
import utmosv2
|
| 564 |
+
v2 = utmosv2.create_model(pretrained=True)
|
| 565 |
+
qmos_map = {}
|
| 566 |
+
for n in tqdm(dev_names, desc="UTMOSv2"):
|
| 567 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
|
| 568 |
+
if not os.path.exists(wav):
|
| 569 |
+
continue
|
| 570 |
+
out = v2.predict(input_path=wav)
|
| 571 |
+
qmos_map[n] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
|
| 572 |
+
del v2; torch.cuda.empty_cache() if device == "cuda" else None
|
| 573 |
+
|
| 574 |
+
@torch.no_grad()
|
| 575 |
+
def predict_emotion(sid):
|
| 576 |
+
wave = load_wav(sid)
|
| 577 |
+
if wave is None or (USE_AUDEERING and sid not in aud_dev):
|
| 578 |
+
return None
|
| 579 |
+
wavlm.eval(); heads.eval()
|
| 580 |
+
iv = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 581 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
|
| 582 |
+
tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
|
| 583 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 584 |
+
fw = wavlm_embed(iv, am)
|
| 585 |
+
feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw
|
| 586 |
+
emos_p, cat_l, vad_p = heads(feat, tgt)
|
| 587 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 588 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 589 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 590 |
+
return emos, cat5, vad3
|
| 591 |
+
|
| 592 |
+
def fmt_cat(p5):
|
| 593 |
+
return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 594 |
+
|
| 595 |
+
def build_answer(out_path):
|
| 596 |
+
n_real = n_def = 0
|
| 597 |
+
with open(out_path, "w") as f:
|
| 598 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 599 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 600 |
+
sid = stem(name)
|
| 601 |
+
pr = predict_emotion(sid)
|
| 602 |
+
if pr is None:
|
| 603 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
|
| 604 |
+
else:
|
| 605 |
+
emos, cat5, vad3 = pr; n_real += 1
|
| 606 |
+
qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
|
| 607 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 608 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
|
| 609 |
+
|
| 610 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 611 |
+
build_answer(answer_path)
|
| 612 |
+
|
| 613 |
+
# %% [markdown]
|
| 614 |
+
# ## 8. Validate + zip
|
| 615 |
+
|
| 616 |
+
# %%
|
| 617 |
+
def validate(path):
|
| 618 |
+
import csv
|
| 619 |
+
with open(path) as f:
|
| 620 |
+
rows = list(csv.reader(f))
|
| 621 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0], "Header sai"
|
| 622 |
+
for i, r in enumerate(rows[1:], 2):
|
| 623 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 624 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 625 |
+
|
| 626 |
+
validate(answer_path)
|
| 627 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp08_resume.zip answer.txt && unzip -l submission_track2_exp08_resume.zip")
|
| 628 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp08_resume.zip"))
|
| 629 |
+
|
| 630 |
+
# %% [markdown]
|
| 631 |
+
# ## Ghi chú
|
| 632 |
+
# - **Đầu vào bắt buộc:** `RESUME_CKPT` = `ft_emotion_full.pt` (CÓ backbone). Bản `ft_emotion_meta.pt` cũ chỉ
|
| 633 |
+
# có heads → cell 2 sẽ assert lỗi nhắc dùng file đủ.
|
| 634 |
+
# - **Cache:** trỏ `CACHE_INPUT` tới dataset chứa `aud_train.npz`/`aud_dev.npz` → khỏi trích lại audeering.
|
| 635 |
+
# Nếu LIMIT khác lần trước, cache thiếu stem nào sẽ tự trích bù (resume theo stem).
|
| 636 |
+
# - **Chuẩn hóa lấy TỪ checkpoint** (`emos_mu/sd`, `vad_mu/sd`) → khớp thang head đã train (đừng tính lại).
|
| 637 |
+
# - **best init từ checkpoint** → chỉ lưu nếu train tiếp THỰC SỰ tốt hơn (không sợ tụt).
|
| 638 |
+
# - Nếu val chững: đặt `RESUME_LR_SCALE=0.5` (giảm LR) hoặc tăng `UNFREEZE_TOP_LAYERS` (lưu ý: mở thêm lớp
|
| 639 |
+
# thì lớp mới chưa được train trong checkpoint → cần nhiều epoch hơn).
|
| 640 |
+
# - QMOS: tốt nhất Add Input `answer.txt` exp07 (0.548). Để trộn cột chuẩn, xem kết quả exp08: 5 cột cảm xúc
|
| 641 |
+
# resume + QMOS exp07 → hệ thống mạnh nhất 6 cột.
|
| 642 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md`.
|
track2/exp09a_qmos_utmosv2_probe.ipynb
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "0afb9ac3",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp09a (PROBE: UTMOSv2 vs UTMOS cho QMOS) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục đích (rẻ, KHÔNG tốn lượt nộp):** trước khi fine-tune QMOS, kiểm tra xem\n",
|
| 11 |
+
"**UTMOSv2** (hệ thống **T05 — vô địch VoiceMOS Challenge 2024 Track 1**, naturalness MOS)\n",
|
| 12 |
+
"có **mạnh hơn UTMOS 2022** (đang dùng) trên dữ liệu Track 2 hay không.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"## Ý tưởng A/B không tốn lượt nộp\n",
|
| 15 |
+
"Tập **train** Track 2 CÓ nhãn `qMOS` thật (`sets/train.csv`). Ta:\n",
|
| 16 |
+
"1. Chấm một mẫu train bằng **UTMOS** (torch.hub `utmos22_strong`) — baseline đang dùng.\n",
|
| 17 |
+
"2. Chấm cùng mẫu đó bằng **UTMOSv2** (`sarulab-speech/UTMOSv2`, MIT).\n",
|
| 18 |
+
"3. So **SRCC mỗi model vs nhãn qMOS vàng** → biết model nào \"xếp hạng\" giống người chấm hơn.\n",
|
| 19 |
+
"\n",
|
| 20 |
+
"> SRCC chấm **thứ hạng** (scale-invariant) → khỏi lo lệch thang điểm. Mẫu ~2.000 wav là đủ ổn định.\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"## Vì sao đáng thử\n",
|
| 23 |
+
"- UTMOSv2 = #1 ở 7/16 metric VMC2024 Track 1 (bỏ xa hạng 3) → bản kế nhiệm trực tiếp của UTMOS.\n",
|
| 24 |
+
"- **Lưu ý:** UTMOSv2 cũng train trên giọng *không* cảm xúc → vẫn có thể lệch domain; A/B này để\n",
|
| 25 |
+
" biết nó có **đáng** làm \"neo\" mạnh hơn cho head QMOS fine-tune (exp09) hay không.\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"**Cách chạy:** GPU T4 + **Internet On** (UTMOSv2 cài từ git + tải checkpoint) → Add Input dataset\n",
|
| 28 |
+
"Track 2 → sửa `DATA_ROOT` → Run All. Lần đầu để `PROBE_N=300` cho nhanh, OK rồi tăng `2000`."
|
| 29 |
+
]
|
| 30 |
+
},
|
| 31 |
+
{
|
| 32 |
+
"cell_type": "markdown",
|
| 33 |
+
"id": "c4f6fdc3",
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"source": [
|
| 36 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"execution_count": null,
|
| 42 |
+
"id": "df72dc96",
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"import os\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
|
| 49 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 50 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
|
| 51 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 54 |
+
"CACHE_DIR = \"/kaggle/working/qmos_probe_cache\"\n",
|
| 55 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 56 |
+
"\n",
|
| 57 |
+
"DEVICE = \"cuda\"\n",
|
| 58 |
+
"PROBE_N = 2000 # số wav train để A/B (lần đầu để 300 cho nhanh). SRCC ~2000 mẫu đã ổn định.\n",
|
| 59 |
+
"SEED = 42\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"# (Tùy chọn) Nếu muốn TẠO LUÔN answer.txt đổi cột QMOS←UTMOSv2 để nộp xác nhận trên DEV:\n",
|
| 62 |
+
"# trỏ tới answer.txt của exp07 (giữ nguyên 5 cột cảm xúc, chỉ thay QMOS).\n",
|
| 63 |
+
"# Để None nếu chỉ muốn chạy A/B nội bộ.\n",
|
| 64 |
+
"EXP07_ANSWER = None # ví dụ: \"/kaggle/input/exp07-answer/answer.txt\"\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"def stem(p):\n",
|
| 67 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 68 |
+
"\n",
|
| 69 |
+
"for p in [WAV_DIR, TRAIN_CSV, DEV_SCP]:\n",
|
| 70 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "markdown",
|
| 75 |
+
"id": "c18d2e88",
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"source": [
|
| 78 |
+
"## 1. Cài đặt (UTMOS + UTMOSv2)"
|
| 79 |
+
]
|
| 80 |
+
},
|
| 81 |
+
{
|
| 82 |
+
"cell_type": "code",
|
| 83 |
+
"execution_count": null,
|
| 84 |
+
"id": "c21a3cb0",
|
| 85 |
+
"metadata": {},
|
| 86 |
+
"outputs": [],
|
| 87 |
+
"source": [
|
| 88 |
+
"import sys, subprocess\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"def pip_install(*pkgs):\n",
|
| 91 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 92 |
+
"\n",
|
| 93 |
+
"pip_install(\"speechmos\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
|
| 94 |
+
"# UTMOSv2 (T05) — cài từ git, cần Internet On. Checkpoint tự tải lần đầu.\n",
|
| 95 |
+
"pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")"
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"cell_type": "markdown",
|
| 100 |
+
"id": "ad56ee86",
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"source": [
|
| 103 |
+
"## 2. Nhãn qMOS vàng (gộp trung bình theo wav)"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "code",
|
| 108 |
+
"execution_count": null,
|
| 109 |
+
"id": "66c28d52",
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"import numpy as np\n",
|
| 114 |
+
"import pandas as pd\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"def load_qmos_labels():\n",
|
| 117 |
+
" \"\"\"train.csv (sep '|') → dict {stem: qMOS trung bình theo wav}.\"\"\"\n",
|
| 118 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 119 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 120 |
+
" wav_col = cols.get(\"wavid\") or cols.get(\"wav\") or list(df.columns)[1]\n",
|
| 121 |
+
" qmos_col = cols.get(\"qmos\") or cols.get(\"mos\")\n",
|
| 122 |
+
" assert qmos_col, f\"Không thấy cột qMOS (cột: {list(df.columns)})\"\n",
|
| 123 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 124 |
+
" g = df.groupby(\"_stem\")[qmos_col].mean()\n",
|
| 125 |
+
" return {s: float(v) for s, v in g.items()}\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"qmos_gold = load_qmos_labels()\n",
|
| 128 |
+
"print(f\"Số wav train có nhãn qMOS: {len(qmos_gold)}\")\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"# Chọn mẫu probe (chỉ giữ wav thật sự tồn tại trên đĩa)\n",
|
| 131 |
+
"rng = np.random.default_rng(SEED)\n",
|
| 132 |
+
"all_stems = [s for s in qmos_gold if os.path.exists(os.path.join(WAV_DIR, s + \".wav\"))]\n",
|
| 133 |
+
"rng.shuffle(all_stems)\n",
|
| 134 |
+
"probe_stems = all_stems[:PROBE_N]\n",
|
| 135 |
+
"print(f\"Mẫu probe: {len(probe_stems)} / {len(all_stems)} wav tồn tại\")"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "markdown",
|
| 140 |
+
"id": "13d793d1",
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"source": [
|
| 143 |
+
"## 3. Hàm chấm: UTMOS (cũ) và UTMOSv2 (mới) — đều cache .npz"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": null,
|
| 149 |
+
"id": "05dd3e90",
|
| 150 |
+
"metadata": {
|
| 151 |
+
"lines_to_next_cell": 1
|
| 152 |
+
},
|
| 153 |
+
"outputs": [],
|
| 154 |
+
"source": [
|
| 155 |
+
"import torch\n",
|
| 156 |
+
"from scipy.stats import spearmanr, pearsonr\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 159 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU\")\n",
|
| 160 |
+
"\n",
|
| 161 |
+
"def score_utmos(stems, tag):\n",
|
| 162 |
+
" \"\"\"UTMOS 2022 (torch.hub utmos22_strong). → dict {stem: score}. Cache.\"\"\"\n",
|
| 163 |
+
" import librosa\n",
|
| 164 |
+
" from tqdm.auto import tqdm\n",
|
| 165 |
+
" cache = os.path.join(CACHE_DIR, f\"utmos_{tag}.npz\")\n",
|
| 166 |
+
" store = {}\n",
|
| 167 |
+
" if os.path.exists(cache):\n",
|
| 168 |
+
" z = np.load(cache, allow_pickle=True)\n",
|
| 169 |
+
" store = {k: float(z[k]) for k in z.files}\n",
|
| 170 |
+
" print(f\"[utmos/{tag}] nạp cache: {len(store)}\")\n",
|
| 171 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 172 |
+
" if todo:\n",
|
| 173 |
+
" predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\",\n",
|
| 174 |
+
" trust_repo=True).to(device).eval()\n",
|
| 175 |
+
" with torch.no_grad():\n",
|
| 176 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"utmos {tag}\")):\n",
|
| 177 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 178 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 179 |
+
" store[s] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),\n",
|
| 180 |
+
" sr=16000).mean().item())\n",
|
| 181 |
+
" if (i + 1) % 500 == 0:\n",
|
| 182 |
+
" np.savez(cache, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 183 |
+
" np.savez(cache, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 184 |
+
" del predictor\n",
|
| 185 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 186 |
+
" return store\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"def score_utmosv2(stems, tag):\n",
|
| 189 |
+
" \"\"\"UTMOSv2 / T05 (sarulab-speech/UTMOSv2). → dict {stem: score}. Cache.\"\"\"\n",
|
| 190 |
+
" from tqdm.auto import tqdm\n",
|
| 191 |
+
" cache = os.path.join(CACHE_DIR, f\"utmosv2_{tag}.npz\")\n",
|
| 192 |
+
" store = {}\n",
|
| 193 |
+
" if os.path.exists(cache):\n",
|
| 194 |
+
" z = np.load(cache, allow_pickle=True)\n",
|
| 195 |
+
" store = {k: float(z[k]) for k in z.files}\n",
|
| 196 |
+
" print(f\"[utmosv2/{tag}] nạp cache: {len(store)}\")\n",
|
| 197 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 198 |
+
" if todo:\n",
|
| 199 |
+
" import utmosv2\n",
|
| 200 |
+
" model = utmosv2.create_model(pretrained=True) # ensemble, checkpoint tự tải\n",
|
| 201 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"utmosv2 {tag}\")):\n",
|
| 202 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 203 |
+
" out = model.predict(input_path=wav)\n",
|
| 204 |
+
" # predict trả về float (hoặc dict có 'predicted_mos') tùy phiên bản\n",
|
| 205 |
+
" store[s] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
|
| 206 |
+
" if (i + 1) % 200 == 0:\n",
|
| 207 |
+
" np.savez(cache, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 208 |
+
" np.savez(cache, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 209 |
+
" del model\n",
|
| 210 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 211 |
+
" return store"
|
| 212 |
+
]
|
| 213 |
+
},
|
| 214 |
+
{
|
| 215 |
+
"cell_type": "markdown",
|
| 216 |
+
"id": "d04d250c",
|
| 217 |
+
"metadata": {},
|
| 218 |
+
"source": [
|
| 219 |
+
"## 4. Chạy A/B trên mẫu train → in SRCC mỗi model vs nhãn qMOS vàng"
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"cell_type": "code",
|
| 224 |
+
"execution_count": null,
|
| 225 |
+
"id": "eb9fe414",
|
| 226 |
+
"metadata": {
|
| 227 |
+
"lines_to_next_cell": 1
|
| 228 |
+
},
|
| 229 |
+
"outputs": [],
|
| 230 |
+
"source": [
|
| 231 |
+
"utmos_s = score_utmos(probe_stems, \"probe\")\n",
|
| 232 |
+
"utmosv2_s = score_utmosv2(probe_stems, \"probe\")\n",
|
| 233 |
+
"\n",
|
| 234 |
+
"# Chỉ so trên các stem cả 2 model đều chấm được (để công bằng)\n",
|
| 235 |
+
"common = [s for s in probe_stems if s in utmos_s and s in utmosv2_s and s in qmos_gold]\n",
|
| 236 |
+
"y_gold = np.array([qmos_gold[s] for s in common])\n",
|
| 237 |
+
"p_v1 = np.array([utmos_s[s] for s in common])\n",
|
| 238 |
+
"p_v2 = np.array([utmosv2_s[s] for s in common])\n",
|
| 239 |
+
"print(f\"\\nSố mẫu so sánh chung: {len(common)}\")\n",
|
| 240 |
+
"\n",
|
| 241 |
+
"srcc_v1 = spearmanr(p_v1, y_gold).correlation\n",
|
| 242 |
+
"srcc_v2 = spearmanr(p_v2, y_gold).correlation\n",
|
| 243 |
+
"lcc_v1 = pearsonr(p_v1, y_gold)[0]\n",
|
| 244 |
+
"lcc_v2 = pearsonr(p_v2, y_gold)[0]\n",
|
| 245 |
+
"\n",
|
| 246 |
+
"print(\"\\n📊 A/B trên TRAIN (nhãn qMOS vàng) — UTT-SRCC là metric chính:\")\n",
|
| 247 |
+
"print(f\" UTMOS 2022 (đang dùng) : SRCC = {srcc_v1:.4f} | LCC = {lcc_v1:.4f}\")\n",
|
| 248 |
+
"print(f\" UTMOSv2 / T05 (mới) : SRCC = {srcc_v2:.4f} | LCC = {lcc_v2:.4f}\")\n",
|
| 249 |
+
"delta = srcc_v2 - srcc_v1\n",
|
| 250 |
+
"if delta > 0.01:\n",
|
| 251 |
+
" print(f\" ✅ UTMOSv2 THẮNG (+{delta:.4f} SRCC) → đáng dùng làm neo cho exp09 / đổi cột QMOS.\")\n",
|
| 252 |
+
"elif delta < -0.01:\n",
|
| 253 |
+
" print(f\" ⚠️ UTMOSv2 THUA ({delta:.4f} SRCC) → giữ UTMOS; lệch domain cảm xúc quá mạnh.\")\n",
|
| 254 |
+
"else:\n",
|
| 255 |
+
" print(f\" ➖ Ngang nhau ({delta:+.4f}) → ưu tiên model nào tiện hơn; chốt bằng fine-tune.\")\n",
|
| 256 |
+
"\n",
|
| 257 |
+
"# Mốc tham chiếu leaderboard: UTMOS zero-shot DEV = 0.414; head QMOS exp07 = 0.548.\n",
|
| 258 |
+
"# (SRCC train ≠ SRCC dev nhưng cùng xu hướng → dùng để quyết hướng, không phải điểm nộp.)\n",
|
| 259 |
+
"print(\"\\nℹ️ Mốc leaderboard DEV để đối chiếu: UTMOS zero-shot 0.414 · head QMOS exp07 0.548.\")"
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"cell_type": "markdown",
|
| 264 |
+
"id": "d808b9d6",
|
| 265 |
+
"metadata": {},
|
| 266 |
+
"source": [
|
| 267 |
+
"## 5. (Tùy chọn) Tạo answer.txt đổi cột QMOS←UTMOSv2 để nộp xác nhận DEV\n",
|
| 268 |
+
"Chỉ chạy nếu `EXP07_ANSWER` trỏ tới answer.txt exp07. Giữ nguyên 5 cột cảm xúc, chỉ thay QMOS."
|
| 269 |
+
]
|
| 270 |
+
},
|
| 271 |
+
{
|
| 272 |
+
"cell_type": "code",
|
| 273 |
+
"execution_count": null,
|
| 274 |
+
"id": "b08a39f5",
|
| 275 |
+
"metadata": {},
|
| 276 |
+
"outputs": [],
|
| 277 |
+
"source": [
|
| 278 |
+
"def build_swapped_answer(exp07_answer_path, out_path):\n",
|
| 279 |
+
" \"\"\"Đọc answer.txt exp07 (wav,QMOS,EMOS,CAT,VAL,ARO,DOM), thay QMOS = UTMOSv2(dev).\"\"\"\n",
|
| 280 |
+
" import csv\n",
|
| 281 |
+
" with open(DEV_SCP) as f:\n",
|
| 282 |
+
" dev_names = [ln.strip() for ln in f if ln.strip()]\n",
|
| 283 |
+
" dev_stems = [stem(n) for n in dev_names]\n",
|
| 284 |
+
" utmosv2_dev = score_utmosv2(dev_stems, \"dev\") # chấm DEV bằng UTMOSv2 (cache riêng)\n",
|
| 285 |
+
"\n",
|
| 286 |
+
" with open(exp07_answer_path) as f:\n",
|
| 287 |
+
" rows = list(csv.reader(f))\n",
|
| 288 |
+
" header, body = rows[0], rows[1:]\n",
|
| 289 |
+
" qi = header.index(\"QMOS\")\n",
|
| 290 |
+
" n_swap = 0\n",
|
| 291 |
+
" with open(out_path, \"w\") as f:\n",
|
| 292 |
+
" f.write(\",\".join(header) + \"\\n\")\n",
|
| 293 |
+
" for r in body:\n",
|
| 294 |
+
" sid = stem(r[0])\n",
|
| 295 |
+
" if sid in utmosv2_dev:\n",
|
| 296 |
+
" r[qi] = f\"{utmosv2_dev[sid]:.6g}\"\n",
|
| 297 |
+
" n_swap += 1\n",
|
| 298 |
+
" f.write(\",\".join(r) + \"\\n\")\n",
|
| 299 |
+
" print(f\"Ghi {len(body)} dòng → {out_path} | đổi QMOS được {n_swap} dòng\")\n",
|
| 300 |
+
" return out_path\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
|
| 303 |
+
" out = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 304 |
+
" build_swapped_answer(EXP07_ANSWER, out)\n",
|
| 305 |
+
" os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp09a_utmosv2.zip answer.txt \"\n",
|
| 306 |
+
" f\"&& unzip -l submission_track2_exp09a_utmosv2.zip\")\n",
|
| 307 |
+
" print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp09a_utmosv2.zip\"))\n",
|
| 308 |
+
"else:\n",
|
| 309 |
+
" print(\"Bỏ qua mục 5 (EXP07_ANSWER=None hoặc không tồn tại). Chỉ chạy A/B nội bộ.\")"
|
| 310 |
+
]
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"cell_type": "markdown",
|
| 314 |
+
"id": "adc8ba21",
|
| 315 |
+
"metadata": {},
|
| 316 |
+
"source": [
|
| 317 |
+
"## Ghi chú\n",
|
| 318 |
+
"- **Đọc kết quả mục 4:** UTMOSv2 SRCC có > UTMOS không?\n",
|
| 319 |
+
" - **Thắng rõ** → dùng UTMOSv2 làm **neo** cho `exp09` (fine-tune WavLM trên nhãn qMOS) thay UTMOS;\n",
|
| 320 |
+
" và/hoặc nộp answer.txt đổi cột (mục 5) để xác nhận trên leaderboard DEV.\n",
|
| 321 |
+
" - **Thua/ngang** → giữ UTMOS làm neo; kết luận \"UTMOSv2 vẫn lệch domain cảm xúc\" (phát hiện cho paper).\n",
|
| 322 |
+
"- **Gotcha Kaggle:** UTMOSv2 cài từ git + tải checkpoint → **Internet On**. Bản nộp Internet-off cần\n",
|
| 323 |
+
" pre-download weights thành Kaggle Dataset.\n",
|
| 324 |
+
"- UTMOSv2 là **ensemble nhiều fold** → chậm hơn UTMOS. Nếu lâu, giảm `PROBE_N` hoặc chấm dần (có cache).\n",
|
| 325 |
+
"- License: UTMOSv2 **MIT** · UTMOS BSD-3. Ghi vào `docs/12_system_description.md`.\n",
|
| 326 |
+
"- Ghi config → kết qu�� → nhận xét vào `docs/04_experiments_log.md` (mục exp09a)."
|
| 327 |
+
]
|
| 328 |
+
}
|
| 329 |
+
],
|
| 330 |
+
"metadata": {
|
| 331 |
+
"jupytext": {
|
| 332 |
+
"cell_metadata_filter": "-all",
|
| 333 |
+
"main_language": "python",
|
| 334 |
+
"notebook_metadata_filter": "-all"
|
| 335 |
+
}
|
| 336 |
+
},
|
| 337 |
+
"nbformat": 4,
|
| 338 |
+
"nbformat_minor": 5
|
| 339 |
+
}
|
track2/exp09a_qmos_utmosv2_probe_pipeline.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp09a (PROBE: UTMOSv2 vs UTMOS cho QMOS) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục đích (rẻ, KHÔNG tốn lượt nộp):** trước khi fine-tune QMOS, kiểm tra xem
|
| 5 |
+
# **UTMOSv2** (hệ thống **T05 — vô địch VoiceMOS Challenge 2024 Track 1**, naturalness MOS)
|
| 6 |
+
# có **mạnh hơn UTMOS 2022** (đang dùng) trên dữ liệu Track 2 hay không.
|
| 7 |
+
#
|
| 8 |
+
# ## Ý tưởng A/B không tốn lượt nộp
|
| 9 |
+
# Tập **train** Track 2 CÓ nhãn `qMOS` thật (`sets/train.csv`). Ta:
|
| 10 |
+
# 1. Chấm một mẫu train bằng **UTMOS** (torch.hub `utmos22_strong`) — baseline đang dùng.
|
| 11 |
+
# 2. Chấm cùng mẫu đó bằng **UTMOSv2** (`sarulab-speech/UTMOSv2`, MIT).
|
| 12 |
+
# 3. So **SRCC mỗi model vs nhãn qMOS vàng** → biết model nào "xếp hạng" giống người chấm hơn.
|
| 13 |
+
#
|
| 14 |
+
# > SRCC chấm **thứ hạng** (scale-invariant) → khỏi lo lệch thang điểm. Mẫu ~2.000 wav là đủ ổn định.
|
| 15 |
+
#
|
| 16 |
+
# ## Vì sao đáng thử
|
| 17 |
+
# - UTMOSv2 = #1 ở 7/16 metric VMC2024 Track 1 (bỏ xa hạng 3) → bản kế nhiệm trực tiếp của UTMOS.
|
| 18 |
+
# - **Lưu ý:** UTMOSv2 cũng train trên giọng *không* cảm xúc → vẫn có thể lệch domain; A/B này để
|
| 19 |
+
# biết nó có **đáng** làm "neo" mạnh hơn cho head QMOS fine-tune (exp09) hay không.
|
| 20 |
+
#
|
| 21 |
+
# **Cách chạy:** GPU T4 + **Internet On** (UTMOSv2 cài từ git + tải checkpoint) → Add Input dataset
|
| 22 |
+
# Track 2 → sửa `DATA_ROOT` → Run All. Lần đầu để `PROBE_N=300` cho nhanh, OK rồi tăng `2000`.
|
| 23 |
+
|
| 24 |
+
# %% [markdown]
|
| 25 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 26 |
+
|
| 27 |
+
# %%
|
| 28 |
+
import os
|
| 29 |
+
|
| 30 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
|
| 31 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 32 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
|
| 33 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 34 |
+
|
| 35 |
+
OUT_DIR = "/kaggle/working"
|
| 36 |
+
CACHE_DIR = "/kaggle/working/qmos_probe_cache"
|
| 37 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 38 |
+
|
| 39 |
+
DEVICE = "cuda"
|
| 40 |
+
PROBE_N = 2000 # số wav train để A/B (lần đầu để 300 cho nhanh). SRCC ~2000 mẫu đã ổn định.
|
| 41 |
+
SEED = 42
|
| 42 |
+
|
| 43 |
+
# (Tùy chọn) Nếu muốn TẠO LUÔN answer.txt đổi cột QMOS←UTMOSv2 để nộp xác nhận trên DEV:
|
| 44 |
+
# trỏ tới answer.txt của exp07 (giữ nguyên 5 cột cảm xúc, chỉ thay QMOS).
|
| 45 |
+
# Để None nếu chỉ muốn chạy A/B nội bộ.
|
| 46 |
+
EXP07_ANSWER = None # ví dụ: "/kaggle/input/exp07-answer/answer.txt"
|
| 47 |
+
|
| 48 |
+
def stem(p):
|
| 49 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 50 |
+
|
| 51 |
+
for p in [WAV_DIR, TRAIN_CSV, DEV_SCP]:
|
| 52 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 53 |
+
|
| 54 |
+
# %% [markdown]
|
| 55 |
+
# ## 1. Cài đặt (UTMOS + UTMOSv2)
|
| 56 |
+
|
| 57 |
+
# %%
|
| 58 |
+
import sys, subprocess
|
| 59 |
+
|
| 60 |
+
def pip_install(*pkgs):
|
| 61 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 62 |
+
|
| 63 |
+
pip_install("speechmos", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
|
| 64 |
+
# UTMOSv2 (T05) — cài từ git, cần Internet On. Checkpoint tự tải lần đầu.
|
| 65 |
+
pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
|
| 66 |
+
|
| 67 |
+
# %% [markdown]
|
| 68 |
+
# ## 2. Nhãn qMOS vàng (gộp trung bình theo wav)
|
| 69 |
+
|
| 70 |
+
# %%
|
| 71 |
+
import numpy as np
|
| 72 |
+
import pandas as pd
|
| 73 |
+
|
| 74 |
+
def load_qmos_labels():
|
| 75 |
+
"""train.csv (sep '|') → dict {stem: qMOS trung bình theo wav}."""
|
| 76 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 77 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 78 |
+
wav_col = cols.get("wavid") or cols.get("wav") or list(df.columns)[1]
|
| 79 |
+
qmos_col = cols.get("qmos") or cols.get("mos")
|
| 80 |
+
assert qmos_col, f"Không thấy cột qMOS (cột: {list(df.columns)})"
|
| 81 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 82 |
+
g = df.groupby("_stem")[qmos_col].mean()
|
| 83 |
+
return {s: float(v) for s, v in g.items()}
|
| 84 |
+
|
| 85 |
+
qmos_gold = load_qmos_labels()
|
| 86 |
+
print(f"Số wav train có nhãn qMOS: {len(qmos_gold)}")
|
| 87 |
+
|
| 88 |
+
# Chọn mẫu probe (chỉ giữ wav thật sự tồn tại trên đĩa)
|
| 89 |
+
rng = np.random.default_rng(SEED)
|
| 90 |
+
all_stems = [s for s in qmos_gold if os.path.exists(os.path.join(WAV_DIR, s + ".wav"))]
|
| 91 |
+
rng.shuffle(all_stems)
|
| 92 |
+
probe_stems = all_stems[:PROBE_N]
|
| 93 |
+
print(f"Mẫu probe: {len(probe_stems)} / {len(all_stems)} wav tồn tại")
|
| 94 |
+
|
| 95 |
+
# %% [markdown]
|
| 96 |
+
# ## 3. Hàm chấm: UTMOS (cũ) và UTMOSv2 (mới) — đều cache .npz
|
| 97 |
+
|
| 98 |
+
# %%
|
| 99 |
+
import torch
|
| 100 |
+
from scipy.stats import spearmanr, pearsonr
|
| 101 |
+
|
| 102 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 103 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU")
|
| 104 |
+
|
| 105 |
+
def score_utmos(stems, tag):
|
| 106 |
+
"""UTMOS 2022 (torch.hub utmos22_strong). → dict {stem: score}. Cache."""
|
| 107 |
+
import librosa
|
| 108 |
+
from tqdm.auto import tqdm
|
| 109 |
+
cache = os.path.join(CACHE_DIR, f"utmos_{tag}.npz")
|
| 110 |
+
store = {}
|
| 111 |
+
if os.path.exists(cache):
|
| 112 |
+
z = np.load(cache, allow_pickle=True)
|
| 113 |
+
store = {k: float(z[k]) for k in z.files}
|
| 114 |
+
print(f"[utmos/{tag}] nạp cache: {len(store)}")
|
| 115 |
+
todo = [s for s in stems if s not in store]
|
| 116 |
+
if todo:
|
| 117 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong",
|
| 118 |
+
trust_repo=True).to(device).eval()
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
for i, s in enumerate(tqdm(todo, desc=f"utmos {tag}")):
|
| 121 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 122 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 123 |
+
store[s] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),
|
| 124 |
+
sr=16000).mean().item())
|
| 125 |
+
if (i + 1) % 500 == 0:
|
| 126 |
+
np.savez(cache, **{k: np.float32(v) for k, v in store.items()})
|
| 127 |
+
np.savez(cache, **{k: np.float32(v) for k, v in store.items()})
|
| 128 |
+
del predictor
|
| 129 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 130 |
+
return store
|
| 131 |
+
|
| 132 |
+
def score_utmosv2(stems, tag):
|
| 133 |
+
"""UTMOSv2 / T05 (sarulab-speech/UTMOSv2). → dict {stem: score}. Cache."""
|
| 134 |
+
from tqdm.auto import tqdm
|
| 135 |
+
cache = os.path.join(CACHE_DIR, f"utmosv2_{tag}.npz")
|
| 136 |
+
store = {}
|
| 137 |
+
if os.path.exists(cache):
|
| 138 |
+
z = np.load(cache, allow_pickle=True)
|
| 139 |
+
store = {k: float(z[k]) for k in z.files}
|
| 140 |
+
print(f"[utmosv2/{tag}] nạp cache: {len(store)}")
|
| 141 |
+
todo = [s for s in stems if s not in store]
|
| 142 |
+
if todo:
|
| 143 |
+
import utmosv2
|
| 144 |
+
model = utmosv2.create_model(pretrained=True) # ensemble, checkpoint tự tải
|
| 145 |
+
for i, s in enumerate(tqdm(todo, desc=f"utmosv2 {tag}")):
|
| 146 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 147 |
+
out = model.predict(input_path=wav)
|
| 148 |
+
# predict trả về float (hoặc dict có 'predicted_mos') tùy phiên bản
|
| 149 |
+
store[s] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
|
| 150 |
+
if (i + 1) % 200 == 0:
|
| 151 |
+
np.savez(cache, **{k: np.float32(v) for k, v in store.items()})
|
| 152 |
+
np.savez(cache, **{k: np.float32(v) for k, v in store.items()})
|
| 153 |
+
del model
|
| 154 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 155 |
+
return store
|
| 156 |
+
|
| 157 |
+
# %% [markdown]
|
| 158 |
+
# ## 4. Chạy A/B trên mẫu train → in SRCC mỗi model vs nhãn qMOS vàng
|
| 159 |
+
|
| 160 |
+
# %%
|
| 161 |
+
utmos_s = score_utmos(probe_stems, "probe")
|
| 162 |
+
utmosv2_s = score_utmosv2(probe_stems, "probe")
|
| 163 |
+
|
| 164 |
+
# Chỉ so trên các stem cả 2 model đều chấm được (để công bằng)
|
| 165 |
+
common = [s for s in probe_stems if s in utmos_s and s in utmosv2_s and s in qmos_gold]
|
| 166 |
+
y_gold = np.array([qmos_gold[s] for s in common])
|
| 167 |
+
p_v1 = np.array([utmos_s[s] for s in common])
|
| 168 |
+
p_v2 = np.array([utmosv2_s[s] for s in common])
|
| 169 |
+
print(f"\nSố mẫu so sánh chung: {len(common)}")
|
| 170 |
+
|
| 171 |
+
srcc_v1 = spearmanr(p_v1, y_gold).correlation
|
| 172 |
+
srcc_v2 = spearmanr(p_v2, y_gold).correlation
|
| 173 |
+
lcc_v1 = pearsonr(p_v1, y_gold)[0]
|
| 174 |
+
lcc_v2 = pearsonr(p_v2, y_gold)[0]
|
| 175 |
+
|
| 176 |
+
print("\n📊 A/B trên TRAIN (nhãn qMOS vàng) — UTT-SRCC là metric chính:")
|
| 177 |
+
print(f" UTMOS 2022 (đang dùng) : SRCC = {srcc_v1:.4f} | LCC = {lcc_v1:.4f}")
|
| 178 |
+
print(f" UTMOSv2 / T05 (mới) : SRCC = {srcc_v2:.4f} | LCC = {lcc_v2:.4f}")
|
| 179 |
+
delta = srcc_v2 - srcc_v1
|
| 180 |
+
if delta > 0.01:
|
| 181 |
+
print(f" ✅ UTMOSv2 THẮNG (+{delta:.4f} SRCC) → đáng dùng làm neo cho exp09 / đổi cột QMOS.")
|
| 182 |
+
elif delta < -0.01:
|
| 183 |
+
print(f" ⚠️ UTMOSv2 THUA ({delta:.4f} SRCC) → giữ UTMOS; lệch domain cảm xúc quá mạnh.")
|
| 184 |
+
else:
|
| 185 |
+
print(f" ➖ Ngang nhau ({delta:+.4f}) → ưu tiên model nào tiện hơn; chốt bằng fine-tune.")
|
| 186 |
+
|
| 187 |
+
# Mốc tham chiếu leaderboard: UTMOS zero-shot DEV = 0.414; head QMOS exp07 = 0.548.
|
| 188 |
+
# (SRCC train ≠ SRCC dev nhưng cùng xu hướng → dùng để quyết hướng, không phải điểm nộp.)
|
| 189 |
+
print("\nℹ️ Mốc leaderboard DEV để đối chiếu: UTMOS zero-shot 0.414 · head QMOS exp07 0.548.")
|
| 190 |
+
|
| 191 |
+
# %% [markdown]
|
| 192 |
+
# ## 5. (Tùy chọn) Tạo answer.txt đổi cột QMOS←UTMOSv2 để nộp xác nhận DEV
|
| 193 |
+
# Chỉ chạy nếu `EXP07_ANSWER` trỏ tới answer.txt exp07. Giữ nguyên 5 cột cảm xúc, chỉ thay QMOS.
|
| 194 |
+
|
| 195 |
+
# %%
|
| 196 |
+
def build_swapped_answer(exp07_answer_path, out_path):
|
| 197 |
+
"""Đọc answer.txt exp07 (wav,QMOS,EMOS,CAT,VAL,ARO,DOM), thay QMOS = UTMOSv2(dev)."""
|
| 198 |
+
import csv
|
| 199 |
+
with open(DEV_SCP) as f:
|
| 200 |
+
dev_names = [ln.strip() for ln in f if ln.strip()]
|
| 201 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 202 |
+
utmosv2_dev = score_utmosv2(dev_stems, "dev") # chấm DEV bằng UTMOSv2 (cache riêng)
|
| 203 |
+
|
| 204 |
+
with open(exp07_answer_path) as f:
|
| 205 |
+
rows = list(csv.reader(f))
|
| 206 |
+
header, body = rows[0], rows[1:]
|
| 207 |
+
qi = header.index("QMOS")
|
| 208 |
+
n_swap = 0
|
| 209 |
+
with open(out_path, "w") as f:
|
| 210 |
+
f.write(",".join(header) + "\n")
|
| 211 |
+
for r in body:
|
| 212 |
+
sid = stem(r[0])
|
| 213 |
+
if sid in utmosv2_dev:
|
| 214 |
+
r[qi] = f"{utmosv2_dev[sid]:.6g}"
|
| 215 |
+
n_swap += 1
|
| 216 |
+
f.write(",".join(r) + "\n")
|
| 217 |
+
print(f"Ghi {len(body)} dòng → {out_path} | đổi QMOS được {n_swap} dòng")
|
| 218 |
+
return out_path
|
| 219 |
+
|
| 220 |
+
if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
|
| 221 |
+
out = os.path.join(OUT_DIR, "answer.txt")
|
| 222 |
+
build_swapped_answer(EXP07_ANSWER, out)
|
| 223 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp09a_utmosv2.zip answer.txt "
|
| 224 |
+
f"&& unzip -l submission_track2_exp09a_utmosv2.zip")
|
| 225 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp09a_utmosv2.zip"))
|
| 226 |
+
else:
|
| 227 |
+
print("Bỏ qua mục 5 (EXP07_ANSWER=None hoặc không tồn tại). Chỉ chạy A/B nội bộ.")
|
| 228 |
+
|
| 229 |
+
# %% [markdown]
|
| 230 |
+
# ## Ghi chú
|
| 231 |
+
# - **Đọc kết quả mục 4:** UTMOSv2 SRCC có > UTMOS không?
|
| 232 |
+
# - **Thắng rõ** → dùng UTMOSv2 làm **neo** cho `exp09` (fine-tune WavLM trên nhãn qMOS) thay UTMOS;
|
| 233 |
+
# và/hoặc nộp answer.txt đổi cột (mục 5) để xác nhận trên leaderboard DEV.
|
| 234 |
+
# - **Thua/ngang** → giữ UTMOS làm neo; kết luận "UTMOSv2 vẫn lệch domain cảm xúc" (phát hiện cho paper).
|
| 235 |
+
# - **Gotcha Kaggle:** UTMOSv2 cài từ git + tải checkpoint → **Internet On**. Bản nộp Internet-off cần
|
| 236 |
+
# pre-download weights thành Kaggle Dataset.
|
| 237 |
+
# - UTMOSv2 là **ensemble nhiều fold** → chậm hơn UTMOS. Nếu lâu, giảm `PROBE_N` hoặc chấm dần (có cache).
|
| 238 |
+
# - License: UTMOSv2 **MIT** · UTMOS BSD-3. Ghi vào `docs/12_system_description.md`.
|
| 239 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp09a).
|
track2/exp10_finetune_audeering.ipynb
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "678096c6",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp10 (fine-tune AUDEERING riêng + ensemble VAD với exp08) — Kaggle T4\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Ý tưởng (Hướng A — an toàn cho T4):** thay vì nhồi 2 backbone large vào 1 model (dễ OOM),\n",
|
| 11 |
+
"ta fine-tune **audeering wav2vec2-large** RIÊNG (1 backbone → vừa T4), rồi **ensemble cột VAD**\n",
|
| 12 |
+
"với exp08 (WavLM fine-tune). Mỗi lần chỉ 1 backbone trong VRAM → không OOM.\n",
|
| 13 |
+
"\n",
|
| 14 |
+
"```\n",
|
| 15 |
+
" [exp08] WavLM fine-tune ─► VAD_wavlm ┐\n",
|
| 16 |
+
" ├─ trung bình ─► VAD cuối (mạnh hơn cả 2)\n",
|
| 17 |
+
" [exp10] audeering fine-tune ─► VAD_aud ┘\n",
|
| 18 |
+
"```\n",
|
| 19 |
+
"audeering vốn là model **dimensional (chuyên VAD)** → fine-tune nó để bổ trợ VAD cho exp08.\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"**Cách chạy:** GPU T4 + Internet On → sửa slug cell 0 → Run All. Lần đầu `LIMIT_TRAIN=300`.\n",
|
| 22 |
+
"Để ensemble: Add Input answer.txt exp08 → trỏ `EXP08_ANSWER`."
|
| 23 |
+
]
|
| 24 |
+
},
|
| 25 |
+
{
|
| 26 |
+
"cell_type": "markdown",
|
| 27 |
+
"id": "291f23e3",
|
| 28 |
+
"metadata": {},
|
| 29 |
+
"source": [
|
| 30 |
+
"## 0. Cấu hình"
|
| 31 |
+
]
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"cell_type": "code",
|
| 35 |
+
"execution_count": null,
|
| 36 |
+
"id": "41d619c3",
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"import os\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"DATA_ROOT = \"/kaggle/input/datasets/minhtoan2/vmc2026-track2-full\" # << SỬA slug\n",
|
| 43 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 44 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\"\n",
|
| 45 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\"\n",
|
| 46 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 47 |
+
"\n",
|
| 48 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"# QMOS mượn exp07 (0.548); ensemble VAD với answer.txt exp08.\n",
|
| 51 |
+
"EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) mượn QMOS; không có → UTMOSv2\n",
|
| 52 |
+
"EXP08_ANSWER = \"/kaggle/input/exp08-answer/answer.txt\" # << (tùy chọn) để ENSEMBLE VAD; không có → chỉ ra answer audeering\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"# ── Fine-tune audeering (1 backbone) ─────────────────────────────────────────\n",
|
| 55 |
+
"DEVICE = \"cuda\"\n",
|
| 56 |
+
"SR = 16000\n",
|
| 57 |
+
"MAX_SECONDS = 8\n",
|
| 58 |
+
"UNFREEZE_TOP_LAYERS = 6 # số lớp encoder audeering mở băng (T4 thừa sức 1 backbone)\n",
|
| 59 |
+
"TRUNK_HIDDEN = 512\n",
|
| 60 |
+
"HEAD_HIDDEN = 128\n",
|
| 61 |
+
"DROPOUT = 0.3\n",
|
| 62 |
+
"LR_BACKBONE = 1e-5\n",
|
| 63 |
+
"LR_HEAD = 1e-3\n",
|
| 64 |
+
"WEIGHT_DECAY = 1e-5\n",
|
| 65 |
+
"EPOCHS = 12\n",
|
| 66 |
+
"PATIENCE = 4\n",
|
| 67 |
+
"BATCH = 4\n",
|
| 68 |
+
"ACCUM = 8\n",
|
| 69 |
+
"VAL_FRAC = 0.10\n",
|
| 70 |
+
"SEED = 42\n",
|
| 71 |
+
"USE_AMP = True\n",
|
| 72 |
+
"USE_GRAD_CKPT = True\n",
|
| 73 |
+
"USE_UNCERTAINTY = True\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"# Ensemble: cột nào lấy TRUNG BÌNH giữa exp08 và exp10; cột khác giữ từ exp08.\n",
|
| 76 |
+
"ENSEMBLE_COLS = [\"VAL\", \"ARO\", \"DOM\"] # audeering mạnh VAD → ensemble VAD. Thêm \"EMOS\" nếu muốn.\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
|
| 79 |
+
"LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
|
| 80 |
+
"\n",
|
| 81 |
+
"EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751}\n",
|
| 82 |
+
"\n",
|
| 83 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 84 |
+
"_EMO_ALIAS = {\n",
|
| 85 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 86 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 87 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 88 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 89 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 90 |
+
"}\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"def norm_emotion(label):\n",
|
| 93 |
+
" key = str(label).strip().lower()\n",
|
| 94 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"def stem(p):\n",
|
| 97 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 100 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 101 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 102 |
+
]
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"cell_type": "markdown",
|
| 106 |
+
"id": "0adb988b",
|
| 107 |
+
"metadata": {},
|
| 108 |
+
"source": [
|
| 109 |
+
"## 1. Cài đặt"
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
{
|
| 113 |
+
"cell_type": "code",
|
| 114 |
+
"execution_count": null,
|
| 115 |
+
"id": "1713d69b",
|
| 116 |
+
"metadata": {},
|
| 117 |
+
"outputs": [],
|
| 118 |
+
"source": [
|
| 119 |
+
"import sys, subprocess\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"def pip_install(*pkgs):\n",
|
| 122 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"pip_install(\"transformers\", \"huggingface_hub\", \"safetensors\", \"speechmos\",\n",
|
| 125 |
+
" \"librosa\", \"soundfile\", \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")"
|
| 126 |
+
]
|
| 127 |
+
},
|
| 128 |
+
{
|
| 129 |
+
"cell_type": "markdown",
|
| 130 |
+
"id": "4fcb8b30",
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"source": [
|
| 133 |
+
"## 2. Nạp audeering wav2vec2-large làm backbone FINE-TUNE\n",
|
| 134 |
+
"Nạp backbone tay (tránh lỗi subclass `Wav2Vec2PreTrainedModel` ở transformers mới) rồi mở băng lớp trên."
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": null,
|
| 140 |
+
"id": "f0a39dab",
|
| 141 |
+
"metadata": {
|
| 142 |
+
"lines_to_next_cell": 1
|
| 143 |
+
},
|
| 144 |
+
"outputs": [],
|
| 145 |
+
"source": [
|
| 146 |
+
"import torch\n",
|
| 147 |
+
"import torch.nn as nn\n",
|
| 148 |
+
"import torch.nn.functional as F\n",
|
| 149 |
+
"import numpy as np\n",
|
| 150 |
+
"\n",
|
| 151 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 152 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 155 |
+
"from huggingface_hub import hf_hub_download\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 158 |
+
"aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 159 |
+
"aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 160 |
+
"aud = Wav2Vec2Model(aud_cfg)\n",
|
| 161 |
+
"try:\n",
|
| 162 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 163 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 164 |
+
"except Exception:\n",
|
| 165 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 166 |
+
"bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 167 |
+
"miss, unexp = aud.load_state_dict(bb_sd, strict=False)\n",
|
| 168 |
+
"print(f\"audeering backbone: thiếu {len(miss)} / dư {len(unexp)} key (strict=False)\")\n",
|
| 169 |
+
"aud = aud.to(device)\n",
|
| 170 |
+
"AUD_DIM = int(aud.config.hidden_size)\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"# Đóng băng tất cả, mở băng UNFREEZE_TOP_LAYERS lớp encoder trên cùng\n",
|
| 173 |
+
"for p in aud.parameters():\n",
|
| 174 |
+
" p.requires_grad = False\n",
|
| 175 |
+
"enc_layers = aud.encoder.layers\n",
|
| 176 |
+
"n_layers = len(enc_layers)\n",
|
| 177 |
+
"for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:\n",
|
| 178 |
+
" for p in layer.parameters():\n",
|
| 179 |
+
" p.requires_grad = True\n",
|
| 180 |
+
"n_train = sum(p.numel() for p in aud.parameters() if p.requires_grad)\n",
|
| 181 |
+
"print(f\"audeering: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {AUD_DIM})\")\n",
|
| 182 |
+
"\n",
|
| 183 |
+
"if USE_GRAD_CKPT:\n",
|
| 184 |
+
" aud.gradient_checkpointing_enable()\n",
|
| 185 |
+
" if hasattr(aud, \"enable_input_require_grads\"):\n",
|
| 186 |
+
" aud.enable_input_require_grads()\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"def masked_mean(hidden, attn_mask):\n",
|
| 189 |
+
" if attn_mask is None:\n",
|
| 190 |
+
" return hidden.mean(dim=1)\n",
|
| 191 |
+
" try:\n",
|
| 192 |
+
" fm = aud._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
|
| 193 |
+
" except Exception:\n",
|
| 194 |
+
" return hidden.mean(dim=1)\n",
|
| 195 |
+
" fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
|
| 196 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
|
| 197 |
+
"\n",
|
| 198 |
+
"def aud_embed(input_values, attn_mask):\n",
|
| 199 |
+
" out = aud(input_values, attention_mask=attn_mask).last_hidden_state\n",
|
| 200 |
+
" return masked_mean(out, attn_mask)"
|
| 201 |
+
]
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"cell_type": "markdown",
|
| 205 |
+
"id": "54e993b9",
|
| 206 |
+
"metadata": {},
|
| 207 |
+
"source": [
|
| 208 |
+
"## 3. Nhãn (gộp theo wavID) — như exp08"
|
| 209 |
+
]
|
| 210 |
+
},
|
| 211 |
+
{
|
| 212 |
+
"cell_type": "code",
|
| 213 |
+
"execution_count": null,
|
| 214 |
+
"id": "46cc0e42",
|
| 215 |
+
"metadata": {},
|
| 216 |
+
"outputs": [],
|
| 217 |
+
"source": [
|
| 218 |
+
"import librosa\n",
|
| 219 |
+
"import pandas as pd\n",
|
| 220 |
+
"from tqdm.auto import tqdm\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"def load_target_emotions():\n",
|
| 223 |
+
" tgt = {}\n",
|
| 224 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 225 |
+
" for ln in f:\n",
|
| 226 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 227 |
+
" if len(parts) >= 2:\n",
|
| 228 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 229 |
+
" return tgt\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"def _col(cols_map, *names, df=None, default_idx=None):\n",
|
| 232 |
+
" for n in names:\n",
|
| 233 |
+
" if n in cols_map:\n",
|
| 234 |
+
" return cols_map[n]\n",
|
| 235 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"def parse_emocat_votes(cell):\n",
|
| 238 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 239 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 240 |
+
" e = norm_emotion(tok)\n",
|
| 241 |
+
" if e in EMOTIONS5:\n",
|
| 242 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 243 |
+
" return v\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"def load_train_labels():\n",
|
| 246 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 247 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 248 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
|
| 249 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 250 |
+
" val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 251 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 252 |
+
" assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
|
| 253 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 254 |
+
" rows = []\n",
|
| 255 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 256 |
+
" rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
|
| 257 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 258 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 259 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 260 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 261 |
+
" if cat_col:\n",
|
| 262 |
+
" for cell in g[cat_col]:\n",
|
| 263 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 264 |
+
" s = votes.sum()\n",
|
| 265 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
|
| 266 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 267 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 268 |
+
" rows.append(rec)\n",
|
| 269 |
+
" return pd.DataFrame(rows)\n",
|
| 270 |
+
"\n",
|
| 271 |
+
"target_map = load_target_emotions()\n",
|
| 272 |
+
"train_df = load_train_labels()\n",
|
| 273 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 274 |
+
"print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
|
| 275 |
+
]
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"cell_type": "markdown",
|
| 279 |
+
"id": "0e0bf0bb",
|
| 280 |
+
"metadata": {},
|
| 281 |
+
"source": [
|
| 282 |
+
"## 4. Dataset/loader (input_values qua audeering processor) + chuẩn hóa nhãn"
|
| 283 |
+
]
|
| 284 |
+
},
|
| 285 |
+
{
|
| 286 |
+
"cell_type": "code",
|
| 287 |
+
"execution_count": null,
|
| 288 |
+
"id": "65768ef9",
|
| 289 |
+
"metadata": {},
|
| 290 |
+
"outputs": [],
|
| 291 |
+
"source": [
|
| 292 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 293 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 294 |
+
"\n",
|
| 295 |
+
"train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
|
| 296 |
+
"if LIMIT_TRAIN:\n",
|
| 297 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 298 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"def _zfit(arr):\n",
|
| 301 |
+
" a = np.asarray(arr, dtype=np.float32)\n",
|
| 302 |
+
" return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
|
| 303 |
+
"\n",
|
| 304 |
+
"emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
|
| 305 |
+
"if HAS_VAD:\n",
|
| 306 |
+
" vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
|
| 307 |
+
" vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
|
| 308 |
+
"else:\n",
|
| 309 |
+
" vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"def onehot_target(tgt):\n",
|
| 312 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 313 |
+
" if tgt in EMOTIONS5:\n",
|
| 314 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 315 |
+
" return v\n",
|
| 316 |
+
"\n",
|
| 317 |
+
"def load_iv(sid):\n",
|
| 318 |
+
" \"\"\"Đọc wav → chuẩn hóa bằng audeering processor → input_values (1D float32).\"\"\"\n",
|
| 319 |
+
" p = os.path.join(WAV_DIR, sid if str(sid).endswith(\".wav\") else str(sid) + \".wav\")\n",
|
| 320 |
+
" if not os.path.exists(p):\n",
|
| 321 |
+
" return None\n",
|
| 322 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 323 |
+
" wave = wave[: MAX_SECONDS * SR]\n",
|
| 324 |
+
" iv = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 325 |
+
" return np.asarray(iv, dtype=np.float32)\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"class AudDataset(Dataset):\n",
|
| 328 |
+
" def __init__(self, stems):\n",
|
| 329 |
+
" self.stems = [s for s in stems if load_iv(s) is not None]\n",
|
| 330 |
+
" def __len__(self):\n",
|
| 331 |
+
" return len(self.stems)\n",
|
| 332 |
+
" def __getitem__(self, i):\n",
|
| 333 |
+
" s = self.stems[i]\n",
|
| 334 |
+
" iv = load_iv(s)\n",
|
| 335 |
+
" emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
|
| 336 |
+
" if HAS_VAD:\n",
|
| 337 |
+
" vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
|
| 338 |
+
" else:\n",
|
| 339 |
+
" vad = np.zeros(3, dtype=np.float32)\n",
|
| 340 |
+
" cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
|
| 341 |
+
" return {\"iv\": iv, \"tgt\": onehot_target(target_map.get(s)),\n",
|
| 342 |
+
" \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
|
| 343 |
+
" \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
|
| 344 |
+
" \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"def collate(batch):\n",
|
| 347 |
+
" L = max(len(b[\"iv\"]) for b in batch)\n",
|
| 348 |
+
" ivs = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 349 |
+
" mask = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 350 |
+
" for i, b in enumerate(batch):\n",
|
| 351 |
+
" ivs[i, : len(b[\"iv\"])] = b[\"iv\"]; mask[i, : len(b[\"iv\"])] = 1.0\n",
|
| 352 |
+
" return {\n",
|
| 353 |
+
" \"input_values\": torch.from_numpy(ivs), \"attn_mask\": torch.from_numpy(mask).long(),\n",
|
| 354 |
+
" \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
|
| 355 |
+
" \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
|
| 356 |
+
" \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
|
| 357 |
+
" \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
|
| 358 |
+
" \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
|
| 359 |
+
" \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
|
| 360 |
+
" }\n",
|
| 361 |
+
"\n",
|
| 362 |
+
"ds = AudDataset(train_stems)\n",
|
| 363 |
+
"print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
|
| 364 |
+
"tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
|
| 365 |
+
"tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
|
| 366 |
+
"va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
|
| 367 |
+
]
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"cell_type": "markdown",
|
| 371 |
+
"id": "697d9ca3",
|
| 372 |
+
"metadata": {},
|
| 373 |
+
"source": [
|
| 374 |
+
"## 5. Heads + train loop (lưu ft_audeering_full.pt mỗi best)"
|
| 375 |
+
]
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"cell_type": "code",
|
| 379 |
+
"execution_count": null,
|
| 380 |
+
"id": "8fe7ec40",
|
| 381 |
+
"metadata": {
|
| 382 |
+
"lines_to_next_cell": 1
|
| 383 |
+
},
|
| 384 |
+
"outputs": [],
|
| 385 |
+
"source": [
|
| 386 |
+
"from scipy.stats import spearmanr\n",
|
| 387 |
+
"\n",
|
| 388 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 389 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 390 |
+
"\n",
|
| 391 |
+
"class EmoHeads(nn.Module):\n",
|
| 392 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 393 |
+
" super().__init__()\n",
|
| 394 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 395 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 396 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 397 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 398 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 399 |
+
" def forward(self, feat, tgt):\n",
|
| 400 |
+
" h = self.trunk(feat)\n",
|
| 401 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 402 |
+
"\n",
|
| 403 |
+
"heads = EmoHeads(AUD_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 406 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 407 |
+
"bb_params = [p for p in aud.parameters() if p.requires_grad]\n",
|
| 408 |
+
"head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 409 |
+
"opt = torch.optim.AdamW([{\"params\": bb_params, \"lr\": LR_BACKBONE},\n",
|
| 410 |
+
" {\"params\": head_params, \"lr\": LR_HEAD}], weight_decay=WEIGHT_DECAY)\n",
|
| 411 |
+
"scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
|
| 412 |
+
"mse = nn.MSELoss()\n",
|
| 413 |
+
"\n",
|
| 414 |
+
"def soft_ce(logits, target_dist):\n",
|
| 415 |
+
" return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
|
| 416 |
+
"\n",
|
| 417 |
+
"def forward_batch(b):\n",
|
| 418 |
+
" feat = aud_embed(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
|
| 419 |
+
" return heads(feat, b[\"tgt\"].to(device))\n",
|
| 420 |
+
"\n",
|
| 421 |
+
"def compute_loss(emos_p, cat_l, vad_p, b):\n",
|
| 422 |
+
" L = {}\n",
|
| 423 |
+
" L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
|
| 424 |
+
" L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
|
| 425 |
+
" if HAS_VAD:\n",
|
| 426 |
+
" vt = b[\"vad\"].to(device)\n",
|
| 427 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
|
| 428 |
+
" else:\n",
|
| 429 |
+
" z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 430 |
+
" if USE_UNCERTAINTY:\n",
|
| 431 |
+
" return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
|
| 432 |
+
" return sum(L.values())\n",
|
| 433 |
+
"\n",
|
| 434 |
+
"@torch.no_grad()\n",
|
| 435 |
+
"def evaluate():\n",
|
| 436 |
+
" aud.eval(); heads.eval()\n",
|
| 437 |
+
" P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
|
| 438 |
+
" catP, catY = [], []\n",
|
| 439 |
+
" for b in va_loader:\n",
|
| 440 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 441 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 442 |
+
" P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
|
| 443 |
+
" vad_p = vad_p.float().cpu().numpy()\n",
|
| 444 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 445 |
+
" P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
|
| 446 |
+
" catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
|
| 447 |
+
" out = {}\n",
|
| 448 |
+
" for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
|
| 449 |
+
" out[t] = spearmanr(P[t], Y[t]).correlation\n",
|
| 450 |
+
" q = np.concatenate(catP); p = np.concatenate(catY)\n",
|
| 451 |
+
" out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
|
| 452 |
+
" return out\n",
|
| 453 |
+
"\n",
|
| 454 |
+
"def mean_srcc(m):\n",
|
| 455 |
+
" keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 456 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"CKPT_PATH = os.path.join(OUT_DIR, \"ft_audeering_full.pt\")\n",
|
| 459 |
+
"def save_full_ckpt(state, val_emos=float(\"nan\")):\n",
|
| 460 |
+
" torch.save({\"aud\": state[\"aud\"], \"heads\": state[\"heads\"],\n",
|
| 461 |
+
" \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
|
| 462 |
+
" \"AUD_DIM\": AUD_DIM, \"UNFREEZE_TOP_LAYERS\": UNFREEZE_TOP_LAYERS,\n",
|
| 463 |
+
" \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"best, best_state, bad = -1e9, None, 0\n",
|
| 466 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 467 |
+
" aud.train(); heads.train()\n",
|
| 468 |
+
" opt.zero_grad(); run = 0.0; nb = 0\n",
|
| 469 |
+
" for step, b in enumerate(tqdm(tr_loader, desc=f\"epoch {ep}\")):\n",
|
| 470 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 471 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 472 |
+
" loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
|
| 473 |
+
" scaler.scale(loss).backward()\n",
|
| 474 |
+
" if (step + 1) % ACCUM == 0:\n",
|
| 475 |
+
" scaler.step(opt); scaler.update(); opt.zero_grad()\n",
|
| 476 |
+
" run += loss.item() * ACCUM; nb += 1\n",
|
| 477 |
+
" m = evaluate(); sc = mean_srcc(m)\n",
|
| 478 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
|
| 479 |
+
" print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
|
| 480 |
+
" if sc > best:\n",
|
| 481 |
+
" best = sc\n",
|
| 482 |
+
" best_state = {\"aud\": {k: v.cpu().clone() for k, v in aud.state_dict().items()},\n",
|
| 483 |
+
" \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
|
| 484 |
+
" save_full_ckpt(best_state, m[\"emos\"])\n",
|
| 485 |
+
" print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\")\n",
|
| 486 |
+
" bad = 0\n",
|
| 487 |
+
" else:\n",
|
| 488 |
+
" bad += 1\n",
|
| 489 |
+
" if bad >= PATIENCE:\n",
|
| 490 |
+
" print(f\"Early stop ở epoch {ep}.\"); break\n",
|
| 491 |
+
"\n",
|
| 492 |
+
"if best_state:\n",
|
| 493 |
+
" aud.load_state_dict(best_state[\"aud\"]); heads.load_state_dict(best_state[\"heads\"])\n",
|
| 494 |
+
"final = evaluate()\n",
|
| 495 |
+
"print(\"\\n✅ VAL (nội bộ) — exp10 (fine-tune audeering):\")\n",
|
| 496 |
+
"print(f\" EMOS={final['emos']:.4f}\", end=\"\")\n",
|
| 497 |
+
"if HAS_VAD:\n",
|
| 498 |
+
" print(f\" | VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} (exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
|
| 499 |
+
"else:\n",
|
| 500 |
+
" print()\n",
|
| 501 |
+
"print(f\" → so exp08: audeering {'mạnh' if HAS_VAD and final['val'] > EXP08['val'] else 'yếu/ngang'} ở VAL. \"\n",
|
| 502 |
+
" f\"Ensemble sẽ lấy trung bình 2 model.\")\n",
|
| 503 |
+
"save_full_ckpt(best_state if best_state else {\"aud\": aud.state_dict(), \"heads\": heads.state_dict()}, final[\"emos\"])\n",
|
| 504 |
+
"print(f\"✅ Đã lưu {CKPT_PATH}. NHỚ Save Version!\")"
|
| 505 |
+
]
|
| 506 |
+
},
|
| 507 |
+
{
|
| 508 |
+
"cell_type": "markdown",
|
| 509 |
+
"id": "4911ff48",
|
| 510 |
+
"metadata": {},
|
| 511 |
+
"source": [
|
| 512 |
+
"## 6. Dự đoán DEV → predictions + answer_audeering.txt"
|
| 513 |
+
]
|
| 514 |
+
},
|
| 515 |
+
{
|
| 516 |
+
"cell_type": "code",
|
| 517 |
+
"execution_count": null,
|
| 518 |
+
"id": "4ed9a022",
|
| 519 |
+
"metadata": {},
|
| 520 |
+
"outputs": [],
|
| 521 |
+
"source": [
|
| 522 |
+
"def list_dev():\n",
|
| 523 |
+
" with open(DEV_SCP) as f:\n",
|
| 524 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 525 |
+
"\n",
|
| 526 |
+
"dev_names = list_dev()\n",
|
| 527 |
+
"if LIMIT_DEV:\n",
|
| 528 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 529 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 530 |
+
"\n",
|
| 531 |
+
"def load_exp07_qmos():\n",
|
| 532 |
+
" if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
|
| 533 |
+
" import csv\n",
|
| 534 |
+
" d = {}\n",
|
| 535 |
+
" with open(EXP07_ANSWER) as f:\n",
|
| 536 |
+
" for row in csv.DictReader(f):\n",
|
| 537 |
+
" d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
|
| 538 |
+
" print(f\"✅ Mượn QMOS từ exp07: {len(d)//2} wav\")\n",
|
| 539 |
+
" return d\n",
|
| 540 |
+
" return None\n",
|
| 541 |
+
"\n",
|
| 542 |
+
"qmos_map = load_exp07_qmos()\n",
|
| 543 |
+
"if qmos_map is None:\n",
|
| 544 |
+
" print(\"ℹ️ Không có exp07 → QMOS bằng UTMOSv2.\")\n",
|
| 545 |
+
" pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
|
| 546 |
+
" import utmosv2\n",
|
| 547 |
+
" v2 = utmosv2.create_model(pretrained=True)\n",
|
| 548 |
+
" qmos_map = {}\n",
|
| 549 |
+
" for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
|
| 550 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
|
| 551 |
+
" if os.path.exists(wav):\n",
|
| 552 |
+
" o = v2.predict(input_path=wav)\n",
|
| 553 |
+
" qmos_map[n] = float(o[\"predicted_mos\"]) if isinstance(o, dict) else float(o)\n",
|
| 554 |
+
" del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 555 |
+
"\n",
|
| 556 |
+
"@torch.no_grad()\n",
|
| 557 |
+
"def predict_emotion(sid):\n",
|
| 558 |
+
" iv = load_iv(sid)\n",
|
| 559 |
+
" if iv is None:\n",
|
| 560 |
+
" return None\n",
|
| 561 |
+
" aud.eval(); heads.eval()\n",
|
| 562 |
+
" ivt = torch.from_numpy(iv).unsqueeze(0).to(device)\n",
|
| 563 |
+
" am = torch.ones((1, len(iv)), dtype=torch.long, device=device)\n",
|
| 564 |
+
" tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
|
| 565 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 566 |
+
" feat = aud_embed(ivt, am)\n",
|
| 567 |
+
" emos_p, cat_l, vad_p = heads(feat, tgt)\n",
|
| 568 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 569 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 570 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 571 |
+
" return emos, cat5, vad3\n",
|
| 572 |
+
"\n",
|
| 573 |
+
"def fmt_cat(p5):\n",
|
| 574 |
+
" return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"dev_pred = {} # name -> (emos, cat5, vad3)\n",
|
| 577 |
+
"with open(os.path.join(OUT_DIR, \"answer_audeering.txt\"), \"w\") as f:\n",
|
| 578 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 579 |
+
" for name in tqdm(dev_names, desc=\"answer_aud\"):\n",
|
| 580 |
+
" sid = stem(name)\n",
|
| 581 |
+
" pr = predict_emotion(sid)\n",
|
| 582 |
+
" if pr is None:\n",
|
| 583 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])\n",
|
| 584 |
+
" else:\n",
|
| 585 |
+
" emos, cat5, vad3 = pr\n",
|
| 586 |
+
" dev_pred[name] = (emos, cat5, vad3)\n",
|
| 587 |
+
" qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
|
| 588 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 589 |
+
"print(\"Đã ghi answer_audeering.txt\")"
|
| 590 |
+
]
|
| 591 |
+
},
|
| 592 |
+
{
|
| 593 |
+
"cell_type": "markdown",
|
| 594 |
+
"id": "72f8f313",
|
| 595 |
+
"metadata": {},
|
| 596 |
+
"source": [
|
| 597 |
+
"## 7. ENSEMBLE với exp08 → answer.txt cuối (trung bình cột VAD)\n",
|
| 598 |
+
"Lấy answer.txt exp08 làm nền; cột trong `ENSEMBLE_COLS` = trung bình (exp08 + exp10). Còn lại giữ exp08."
|
| 599 |
+
]
|
| 600 |
+
},
|
| 601 |
+
{
|
| 602 |
+
"cell_type": "code",
|
| 603 |
+
"execution_count": null,
|
| 604 |
+
"id": "a33720f9",
|
| 605 |
+
"metadata": {
|
| 606 |
+
"lines_to_next_cell": 1
|
| 607 |
+
},
|
| 608 |
+
"outputs": [],
|
| 609 |
+
"source": [
|
| 610 |
+
"import csv\n",
|
| 611 |
+
"COL_IDX = {\"QMOS\": 1, \"EMOS\": 2, \"VAL\": 4, \"ARO\": 5, \"DOM\": 6} # vị trí cột trong answer.txt\n",
|
| 612 |
+
"AUD_VAL = {\"EMOS\": lambda p: p[0], \"VAL\": lambda p: p[2][0], \"ARO\": lambda p: p[2][1], \"DOM\": lambda p: p[2][2]}\n",
|
| 613 |
+
"\n",
|
| 614 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 615 |
+
"if EXP08_ANSWER and os.path.exists(EXP08_ANSWER):\n",
|
| 616 |
+
" with open(EXP08_ANSWER) as f:\n",
|
| 617 |
+
" rows = list(csv.reader(f))\n",
|
| 618 |
+
" header, body = rows[0], rows[1:]\n",
|
| 619 |
+
" n_ens = 0\n",
|
| 620 |
+
" with open(answer_path, \"w\") as f:\n",
|
| 621 |
+
" f.write(\",\".join(header) + \"\\n\")\n",
|
| 622 |
+
" for r in body:\n",
|
| 623 |
+
" name = r[0]; sid = stem(name)\n",
|
| 624 |
+
" pr = dev_pred.get(name) or dev_pred.get(sid)\n",
|
| 625 |
+
" if pr is not None:\n",
|
| 626 |
+
" for col in ENSEMBLE_COLS:\n",
|
| 627 |
+
" if col in COL_IDX and col in AUD_VAL:\n",
|
| 628 |
+
" v08 = float(r[COL_IDX[col]]); vaud = float(AUD_VAL[col](pr))\n",
|
| 629 |
+
" r[COL_IDX[col]] = f\"{0.5*(v08+vaud):.6g}\"\n",
|
| 630 |
+
" n_ens += 1\n",
|
| 631 |
+
" f.write(\",\".join(r) + \"\\n\")\n",
|
| 632 |
+
" print(f\"✅ Ensemble {ENSEMBLE_COLS}: {n_ens} dòng → {answer_path} (nền exp08 + trung bình audeering)\")\n",
|
| 633 |
+
"else:\n",
|
| 634 |
+
" print(\"ℹ️ Không có EXP08_ANSWER → answer.txt = answer_audeering.txt (chỉ audeering, chưa ensemble).\")\n",
|
| 635 |
+
" import shutil\n",
|
| 636 |
+
" shutil.copy(os.path.join(OUT_DIR, \"answer_audeering.txt\"), answer_path)"
|
| 637 |
+
]
|
| 638 |
+
},
|
| 639 |
+
{
|
| 640 |
+
"cell_type": "markdown",
|
| 641 |
+
"id": "5089b7fa",
|
| 642 |
+
"metadata": {},
|
| 643 |
+
"source": [
|
| 644 |
+
"## 8. Validate + zip"
|
| 645 |
+
]
|
| 646 |
+
},
|
| 647 |
+
{
|
| 648 |
+
"cell_type": "code",
|
| 649 |
+
"execution_count": null,
|
| 650 |
+
"id": "66272528",
|
| 651 |
+
"metadata": {},
|
| 652 |
+
"outputs": [],
|
| 653 |
+
"source": [
|
| 654 |
+
"def validate(path):\n",
|
| 655 |
+
" with open(path) as f:\n",
|
| 656 |
+
" rows = list(csv.reader(f))\n",
|
| 657 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0], \"Header sai\"\n",
|
| 658 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 659 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 660 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 661 |
+
"\n",
|
| 662 |
+
"validate(answer_path)\n",
|
| 663 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp10_ensemble.zip answer.txt && unzip -l submission_track2_exp10_ensemble.zip\")\n",
|
| 664 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp10_ensemble.zip\"))"
|
| 665 |
+
]
|
| 666 |
+
},
|
| 667 |
+
{
|
| 668 |
+
"cell_type": "markdown",
|
| 669 |
+
"id": "994723d2",
|
| 670 |
+
"metadata": {},
|
| 671 |
+
"source": [
|
| 672 |
+
"## Ghi chú\n",
|
| 673 |
+
"- **Hướng A (T4-an toàn):** fine-tune audeering RIÊNG (1 backbone) → ensemble VAD với exp08 → KHÔNG OOM.\n",
|
| 674 |
+
"- **Đọc mục 5:** audeering VAL/ARO/DOM có ≥ exp08 không? Nếu ngang/hơn → ensemble đáng giá.\n",
|
| 675 |
+
"- **Ensemble (mục 7):** mặc định trung bình VAL/ARO/DOM. Thêm \"EMOS\" vào `ENSEMBLE_COLS` nếu audeering EMOS tốt.\n",
|
| 676 |
+
"- **Checkpoint:** lưu `ft_audeering_full.pt` mỗi best (kernel chết vẫn còn). Save Version sau khi xong.\n",
|
| 677 |
+
"- QMOS vẫn mượn exp07 (0.548). So sánh: nộp answer.txt ensemble vs exp08 thuần để xem ensemble có nhích VAD.\n",
|
| 678 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (exp10)."
|
| 679 |
+
]
|
| 680 |
+
}
|
| 681 |
+
],
|
| 682 |
+
"metadata": {
|
| 683 |
+
"jupytext": {
|
| 684 |
+
"cell_metadata_filter": "-all",
|
| 685 |
+
"main_language": "python",
|
| 686 |
+
"notebook_metadata_filter": "-all"
|
| 687 |
+
}
|
| 688 |
+
},
|
| 689 |
+
"nbformat": 4,
|
| 690 |
+
"nbformat_minor": 5
|
| 691 |
+
}
|
track2/exp10_finetune_audeering_pipeline.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp10 (fine-tune AUDEERING riêng + ensemble VAD với exp08) — Kaggle T4
|
| 3 |
+
#
|
| 4 |
+
# **Ý tưởng (Hướng A — an toàn cho T4):** thay vì nhồi 2 backbone large vào 1 model (dễ OOM),
|
| 5 |
+
# ta fine-tune **audeering wav2vec2-large** RIÊNG (1 backbone → vừa T4), rồi **ensemble cột VAD**
|
| 6 |
+
# với exp08 (WavLM fine-tune). Mỗi lần chỉ 1 backbone trong VRAM → không OOM.
|
| 7 |
+
#
|
| 8 |
+
# ```
|
| 9 |
+
# [exp08] WavLM fine-tune ─► VAD_wavlm ┐
|
| 10 |
+
# ├─ trung bình ─► VAD cuối (mạnh hơn cả 2)
|
| 11 |
+
# [exp10] audeering fine-tune ─► VAD_aud ┘
|
| 12 |
+
# ```
|
| 13 |
+
# audeering vốn là model **dimensional (chuyên VAD)** → fine-tune nó để bổ trợ VAD cho exp08.
|
| 14 |
+
#
|
| 15 |
+
# **Cách chạy:** GPU T4 + Internet On → sửa slug cell 0 → Run All. Lần đầu `LIMIT_TRAIN=300`.
|
| 16 |
+
# Để ensemble: Add Input answer.txt exp08 → trỏ `EXP08_ANSWER`.
|
| 17 |
+
|
| 18 |
+
# %% [markdown]
|
| 19 |
+
# ## 0. Cấu hình
|
| 20 |
+
|
| 21 |
+
# %%
|
| 22 |
+
import os
|
| 23 |
+
|
| 24 |
+
DATA_ROOT = "/kaggle/input/datasets/minhtoan2/vmc2026-track2-full" # << SỬA slug
|
| 25 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 26 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv"
|
| 27 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv"
|
| 28 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 29 |
+
|
| 30 |
+
OUT_DIR = "/kaggle/working"
|
| 31 |
+
|
| 32 |
+
# QMOS mượn exp07 (0.548); ensemble VAD với answer.txt exp08.
|
| 33 |
+
EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) mượn QMOS; không có → UTMOSv2
|
| 34 |
+
EXP08_ANSWER = "/kaggle/input/exp08-answer/answer.txt" # << (tùy chọn) để ENSEMBLE VAD; không có → chỉ ra answer audeering
|
| 35 |
+
|
| 36 |
+
# ── Fine-tune audeering (1 backbone) ─────────────────────────────────────────
|
| 37 |
+
DEVICE = "cuda"
|
| 38 |
+
SR = 16000
|
| 39 |
+
MAX_SECONDS = 8
|
| 40 |
+
UNFREEZE_TOP_LAYERS = 6 # số lớp encoder audeering mở băng (T4 thừa sức 1 backbone)
|
| 41 |
+
TRUNK_HIDDEN = 512
|
| 42 |
+
HEAD_HIDDEN = 128
|
| 43 |
+
DROPOUT = 0.3
|
| 44 |
+
LR_BACKBONE = 1e-5
|
| 45 |
+
LR_HEAD = 1e-3
|
| 46 |
+
WEIGHT_DECAY = 1e-5
|
| 47 |
+
EPOCHS = 12
|
| 48 |
+
PATIENCE = 4
|
| 49 |
+
BATCH = 4
|
| 50 |
+
ACCUM = 8
|
| 51 |
+
VAL_FRAC = 0.10
|
| 52 |
+
SEED = 42
|
| 53 |
+
USE_AMP = True
|
| 54 |
+
USE_GRAD_CKPT = True
|
| 55 |
+
USE_UNCERTAINTY = True
|
| 56 |
+
|
| 57 |
+
# Ensemble: cột nào lấy TRUNG BÌNH giữa exp08 và exp10; cột khác giữ từ exp08.
|
| 58 |
+
ENSEMBLE_COLS = ["VAL", "ARO", "DOM"] # audeering mạnh VAD → ensemble VAD. Thêm "EMOS" nếu muốn.
|
| 59 |
+
|
| 60 |
+
LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
|
| 61 |
+
LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
|
| 62 |
+
|
| 63 |
+
EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751}
|
| 64 |
+
|
| 65 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 66 |
+
_EMO_ALIAS = {
|
| 67 |
+
"angry": "angry", "anger": "angry",
|
| 68 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 69 |
+
"neutral": "neutral", "calm": "neutral",
|
| 70 |
+
"sad": "sad", "sadness": "sad",
|
| 71 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
def norm_emotion(label):
|
| 75 |
+
key = str(label).strip().lower()
|
| 76 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 77 |
+
|
| 78 |
+
def stem(p):
|
| 79 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 80 |
+
|
| 81 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 82 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 83 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 84 |
+
|
| 85 |
+
# %% [markdown]
|
| 86 |
+
# ## 1. Cài đặt
|
| 87 |
+
|
| 88 |
+
# %%
|
| 89 |
+
import sys, subprocess
|
| 90 |
+
|
| 91 |
+
def pip_install(*pkgs):
|
| 92 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 93 |
+
|
| 94 |
+
pip_install("transformers", "huggingface_hub", "safetensors", "speechmos",
|
| 95 |
+
"librosa", "soundfile", "scipy", "scikit-learn", "pandas", "tqdm")
|
| 96 |
+
|
| 97 |
+
# %% [markdown]
|
| 98 |
+
# ## 2. Nạp audeering wav2vec2-large làm backbone FINE-TUNE
|
| 99 |
+
# Nạp backbone tay (tránh lỗi subclass `Wav2Vec2PreTrainedModel` ở transformers mới) rồi mở băng lớp trên.
|
| 100 |
+
|
| 101 |
+
# %%
|
| 102 |
+
import torch
|
| 103 |
+
import torch.nn as nn
|
| 104 |
+
import torch.nn.functional as F
|
| 105 |
+
import numpy as np
|
| 106 |
+
|
| 107 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 108 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
|
| 109 |
+
|
| 110 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 111 |
+
from huggingface_hub import hf_hub_download
|
| 112 |
+
|
| 113 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 114 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 115 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 116 |
+
aud = Wav2Vec2Model(aud_cfg)
|
| 117 |
+
try:
|
| 118 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 119 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 120 |
+
except Exception:
|
| 121 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 122 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 123 |
+
miss, unexp = aud.load_state_dict(bb_sd, strict=False)
|
| 124 |
+
print(f"audeering backbone: thiếu {len(miss)} / dư {len(unexp)} key (strict=False)")
|
| 125 |
+
aud = aud.to(device)
|
| 126 |
+
AUD_DIM = int(aud.config.hidden_size)
|
| 127 |
+
|
| 128 |
+
# Đóng băng tất cả, mở băng UNFREEZE_TOP_LAYERS lớp encoder trên cùng
|
| 129 |
+
for p in aud.parameters():
|
| 130 |
+
p.requires_grad = False
|
| 131 |
+
enc_layers = aud.encoder.layers
|
| 132 |
+
n_layers = len(enc_layers)
|
| 133 |
+
for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:
|
| 134 |
+
for p in layer.parameters():
|
| 135 |
+
p.requires_grad = True
|
| 136 |
+
n_train = sum(p.numel() for p in aud.parameters() if p.requires_grad)
|
| 137 |
+
print(f"audeering: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {AUD_DIM})")
|
| 138 |
+
|
| 139 |
+
if USE_GRAD_CKPT:
|
| 140 |
+
aud.gradient_checkpointing_enable()
|
| 141 |
+
if hasattr(aud, "enable_input_require_grads"):
|
| 142 |
+
aud.enable_input_require_grads()
|
| 143 |
+
|
| 144 |
+
def masked_mean(hidden, attn_mask):
|
| 145 |
+
if attn_mask is None:
|
| 146 |
+
return hidden.mean(dim=1)
|
| 147 |
+
try:
|
| 148 |
+
fm = aud._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
|
| 149 |
+
except Exception:
|
| 150 |
+
return hidden.mean(dim=1)
|
| 151 |
+
fm = fm.unsqueeze(-1).to(hidden.dtype)
|
| 152 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 153 |
+
|
| 154 |
+
def aud_embed(input_values, attn_mask):
|
| 155 |
+
out = aud(input_values, attention_mask=attn_mask).last_hidden_state
|
| 156 |
+
return masked_mean(out, attn_mask)
|
| 157 |
+
|
| 158 |
+
# %% [markdown]
|
| 159 |
+
# ## 3. Nhãn (gộp theo wavID) — như exp08
|
| 160 |
+
|
| 161 |
+
# %%
|
| 162 |
+
import librosa
|
| 163 |
+
import pandas as pd
|
| 164 |
+
from tqdm.auto import tqdm
|
| 165 |
+
|
| 166 |
+
def load_target_emotions():
|
| 167 |
+
tgt = {}
|
| 168 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 169 |
+
for ln in f:
|
| 170 |
+
parts = ln.strip().split("|")
|
| 171 |
+
if len(parts) >= 2:
|
| 172 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 173 |
+
return tgt
|
| 174 |
+
|
| 175 |
+
def _col(cols_map, *names, df=None, default_idx=None):
|
| 176 |
+
for n in names:
|
| 177 |
+
if n in cols_map:
|
| 178 |
+
return cols_map[n]
|
| 179 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 180 |
+
|
| 181 |
+
def parse_emocat_votes(cell):
|
| 182 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 183 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 184 |
+
e = norm_emotion(tok)
|
| 185 |
+
if e in EMOTIONS5:
|
| 186 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 187 |
+
return v
|
| 188 |
+
|
| 189 |
+
def load_train_labels():
|
| 190 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 191 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 192 |
+
wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
|
| 193 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 194 |
+
val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
|
| 195 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 196 |
+
assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
|
| 197 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 198 |
+
rows = []
|
| 199 |
+
for sid, g in df.groupby("_stem"):
|
| 200 |
+
rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
|
| 201 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 202 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 203 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 204 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 205 |
+
if cat_col:
|
| 206 |
+
for cell in g[cat_col]:
|
| 207 |
+
votes += parse_emocat_votes(cell)
|
| 208 |
+
s = votes.sum()
|
| 209 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
|
| 210 |
+
for i in range(len(EMOTIONS5)):
|
| 211 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 212 |
+
rows.append(rec)
|
| 213 |
+
return pd.DataFrame(rows)
|
| 214 |
+
|
| 215 |
+
target_map = load_target_emotions()
|
| 216 |
+
train_df = load_train_labels()
|
| 217 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 218 |
+
print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
|
| 219 |
+
|
| 220 |
+
# %% [markdown]
|
| 221 |
+
# ## 4. Dataset/loader (input_values qua audeering processor) + chuẩn hóa nhãn
|
| 222 |
+
|
| 223 |
+
# %%
|
| 224 |
+
from torch.utils.data import Dataset, DataLoader
|
| 225 |
+
from sklearn.model_selection import train_test_split
|
| 226 |
+
|
| 227 |
+
train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
|
| 228 |
+
if LIMIT_TRAIN:
|
| 229 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 230 |
+
lab = train_df.set_index("wavID")
|
| 231 |
+
|
| 232 |
+
def _zfit(arr):
|
| 233 |
+
a = np.asarray(arr, dtype=np.float32)
|
| 234 |
+
return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
|
| 235 |
+
|
| 236 |
+
emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
|
| 237 |
+
if HAS_VAD:
|
| 238 |
+
vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], dtype=np.float32)
|
| 239 |
+
vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], dtype=np.float32)
|
| 240 |
+
else:
|
| 241 |
+
vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
|
| 242 |
+
|
| 243 |
+
def onehot_target(tgt):
|
| 244 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 245 |
+
if tgt in EMOTIONS5:
|
| 246 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 247 |
+
return v
|
| 248 |
+
|
| 249 |
+
def load_iv(sid):
|
| 250 |
+
"""Đọc wav → chuẩn hóa bằng audeering processor → input_values (1D float32)."""
|
| 251 |
+
p = os.path.join(WAV_DIR, sid if str(sid).endswith(".wav") else str(sid) + ".wav")
|
| 252 |
+
if not os.path.exists(p):
|
| 253 |
+
return None
|
| 254 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 255 |
+
wave = wave[: MAX_SECONDS * SR]
|
| 256 |
+
iv = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 257 |
+
return np.asarray(iv, dtype=np.float32)
|
| 258 |
+
|
| 259 |
+
class AudDataset(Dataset):
|
| 260 |
+
def __init__(self, stems):
|
| 261 |
+
self.stems = [s for s in stems if load_iv(s) is not None]
|
| 262 |
+
def __len__(self):
|
| 263 |
+
return len(self.stems)
|
| 264 |
+
def __getitem__(self, i):
|
| 265 |
+
s = self.stems[i]
|
| 266 |
+
iv = load_iv(s)
|
| 267 |
+
emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
|
| 268 |
+
if HAS_VAD:
|
| 269 |
+
vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
|
| 270 |
+
else:
|
| 271 |
+
vad = np.zeros(3, dtype=np.float32)
|
| 272 |
+
cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
|
| 273 |
+
return {"iv": iv, "tgt": onehot_target(target_map.get(s)),
|
| 274 |
+
"emos": np.float32(emos), "vad": vad, "cat": cat,
|
| 275 |
+
"emos_raw": np.float32(lab.loc[s, "emos"]),
|
| 276 |
+
"vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
|
| 277 |
+
|
| 278 |
+
def collate(batch):
|
| 279 |
+
L = max(len(b["iv"]) for b in batch)
|
| 280 |
+
ivs = np.zeros((len(batch), L), dtype=np.float32)
|
| 281 |
+
mask = np.zeros((len(batch), L), dtype=np.float32)
|
| 282 |
+
for i, b in enumerate(batch):
|
| 283 |
+
ivs[i, : len(b["iv"])] = b["iv"]; mask[i, : len(b["iv"])] = 1.0
|
| 284 |
+
return {
|
| 285 |
+
"input_values": torch.from_numpy(ivs), "attn_mask": torch.from_numpy(mask).long(),
|
| 286 |
+
"tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
|
| 287 |
+
"emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
|
| 288 |
+
"vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
|
| 289 |
+
"cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
|
| 290 |
+
"emos_raw": np.stack([b["emos_raw"] for b in batch]),
|
| 291 |
+
"vad_raw": np.stack([b["vad_raw"] for b in batch]),
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
ds = AudDataset(train_stems)
|
| 295 |
+
print("Dataset hợp lệ:", len(ds), "wav")
|
| 296 |
+
tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
|
| 297 |
+
tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
|
| 298 |
+
va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
|
| 299 |
+
|
| 300 |
+
# %% [markdown]
|
| 301 |
+
# ## 5. Heads + train loop (lưu ft_audeering_full.pt mỗi best)
|
| 302 |
+
|
| 303 |
+
# %%
|
| 304 |
+
from scipy.stats import spearmanr
|
| 305 |
+
|
| 306 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 307 |
+
N_EMO = len(EMOTIONS5)
|
| 308 |
+
|
| 309 |
+
class EmoHeads(nn.Module):
|
| 310 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 311 |
+
super().__init__()
|
| 312 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 313 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 314 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 315 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 316 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 317 |
+
def forward(self, feat, tgt):
|
| 318 |
+
h = self.trunk(feat)
|
| 319 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 320 |
+
|
| 321 |
+
heads = EmoHeads(AUD_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
|
| 322 |
+
|
| 323 |
+
TASKS = ["emos", "cat", "val", "aro", "dom"]
|
| 324 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 325 |
+
bb_params = [p for p in aud.parameters() if p.requires_grad]
|
| 326 |
+
head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
|
| 327 |
+
opt = torch.optim.AdamW([{"params": bb_params, "lr": LR_BACKBONE},
|
| 328 |
+
{"params": head_params, "lr": LR_HEAD}], weight_decay=WEIGHT_DECAY)
|
| 329 |
+
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
|
| 330 |
+
mse = nn.MSELoss()
|
| 331 |
+
|
| 332 |
+
def soft_ce(logits, target_dist):
|
| 333 |
+
return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
|
| 334 |
+
|
| 335 |
+
def forward_batch(b):
|
| 336 |
+
feat = aud_embed(b["input_values"].to(device), b["attn_mask"].to(device))
|
| 337 |
+
return heads(feat, b["tgt"].to(device))
|
| 338 |
+
|
| 339 |
+
def compute_loss(emos_p, cat_l, vad_p, b):
|
| 340 |
+
L = {}
|
| 341 |
+
L["emos"] = mse(emos_p, b["emos"].to(device))
|
| 342 |
+
L["cat"] = soft_ce(cat_l, b["cat"].to(device))
|
| 343 |
+
if HAS_VAD:
|
| 344 |
+
vt = b["vad"].to(device)
|
| 345 |
+
L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
|
| 346 |
+
else:
|
| 347 |
+
z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
|
| 348 |
+
if USE_UNCERTAINTY:
|
| 349 |
+
return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
|
| 350 |
+
return sum(L.values())
|
| 351 |
+
|
| 352 |
+
@torch.no_grad()
|
| 353 |
+
def evaluate():
|
| 354 |
+
aud.eval(); heads.eval()
|
| 355 |
+
P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
|
| 356 |
+
catP, catY = [], []
|
| 357 |
+
for b in va_loader:
|
| 358 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 359 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 360 |
+
P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
|
| 361 |
+
vad_p = vad_p.float().cpu().numpy()
|
| 362 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 363 |
+
P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
|
| 364 |
+
catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
|
| 365 |
+
out = {}
|
| 366 |
+
for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
|
| 367 |
+
out[t] = spearmanr(P[t], Y[t]).correlation
|
| 368 |
+
q = np.concatenate(catP); p = np.concatenate(catY)
|
| 369 |
+
out["cat_err"] = float(np.abs(q - p).sum(1).mean())
|
| 370 |
+
return out
|
| 371 |
+
|
| 372 |
+
def mean_srcc(m):
|
| 373 |
+
keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 374 |
+
return float(np.mean([m[k] for k in keys]))
|
| 375 |
+
|
| 376 |
+
CKPT_PATH = os.path.join(OUT_DIR, "ft_audeering_full.pt")
|
| 377 |
+
def save_full_ckpt(state, val_emos=float("nan")):
|
| 378 |
+
torch.save({"aud": state["aud"], "heads": state["heads"],
|
| 379 |
+
"emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
|
| 380 |
+
"AUD_DIM": AUD_DIM, "UNFREEZE_TOP_LAYERS": UNFREEZE_TOP_LAYERS,
|
| 381 |
+
"val_emos": float(val_emos)}, CKPT_PATH)
|
| 382 |
+
|
| 383 |
+
best, best_state, bad = -1e9, None, 0
|
| 384 |
+
for ep in range(1, EPOCHS + 1):
|
| 385 |
+
aud.train(); heads.train()
|
| 386 |
+
opt.zero_grad(); run = 0.0; nb = 0
|
| 387 |
+
for step, b in enumerate(tqdm(tr_loader, desc=f"epoch {ep}")):
|
| 388 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 389 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 390 |
+
loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
|
| 391 |
+
scaler.scale(loss).backward()
|
| 392 |
+
if (step + 1) % ACCUM == 0:
|
| 393 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 394 |
+
run += loss.item() * ACCUM; nb += 1
|
| 395 |
+
m = evaluate(); sc = mean_srcc(m)
|
| 396 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
|
| 397 |
+
print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
|
| 398 |
+
if sc > best:
|
| 399 |
+
best = sc
|
| 400 |
+
best_state = {"aud": {k: v.cpu().clone() for k, v in aud.state_dict().items()},
|
| 401 |
+
"heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
|
| 402 |
+
save_full_ckpt(best_state, m["emos"])
|
| 403 |
+
print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})")
|
| 404 |
+
bad = 0
|
| 405 |
+
else:
|
| 406 |
+
bad += 1
|
| 407 |
+
if bad >= PATIENCE:
|
| 408 |
+
print(f"Early stop ở epoch {ep}."); break
|
| 409 |
+
|
| 410 |
+
if best_state:
|
| 411 |
+
aud.load_state_dict(best_state["aud"]); heads.load_state_dict(best_state["heads"])
|
| 412 |
+
final = evaluate()
|
| 413 |
+
print("\n✅ VAL (nội bộ) — exp10 (fine-tune audeering):")
|
| 414 |
+
print(f" EMOS={final['emos']:.4f}", end="")
|
| 415 |
+
if HAS_VAD:
|
| 416 |
+
print(f" | VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} (exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
|
| 417 |
+
else:
|
| 418 |
+
print()
|
| 419 |
+
print(f" → so exp08: audeering {'mạnh' if HAS_VAD and final['val'] > EXP08['val'] else 'yếu/ngang'} ở VAL. "
|
| 420 |
+
f"Ensemble sẽ lấy trung bình 2 model.")
|
| 421 |
+
save_full_ckpt(best_state if best_state else {"aud": aud.state_dict(), "heads": heads.state_dict()}, final["emos"])
|
| 422 |
+
print(f"✅ Đã lưu {CKPT_PATH}. NHỚ Save Version!")
|
| 423 |
+
|
| 424 |
+
# %% [markdown]
|
| 425 |
+
# ## 6. Dự đoán DEV → predictions + answer_audeering.txt
|
| 426 |
+
|
| 427 |
+
# %%
|
| 428 |
+
def list_dev():
|
| 429 |
+
with open(DEV_SCP) as f:
|
| 430 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 431 |
+
|
| 432 |
+
dev_names = list_dev()
|
| 433 |
+
if LIMIT_DEV:
|
| 434 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 435 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 436 |
+
|
| 437 |
+
def load_exp07_qmos():
|
| 438 |
+
if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
|
| 439 |
+
import csv
|
| 440 |
+
d = {}
|
| 441 |
+
with open(EXP07_ANSWER) as f:
|
| 442 |
+
for row in csv.DictReader(f):
|
| 443 |
+
d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
|
| 444 |
+
print(f"✅ Mượn QMOS từ exp07: {len(d)//2} wav")
|
| 445 |
+
return d
|
| 446 |
+
return None
|
| 447 |
+
|
| 448 |
+
qmos_map = load_exp07_qmos()
|
| 449 |
+
if qmos_map is None:
|
| 450 |
+
print("ℹ️ Không có exp07 → QMOS bằng UTMOSv2.")
|
| 451 |
+
pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
|
| 452 |
+
import utmosv2
|
| 453 |
+
v2 = utmosv2.create_model(pretrained=True)
|
| 454 |
+
qmos_map = {}
|
| 455 |
+
for n in tqdm(dev_names, desc="UTMOSv2"):
|
| 456 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
|
| 457 |
+
if os.path.exists(wav):
|
| 458 |
+
o = v2.predict(input_path=wav)
|
| 459 |
+
qmos_map[n] = float(o["predicted_mos"]) if isinstance(o, dict) else float(o)
|
| 460 |
+
del v2; torch.cuda.empty_cache() if device == "cuda" else None
|
| 461 |
+
|
| 462 |
+
@torch.no_grad()
|
| 463 |
+
def predict_emotion(sid):
|
| 464 |
+
iv = load_iv(sid)
|
| 465 |
+
if iv is None:
|
| 466 |
+
return None
|
| 467 |
+
aud.eval(); heads.eval()
|
| 468 |
+
ivt = torch.from_numpy(iv).unsqueeze(0).to(device)
|
| 469 |
+
am = torch.ones((1, len(iv)), dtype=torch.long, device=device)
|
| 470 |
+
tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
|
| 471 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 472 |
+
feat = aud_embed(ivt, am)
|
| 473 |
+
emos_p, cat_l, vad_p = heads(feat, tgt)
|
| 474 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 475 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 476 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 477 |
+
return emos, cat5, vad3
|
| 478 |
+
|
| 479 |
+
def fmt_cat(p5):
|
| 480 |
+
return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 481 |
+
|
| 482 |
+
dev_pred = {} # name -> (emos, cat5, vad3)
|
| 483 |
+
with open(os.path.join(OUT_DIR, "answer_audeering.txt"), "w") as f:
|
| 484 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 485 |
+
for name in tqdm(dev_names, desc="answer_aud"):
|
| 486 |
+
sid = stem(name)
|
| 487 |
+
pr = predict_emotion(sid)
|
| 488 |
+
if pr is None:
|
| 489 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])
|
| 490 |
+
else:
|
| 491 |
+
emos, cat5, vad3 = pr
|
| 492 |
+
dev_pred[name] = (emos, cat5, vad3)
|
| 493 |
+
qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
|
| 494 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 495 |
+
print("Đã ghi answer_audeering.txt")
|
| 496 |
+
|
| 497 |
+
# %% [markdown]
|
| 498 |
+
# ## 7. ENSEMBLE với exp08 → answer.txt cuối (trung bình cột VAD)
|
| 499 |
+
# Lấy answer.txt exp08 làm nền; cột trong `ENSEMBLE_COLS` = trung bình (exp08 + exp10). Còn lại giữ exp08.
|
| 500 |
+
|
| 501 |
+
# %%
|
| 502 |
+
import csv
|
| 503 |
+
COL_IDX = {"QMOS": 1, "EMOS": 2, "VAL": 4, "ARO": 5, "DOM": 6} # vị trí cột trong answer.txt
|
| 504 |
+
AUD_VAL = {"EMOS": lambda p: p[0], "VAL": lambda p: p[2][0], "ARO": lambda p: p[2][1], "DOM": lambda p: p[2][2]}
|
| 505 |
+
|
| 506 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 507 |
+
if EXP08_ANSWER and os.path.exists(EXP08_ANSWER):
|
| 508 |
+
with open(EXP08_ANSWER) as f:
|
| 509 |
+
rows = list(csv.reader(f))
|
| 510 |
+
header, body = rows[0], rows[1:]
|
| 511 |
+
n_ens = 0
|
| 512 |
+
with open(answer_path, "w") as f:
|
| 513 |
+
f.write(",".join(header) + "\n")
|
| 514 |
+
for r in body:
|
| 515 |
+
name = r[0]; sid = stem(name)
|
| 516 |
+
pr = dev_pred.get(name) or dev_pred.get(sid)
|
| 517 |
+
if pr is not None:
|
| 518 |
+
for col in ENSEMBLE_COLS:
|
| 519 |
+
if col in COL_IDX and col in AUD_VAL:
|
| 520 |
+
v08 = float(r[COL_IDX[col]]); vaud = float(AUD_VAL[col](pr))
|
| 521 |
+
r[COL_IDX[col]] = f"{0.5*(v08+vaud):.6g}"
|
| 522 |
+
n_ens += 1
|
| 523 |
+
f.write(",".join(r) + "\n")
|
| 524 |
+
print(f"✅ Ensemble {ENSEMBLE_COLS}: {n_ens} dòng → {answer_path} (nền exp08 + trung bình audeering)")
|
| 525 |
+
else:
|
| 526 |
+
print("ℹ️ Không có EXP08_ANSWER → answer.txt = answer_audeering.txt (chỉ audeering, chưa ensemble).")
|
| 527 |
+
import shutil
|
| 528 |
+
shutil.copy(os.path.join(OUT_DIR, "answer_audeering.txt"), answer_path)
|
| 529 |
+
|
| 530 |
+
# %% [markdown]
|
| 531 |
+
# ## 8. Validate + zip
|
| 532 |
+
|
| 533 |
+
# %%
|
| 534 |
+
def validate(path):
|
| 535 |
+
with open(path) as f:
|
| 536 |
+
rows = list(csv.reader(f))
|
| 537 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0], "Header sai"
|
| 538 |
+
for i, r in enumerate(rows[1:], 2):
|
| 539 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 540 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 541 |
+
|
| 542 |
+
validate(answer_path)
|
| 543 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp10_ensemble.zip answer.txt && unzip -l submission_track2_exp10_ensemble.zip")
|
| 544 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp10_ensemble.zip"))
|
| 545 |
+
|
| 546 |
+
# %% [markdown]
|
| 547 |
+
# ## Ghi chú
|
| 548 |
+
# - **Hướng A (T4-an toàn):** fine-tune audeering RIÊNG (1 backbone) → ensemble VAD với exp08 → KHÔNG OOM.
|
| 549 |
+
# - **Đọc mục 5:** audeering VAL/ARO/DOM có ≥ exp08 không? Nếu ngang/hơn → ensemble đáng giá.
|
| 550 |
+
# - **Ensemble (mục 7):** mặc định trung bình VAL/ARO/DOM. Thêm "EMOS" vào `ENSEMBLE_COLS` nếu audeering EMOS tốt.
|
| 551 |
+
# - **Checkpoint:** lưu `ft_audeering_full.pt` mỗi best (kernel chết vẫn còn). Save Version sau khi xong.
|
| 552 |
+
# - QMOS vẫn mượn exp07 (0.548). So sánh: nộp answer.txt ensemble vs exp08 thuần để xem ensemble có nhích VAD.
|
| 553 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (exp10).
|
track2/exp11_finetune_joint.ipynb
ADDED
|
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "a2dce1b4",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp11 (FINE-TUNE ĐỒNG THỜI WavLM + audeering, FUSION 1 model) — Kaggle T4\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Khác exp08:** exp08 chỉ fine-tune WavLM, audeering **đóng băng** (frozen, cache). exp11 **MỞ BĂNG CẢ HAI**\n",
|
| 11 |
+
"backbone và fuse đặc trưng **trong cùng 1 model** → cả hai cùng học cho bài MOS cảm xúc 2026.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"```\n",
|
| 14 |
+
" wav ─┬─► WavLM-large (warm-start exp08, TRAINABLE: mở băng N lớp trên) ─► pool ─► emb_wavlm ┐\n",
|
| 15 |
+
" └─► audeering MSP (TRAINABLE: mở băng N lớp trên) ─► pool ─► [emb_aud(1024) | vad3] ──────┼─► TRUNK ─┬─► EMOS (+target)\n",
|
| 16 |
+
" ┘ ├─► CAT (5)\n",
|
| 17 |
+
" └─► VAD (3)\n",
|
| 18 |
+
" QMOS: KHÔNG train ở đây → mượn cột QMOS exp07 (0.548) hoặc UTMOSv2.\n",
|
| 19 |
+
"```\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"## Vì sao \"feature fusion + fine-tune cả 2\" (khác ensemble exp10)\n",
|
| 22 |
+
"- **exp10 = ensemble:** 2 model RIÊNG → trung bình cột VAD ở mức answer. An toàn nhưng 2 model không \"nói chuyện\".\n",
|
| 23 |
+
"- **exp11 = fusion:** 1 model, 2 backbone fuse Ở TRONG → trunk học phối hợp cả hai góc nhìn (WavLM categorical +\n",
|
| 24 |
+
" audeering dimensional) → kỳ vọng mạnh hơn nếu không OOM/overfit.\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"## ⚠️ ĐÁNH ĐỔI PHẢI BIẾT — đây là cấu hình NẶNG nhất (2 backbone large cùng có gradient)\n",
|
| 27 |
+
"- **Rủi ro OOM cao trên T4 (16GB).** Đã bật sẵn mọi cách giảm bộ nhớ: `BATCH=1` + grad-accum,\n",
|
| 28 |
+
" gradient-checkpointing CẢ 2 backbone, AMP fp16, `MAX_SECONDS=6`, mở băng ÍT lớp (mặc định 4 mỗi backbone).\n",
|
| 29 |
+
"- Nếu vẫn OOM: giảm `UNFREEZE_WAVLM`/`UNFREEZE_AUD` → 2, giảm `MAX_SECONDS` → 5, tăng `ACCUM`.\n",
|
| 30 |
+
"- **Chậm + đốt giờ GPU** (2 backbone forward+backward, không cache được). **LẦN ĐẦU BẮT BUỘC `LIMIT_TRAIN=300`,\n",
|
| 31 |
+
" `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.\n",
|
| 32 |
+
"- **Lưới an toàn:** đừng đốt lượt nộp — chỉ nộp khi exp11 thắng exp08 (0.811) TRÊN VAL NỘI BỘ.\n",
|
| 33 |
+
"\n",
|
| 34 |
+
"**Cách chạy:** GPU **T4** + Internet **On** → Add Input (data + checkpoint exp08 + [tùy chọn] answer exp07) →\n",
|
| 35 |
+
"sửa slug cell 0 → Run All. Ghi config→kết quả→nhận xét vào `docs/04_experiments_log.md` (exp11)."
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "markdown",
|
| 40 |
+
"id": "f6d884a7",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"source": [
|
| 43 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 44 |
+
]
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"cell_type": "code",
|
| 48 |
+
"execution_count": null,
|
| 49 |
+
"id": "6eca25f7",
|
| 50 |
+
"metadata": {},
|
| 51 |
+
"outputs": [],
|
| 52 |
+
"source": [
|
| 53 |
+
"import os\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"DATA_ROOT = \"/kaggle/input/datasets/minhtoan2/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug cho khớp Add Input\n",
|
| 56 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 57 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
|
| 58 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
|
| 59 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 60 |
+
"\n",
|
| 61 |
+
"# ── Warm-start / RESUME: trỏ tới 1 trong 2 loại checkpoint ───────────────────\n",
|
| 62 |
+
"# • ft_emotion_full_20epoch.pt (exp08): có 'wavlm'+'heads' → WARM-START (audeering từ pretrained gốc).\n",
|
| 63 |
+
"# • ft_joint_full.pt (exp11): có thêm 'aud'+'aud_head' → RESUME ĐỦ (khôi phục cả 2 backbone đã fine-tune).\n",
|
| 64 |
+
"# Notebook TỰ nhận biết theo key trong checkpoint. Để \"\" nếu train WavLM từ SAILER trắng.\n",
|
| 65 |
+
"WARMSTART_CKPT = \"/kaggle/input/ft-joint-full/ft_joint_full.pt\" # << exp08 ckpt (warm-start) HOẶC exp11 ckpt (resume)\n",
|
| 66 |
+
"\n",
|
| 67 |
+
"# Mượn cột QMOS exp07 (0.548). Không có → UTMOSv2.\n",
|
| 68 |
+
"EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) answer.txt exp07; không có → UTMOSv2\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"# ── Fine-tune / siêu tham số (CẤU HÌNH NẶNG — đã tối ưu cho T4) ───────────────\n",
|
| 73 |
+
"DEVICE = \"cuda\"\n",
|
| 74 |
+
"SR = 16000\n",
|
| 75 |
+
"MAX_SECONDS = 6 # ↓ so exp08 (8) để tiết kiệm VRAM (2 backbone)\n",
|
| 76 |
+
"UNFREEZE_WAVLM = 4 # số lớp encoder WavLM mở băng (OOM → 2)\n",
|
| 77 |
+
"UNFREEZE_AUD = 4 # số lớp encoder audeering mở băng (OOM → 2)\n",
|
| 78 |
+
"TRUNK_HIDDEN = 512 # PHẢI khớp checkpoint exp08 nếu warm-start heads\n",
|
| 79 |
+
"HEAD_HIDDEN = 128 # PHẢI khớp checkpoint exp08\n",
|
| 80 |
+
"DROPOUT = 0.3\n",
|
| 81 |
+
"LR_BACKBONE = 1e-5 # LR chung cho 2 backbone\n",
|
| 82 |
+
"LR_HEAD = 1e-3\n",
|
| 83 |
+
"RESUME_LR_SCALE = 1.0 # <1.0 để GIẢM LR khi resume (vd 0.5 nếu val đã chững) — nhân vào cả 2 nhóm LR\n",
|
| 84 |
+
"WEIGHT_DECAY = 1e-5\n",
|
| 85 |
+
"EPOCHS = 12\n",
|
| 86 |
+
"PATIENCE = 4 # dừng khi val không lên; LUÔN giữ best\n",
|
| 87 |
+
"BATCH = 1 # ⚠️ 2 backbone large → batch nhỏ\n",
|
| 88 |
+
"ACCUM = 16 # effective batch = 16\n",
|
| 89 |
+
"VAL_FRAC = 0.10\n",
|
| 90 |
+
"SEED = 42\n",
|
| 91 |
+
"USE_AMP = True\n",
|
| 92 |
+
"USE_GRAD_CKPT = True\n",
|
| 93 |
+
"USE_UNCERTAINTY = True\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
|
| 96 |
+
"LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
|
| 97 |
+
"\n",
|
| 98 |
+
"EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751} # mốc để so\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 101 |
+
"_EMO_ALIAS = {\n",
|
| 102 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 103 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 104 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 105 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 106 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 107 |
+
"}\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"def norm_emotion(label):\n",
|
| 110 |
+
" key = str(label).strip().lower()\n",
|
| 111 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"def stem(p):\n",
|
| 114 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 115 |
+
"\n",
|
| 116 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 117 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 118 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
|
| 119 |
+
"print((\" ✅ \" if (WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT)) else \" ⚠️ KHÔNG có \") + str(WARMSTART_CKPT)\n",
|
| 120 |
+
" + (\" → warm-start\" if (WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT)) else \" → train từ SAILER trắng\"))"
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "markdown",
|
| 125 |
+
"id": "b15a7e01",
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"source": [
|
| 128 |
+
"## 1. Cài đặt + tải code SAILER (dựng đúng kiến trúc WavLM)"
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "code",
|
| 133 |
+
"execution_count": null,
|
| 134 |
+
"id": "2aacc36b",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"outputs": [],
|
| 137 |
+
"source": [
|
| 138 |
+
"import sys, subprocess\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"def pip_install(*pkgs):\n",
|
| 141 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 142 |
+
"\n",
|
| 143 |
+
"pip_install(\"transformers\", \"huggingface_hub\", \"safetensors\", \"loralib\", \"speechbrain\",\n",
|
| 144 |
+
" \"speechmos\", \"librosa\", \"soundfile\", \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 147 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 148 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 149 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 150 |
+
"if REPO_DIR not in sys.path:\n",
|
| 151 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
{
|
| 155 |
+
"cell_type": "markdown",
|
| 156 |
+
"id": "3021edeb",
|
| 157 |
+
"metadata": {},
|
| 158 |
+
"source": [
|
| 159 |
+
"## 2A. WavLM TRAINABLE (warm-start SAILER / checkpoint exp08)"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"cell_type": "code",
|
| 164 |
+
"execution_count": null,
|
| 165 |
+
"id": "d2459502",
|
| 166 |
+
"metadata": {
|
| 167 |
+
"lines_to_next_cell": 1
|
| 168 |
+
},
|
| 169 |
+
"outputs": [],
|
| 170 |
+
"source": [
|
| 171 |
+
"import torch\n",
|
| 172 |
+
"import torch.nn as nn\n",
|
| 173 |
+
"import torch.nn.functional as F\n",
|
| 174 |
+
"import numpy as np\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 177 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"# Nạp checkpoint exp08 (nếu có) — lấy cả 'wavlm', 'heads', thống kê chuẩn hóa\n",
|
| 180 |
+
"ckpt = None\n",
|
| 181 |
+
"if WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT):\n",
|
| 182 |
+
" ckpt = torch.load(WARMSTART_CKPT, map_location=\"cpu\", weights_only=False)\n",
|
| 183 |
+
" print(\"✅ Nạp checkpoint warm-start:\", WARMSTART_CKPT, \"| keys:\", list(ckpt.keys()))\n",
|
| 184 |
+
" if \"wavlm\" not in ckpt:\n",
|
| 185 |
+
" print(\" ⚠️ Checkpoint KHÔNG có 'wavlm' (chỉ heads?) → vẫn dựng WavLM từ SAILER, chỉ warm-start heads nếu khớp.\")\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"def find_hf_backbone(module):\n",
|
| 188 |
+
" cands = []\n",
|
| 189 |
+
" for name, m in module.named_modules():\n",
|
| 190 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 191 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 192 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 193 |
+
" cands.append((name, m))\n",
|
| 194 |
+
" if not cands:\n",
|
| 195 |
+
" return None, None\n",
|
| 196 |
+
" cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
|
| 197 |
+
" return cands[0]\n",
|
| 198 |
+
"\n",
|
| 199 |
+
"wavlm = None\n",
|
| 200 |
+
"try:\n",
|
| 201 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 202 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 203 |
+
" name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 204 |
+
" if wavlm is not None:\n",
|
| 205 |
+
" print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'\")\n",
|
| 206 |
+
"except Exception as e:\n",
|
| 207 |
+
" print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"if wavlm is None:\n",
|
| 210 |
+
" from transformers import WavLMModel\n",
|
| 211 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 212 |
+
" print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
|
| 213 |
+
"\n",
|
| 214 |
+
"wavlm = wavlm.to(device)\n",
|
| 215 |
+
"WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 216 |
+
"wavlm.config.layerdrop = 0.0 # ⚠️ tắt layerdrop khi dùng gradient-checkpointing (tránh CheckpointError)\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"# Đè trọng số đã fine-tune từ checkpoint exp08 (nếu có)\n",
|
| 219 |
+
"if ckpt is not None and \"wavlm\" in ckpt:\n",
|
| 220 |
+
" miss, unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
|
| 221 |
+
" print(f\"🔁 load wavlm từ checkpoint exp08: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).\")\n",
|
| 222 |
+
"\n",
|
| 223 |
+
"# Đóng băng partial: chỉ mở UNFREEZE_WAVLM lớp trên\n",
|
| 224 |
+
"for p in wavlm.parameters():\n",
|
| 225 |
+
" p.requires_grad = False\n",
|
| 226 |
+
"_wl = wavlm.encoder.layers\n",
|
| 227 |
+
"for layer in _wl[max(0, len(_wl) - UNFREEZE_WAVLM):]:\n",
|
| 228 |
+
" for p in layer.parameters():\n",
|
| 229 |
+
" p.requires_grad = True\n",
|
| 230 |
+
"print(f\"WavLM: {len(_wl)} lớp · mở băng {min(UNFREEZE_WAVLM, len(_wl))} → \"\n",
|
| 231 |
+
" f\"{sum(p.numel() for p in wavlm.parameters() if p.requires_grad)/1e6:.1f}M param train (dim {WAVLM_DIM})\")\n",
|
| 232 |
+
"\n",
|
| 233 |
+
"if USE_GRAD_CKPT:\n",
|
| 234 |
+
" wavlm.gradient_checkpointing_enable()\n",
|
| 235 |
+
" if hasattr(wavlm, \"enable_input_require_grads\"):\n",
|
| 236 |
+
" wavlm.enable_input_require_grads()\n",
|
| 237 |
+
"\n",
|
| 238 |
+
"def masked_mean(hidden, attn_mask, model):\n",
|
| 239 |
+
" if attn_mask is None:\n",
|
| 240 |
+
" return hidden.mean(dim=1)\n",
|
| 241 |
+
" try:\n",
|
| 242 |
+
" fm = model._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
|
| 243 |
+
" except Exception:\n",
|
| 244 |
+
" return hidden.mean(dim=1)\n",
|
| 245 |
+
" fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
|
| 246 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"def wavlm_embed(input_values, attn_mask):\n",
|
| 249 |
+
" out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
|
| 250 |
+
" return masked_mean(out, attn_mask, wavlm)"
|
| 251 |
+
]
|
| 252 |
+
},
|
| 253 |
+
{
|
| 254 |
+
"cell_type": "markdown",
|
| 255 |
+
"id": "20a0c88d",
|
| 256 |
+
"metadata": {},
|
| 257 |
+
"source": [
|
| 258 |
+
"## 2B. audeering TRAINABLE (mở băng — khác exp08 là frozen)\n",
|
| 259 |
+
"Nạp backbone tay + head dimensional gốc; mở băng `UNFREEZE_AUD` lớp trên. Đặc trưng fuse = [hidden(1024) | vad3]."
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"cell_type": "code",
|
| 264 |
+
"execution_count": null,
|
| 265 |
+
"id": "9360d566",
|
| 266 |
+
"metadata": {
|
| 267 |
+
"lines_to_next_cell": 1
|
| 268 |
+
},
|
| 269 |
+
"outputs": [],
|
| 270 |
+
"source": [
|
| 271 |
+
"from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 272 |
+
"from huggingface_hub import hf_hub_download\n",
|
| 273 |
+
"\n",
|
| 274 |
+
"AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 275 |
+
"aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 276 |
+
"aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 277 |
+
"aud = Wav2Vec2Model(aud_cfg)\n",
|
| 278 |
+
"try:\n",
|
| 279 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 280 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 281 |
+
"except Exception:\n",
|
| 282 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 283 |
+
"bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 284 |
+
"aud.load_state_dict(bb_sd, strict=False)\n",
|
| 285 |
+
"_hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 286 |
+
"aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd[\"classifier.out_proj.weight\"].shape[0]))\n",
|
| 287 |
+
"aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 288 |
+
"aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 289 |
+
"aud = aud.to(device); aud_head = aud_head.to(device)\n",
|
| 290 |
+
"aud.config.layerdrop = 0.0 # ⚠️ tắt layerdrop khi dùng gradient-checkpointing (tránh CheckpointError)\n",
|
| 291 |
+
"AUD_DIM = _hid + 3 # = 1027 (khớp exp08 để warm-start heads)\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"# RESUME: nếu checkpoint là ft_joint_full.pt (có 'aud') → khôi phục audeering ĐÃ fine-tune (đè pretrained)\n",
|
| 294 |
+
"if ckpt is not None and \"aud\" in ckpt:\n",
|
| 295 |
+
" amiss, aunexp = aud.load_state_dict(ckpt[\"aud\"], strict=False)\n",
|
| 296 |
+
" print(f\"🔁 RESUME audeering từ checkpoint: thiếu {len(amiss)} / dư {len(aunexp)} key (kỳ vọng ~0).\")\n",
|
| 297 |
+
" if \"aud_head\" in ckpt:\n",
|
| 298 |
+
" aud_head.load_state_dict(ckpt[\"aud_head\"]); print(\"🔁 RESUME aud_head từ checkpoint.\")\n",
|
| 299 |
+
"else:\n",
|
| 300 |
+
" print(\"ℹ️ Checkpoint không có 'aud' → audeering khởi từ pretrained gốc (chế độ warm-start exp08).\")\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"# Đóng băng partial audeering: mở UNFREEZE_AUD lớp trên + head dimensional luôn trainable\n",
|
| 303 |
+
"for p in aud.parameters():\n",
|
| 304 |
+
" p.requires_grad = False\n",
|
| 305 |
+
"_al = aud.encoder.layers\n",
|
| 306 |
+
"for layer in _al[max(0, len(_al) - UNFREEZE_AUD):]:\n",
|
| 307 |
+
" for p in layer.parameters():\n",
|
| 308 |
+
" p.requires_grad = True\n",
|
| 309 |
+
"for p in aud_head.parameters():\n",
|
| 310 |
+
" p.requires_grad = True\n",
|
| 311 |
+
"print(f\"audeering: {len(_al)} lớp · mở băng {min(UNFREEZE_AUD, len(_al))} → \"\n",
|
| 312 |
+
" f\"{sum(p.numel() for p in aud.parameters() if p.requires_grad)/1e6:.1f}M param train (hidden {_hid}, fuse dim {AUD_DIM})\")\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"if USE_GRAD_CKPT:\n",
|
| 315 |
+
" aud.gradient_checkpointing_enable()\n",
|
| 316 |
+
" if hasattr(aud, \"enable_input_require_grads\"):\n",
|
| 317 |
+
" aud.enable_input_require_grads()\n",
|
| 318 |
+
"\n",
|
| 319 |
+
"def aud_embed(input_values, attn_mask):\n",
|
| 320 |
+
" \"\"\"Trả về [hidden(1024) | vad3] — vad3 từ head dimensional gốc, theo thứ tự VAL,ARO,DOM.\"\"\"\n",
|
| 321 |
+
" h = masked_mean(aud(input_values, attention_mask=attn_mask).last_hidden_state, attn_mask, aud)\n",
|
| 322 |
+
" out = aud_head(h) # [B,3] thứ tự gốc audeering: (arousal, dominance, valence)\n",
|
| 323 |
+
" vad = torch.stack([1 + 4 * out[:, 2], 1 + 4 * out[:, 0], 1 + 4 * out[:, 1]], dim=1) # → VAL,ARO,DOM\n",
|
| 324 |
+
" return torch.cat([h, vad], dim=1)"
|
| 325 |
+
]
|
| 326 |
+
},
|
| 327 |
+
{
|
| 328 |
+
"cell_type": "markdown",
|
| 329 |
+
"id": "7550bb8e",
|
| 330 |
+
"metadata": {},
|
| 331 |
+
"source": [
|
| 332 |
+
"## 3. Đọc & gộp nhãn theo wavID (như exp08)"
|
| 333 |
+
]
|
| 334 |
+
},
|
| 335 |
+
{
|
| 336 |
+
"cell_type": "code",
|
| 337 |
+
"execution_count": null,
|
| 338 |
+
"id": "86f83d8a",
|
| 339 |
+
"metadata": {},
|
| 340 |
+
"outputs": [],
|
| 341 |
+
"source": [
|
| 342 |
+
"import librosa\n",
|
| 343 |
+
"import pandas as pd\n",
|
| 344 |
+
"from tqdm.auto import tqdm\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"def load_target_emotions():\n",
|
| 347 |
+
" tgt = {}\n",
|
| 348 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 349 |
+
" for ln in f:\n",
|
| 350 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 351 |
+
" if len(parts) >= 2:\n",
|
| 352 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 353 |
+
" return tgt\n",
|
| 354 |
+
"\n",
|
| 355 |
+
"def _col(cols_map, *names, df=None, default_idx=None):\n",
|
| 356 |
+
" for n in names:\n",
|
| 357 |
+
" if n in cols_map:\n",
|
| 358 |
+
" return cols_map[n]\n",
|
| 359 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 360 |
+
"\n",
|
| 361 |
+
"def parse_emocat_votes(cell):\n",
|
| 362 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 363 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 364 |
+
" e = norm_emotion(tok)\n",
|
| 365 |
+
" if e in EMOTIONS5:\n",
|
| 366 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 367 |
+
" return v\n",
|
| 368 |
+
"\n",
|
| 369 |
+
"def load_train_labels():\n",
|
| 370 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 371 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 372 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
|
| 373 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 374 |
+
" val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 375 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 376 |
+
" assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
|
| 377 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 378 |
+
" rows = []\n",
|
| 379 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 380 |
+
" rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
|
| 381 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 382 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 383 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 384 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 385 |
+
" if cat_col:\n",
|
| 386 |
+
" for cell in g[cat_col]:\n",
|
| 387 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 388 |
+
" s = votes.sum()\n",
|
| 389 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
|
| 390 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 391 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 392 |
+
" rows.append(rec)\n",
|
| 393 |
+
" return pd.DataFrame(rows)\n",
|
| 394 |
+
"\n",
|
| 395 |
+
"target_map = load_target_emotions()\n",
|
| 396 |
+
"train_df = load_train_labels()\n",
|
| 397 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 398 |
+
"print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
|
| 399 |
+
]
|
| 400 |
+
},
|
| 401 |
+
{
|
| 402 |
+
"cell_type": "markdown",
|
| 403 |
+
"id": "02e003af",
|
| 404 |
+
"metadata": {},
|
| 405 |
+
"source": [
|
| 406 |
+
"## 4. Dataset/loader — trả về CẢ raw wave (cho WavLM) + input_values audeering\n",
|
| 407 |
+
"Hai backbone cần đầu vào khác nhau: WavLM nhận wave thô; audeering nhận wave đã chuẩn hóa bởi processor.\n",
|
| 408 |
+
"Cùng độ dài → dùng chung attention mask theo mức sample."
|
| 409 |
+
]
|
| 410 |
+
},
|
| 411 |
+
{
|
| 412 |
+
"cell_type": "code",
|
| 413 |
+
"execution_count": null,
|
| 414 |
+
"id": "f91d8e80",
|
| 415 |
+
"metadata": {},
|
| 416 |
+
"outputs": [],
|
| 417 |
+
"source": [
|
| 418 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 419 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 420 |
+
"\n",
|
| 421 |
+
"train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
|
| 422 |
+
"if LIMIT_TRAIN:\n",
|
| 423 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 424 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 425 |
+
"\n",
|
| 426 |
+
"# Chuẩn hóa: lấy TỪ checkpoint nếu warm-start (để khớp head đã train); không thì fit từ data.\n",
|
| 427 |
+
"if ckpt is not None and \"emos_mu\" in ckpt:\n",
|
| 428 |
+
" emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
|
| 429 |
+
" vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
|
| 430 |
+
" print(f\"Chuẩn hóa TỪ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
|
| 431 |
+
"else:\n",
|
| 432 |
+
" def _zfit(a):\n",
|
| 433 |
+
" a = np.asarray(a, dtype=np.float32); return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
|
| 434 |
+
" emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
|
| 435 |
+
" if HAS_VAD:\n",
|
| 436 |
+
" vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], np.float32)\n",
|
| 437 |
+
" vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], np.float32)\n",
|
| 438 |
+
" else:\n",
|
| 439 |
+
" vad_mu = np.zeros(3, np.float32); vad_sd = np.ones(3, np.float32)\n",
|
| 440 |
+
" print(f\"Chuẩn hóa fit từ data: emos μ={emos_mu:.3f} σ={emos_sd:.3f}\")\n",
|
| 441 |
+
"\n",
|
| 442 |
+
"def onehot_target(tgt):\n",
|
| 443 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 444 |
+
" if tgt in EMOTIONS5:\n",
|
| 445 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 446 |
+
" return v\n",
|
| 447 |
+
"\n",
|
| 448 |
+
"def load_pair(sid):\n",
|
| 449 |
+
" \"\"\"Trả về (wave_thô, iv_audeering) cùng độ dài; None nếu thiếu file.\"\"\"\n",
|
| 450 |
+
" p = os.path.join(WAV_DIR, sid if str(sid).endswith(\".wav\") else str(sid) + \".wav\")\n",
|
| 451 |
+
" if not os.path.exists(p):\n",
|
| 452 |
+
" return None\n",
|
| 453 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 454 |
+
" wave = wave[: MAX_SECONDS * SR].astype(np.float32)\n",
|
| 455 |
+
" iv = np.asarray(aud_proc(wave, sampling_rate=SR).input_values[0], dtype=np.float32)\n",
|
| 456 |
+
" return wave, iv\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"class JointDataset(Dataset):\n",
|
| 459 |
+
" def __init__(self, stems):\n",
|
| 460 |
+
" self.stems = [s for s in stems if load_pair(s) is not None]\n",
|
| 461 |
+
" def __len__(self):\n",
|
| 462 |
+
" return len(self.stems)\n",
|
| 463 |
+
" def __getitem__(self, i):\n",
|
| 464 |
+
" s = self.stems[i]\n",
|
| 465 |
+
" wave, iv = load_pair(s)\n",
|
| 466 |
+
" emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
|
| 467 |
+
" if HAS_VAD:\n",
|
| 468 |
+
" vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
|
| 469 |
+
" else:\n",
|
| 470 |
+
" vad = np.zeros(3, dtype=np.float32)\n",
|
| 471 |
+
" cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
|
| 472 |
+
" return {\"wave\": wave, \"iv\": iv, \"tgt\": onehot_target(target_map.get(s)),\n",
|
| 473 |
+
" \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
|
| 474 |
+
" \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
|
| 475 |
+
" \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"def collate(batch):\n",
|
| 478 |
+
" L = max(len(b[\"wave\"]) for b in batch)\n",
|
| 479 |
+
" waves = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 480 |
+
" ivs = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 481 |
+
" mask = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 482 |
+
" for i, b in enumerate(batch):\n",
|
| 483 |
+
" n = len(b[\"wave\"])\n",
|
| 484 |
+
" waves[i, :n] = b[\"wave\"]; ivs[i, :len(b[\"iv\"])] = b[\"iv\"]; mask[i, :n] = 1.0\n",
|
| 485 |
+
" return {\n",
|
| 486 |
+
" \"wave\": torch.from_numpy(waves), \"iv\": torch.from_numpy(ivs), \"attn_mask\": torch.from_numpy(mask).long(),\n",
|
| 487 |
+
" \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
|
| 488 |
+
" \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
|
| 489 |
+
" \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
|
| 490 |
+
" \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
|
| 491 |
+
" \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
|
| 492 |
+
" \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
|
| 493 |
+
" }\n",
|
| 494 |
+
"\n",
|
| 495 |
+
"ds = JointDataset(train_stems)\n",
|
| 496 |
+
"print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
|
| 497 |
+
"tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
|
| 498 |
+
"tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
|
| 499 |
+
"va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
|
| 500 |
+
]
|
| 501 |
+
},
|
| 502 |
+
{
|
| 503 |
+
"cell_type": "markdown",
|
| 504 |
+
"id": "0f85c871",
|
| 505 |
+
"metadata": {},
|
| 506 |
+
"source": [
|
| 507 |
+
"## 5. Heads (warm-start exp08 nếu khớp) + optimizer 2 backbone + train loop"
|
| 508 |
+
]
|
| 509 |
+
},
|
| 510 |
+
{
|
| 511 |
+
"cell_type": "code",
|
| 512 |
+
"execution_count": null,
|
| 513 |
+
"id": "b0a71176",
|
| 514 |
+
"metadata": {
|
| 515 |
+
"lines_to_next_cell": 1
|
| 516 |
+
},
|
| 517 |
+
"outputs": [],
|
| 518 |
+
"source": [
|
| 519 |
+
"from scipy.stats import spearmanr\n",
|
| 520 |
+
"\n",
|
| 521 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 522 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 523 |
+
"TRUNK_IN = WAVLM_DIM + AUD_DIM\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"class EmoHeads(nn.Module):\n",
|
| 526 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 527 |
+
" super().__init__()\n",
|
| 528 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 529 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 530 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 531 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 532 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 533 |
+
" def forward(self, feat, tgt):\n",
|
| 534 |
+
" h = self.trunk(feat)\n",
|
| 535 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 536 |
+
"\n",
|
| 537 |
+
"heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
|
| 538 |
+
"if ckpt is not None and \"heads\" in ckpt:\n",
|
| 539 |
+
" hmiss, hunexp = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
|
| 540 |
+
" if len(hmiss) == 0 and len(hunexp) == 0:\n",
|
| 541 |
+
" print(\"🔁 warm-start heads từ exp08: KHỚP hoàn toàn.\")\n",
|
| 542 |
+
" else:\n",
|
| 543 |
+
" print(f\"⚠️ heads exp08 lệch (thiếu {len(hmiss)}/dư {len(hunexp)}) → có thể TRUNK_IN khác. Heads init mới phần lệch.\")\n",
|
| 544 |
+
"print(f\"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM})\")\n",
|
| 545 |
+
"\n",
|
| 546 |
+
"TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 547 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 548 |
+
"bb_params = [p for p in wavlm.parameters() if p.requires_grad] + \\\n",
|
| 549 |
+
" [p for p in aud.parameters() if p.requires_grad] + list(aud_head.parameters())\n",
|
| 550 |
+
"head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 551 |
+
"opt = torch.optim.AdamW([{\"params\": bb_params, \"lr\": LR_BACKBONE * RESUME_LR_SCALE},\n",
|
| 552 |
+
" {\"params\": head_params, \"lr\": LR_HEAD * RESUME_LR_SCALE}], weight_decay=WEIGHT_DECAY)\n",
|
| 553 |
+
"scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
|
| 554 |
+
"mse = nn.MSELoss()\n",
|
| 555 |
+
"\n",
|
| 556 |
+
"def soft_ce(logits, target_dist):\n",
|
| 557 |
+
" return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
|
| 558 |
+
"\n",
|
| 559 |
+
"def forward_batch(b):\n",
|
| 560 |
+
" am = b[\"attn_mask\"].to(device)\n",
|
| 561 |
+
" fw = wavlm_embed(b[\"wave\"].to(device), am)\n",
|
| 562 |
+
" fa = aud_embed(b[\"iv\"].to(device), am)\n",
|
| 563 |
+
" return heads(torch.cat([fw, fa], dim=1), b[\"tgt\"].to(device))\n",
|
| 564 |
+
"\n",
|
| 565 |
+
"def compute_loss(emos_p, cat_l, vad_p, b):\n",
|
| 566 |
+
" L = {}\n",
|
| 567 |
+
" L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
|
| 568 |
+
" L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
|
| 569 |
+
" if HAS_VAD:\n",
|
| 570 |
+
" vt = b[\"vad\"].to(device)\n",
|
| 571 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
|
| 572 |
+
" else:\n",
|
| 573 |
+
" z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 574 |
+
" if USE_UNCERTAINTY:\n",
|
| 575 |
+
" return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
|
| 576 |
+
" return sum(L.values())\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"def set_train(flag):\n",
|
| 579 |
+
" wavlm.train(flag); aud.train(flag); aud_head.train(flag); heads.train(flag)\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"@torch.no_grad()\n",
|
| 582 |
+
"def evaluate():\n",
|
| 583 |
+
" set_train(False)\n",
|
| 584 |
+
" P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
|
| 585 |
+
" catP, catY = [], []\n",
|
| 586 |
+
" for b in va_loader:\n",
|
| 587 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 588 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 589 |
+
" P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
|
| 590 |
+
" vad_p = vad_p.float().cpu().numpy()\n",
|
| 591 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 592 |
+
" P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
|
| 593 |
+
" catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
|
| 594 |
+
" out = {}\n",
|
| 595 |
+
" for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
|
| 596 |
+
" out[t] = spearmanr(P[t], Y[t]).correlation\n",
|
| 597 |
+
" q = np.concatenate(catP); p = np.concatenate(catY)\n",
|
| 598 |
+
" out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
|
| 599 |
+
" return out\n",
|
| 600 |
+
"\n",
|
| 601 |
+
"def mean_srcc(m):\n",
|
| 602 |
+
" keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 603 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 604 |
+
"\n",
|
| 605 |
+
"def snapshot():\n",
|
| 606 |
+
" return {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
|
| 607 |
+
" \"aud\": {k: v.cpu().clone() for k, v in aud.state_dict().items()},\n",
|
| 608 |
+
" \"aud_head\": {k: v.cpu().clone() for k, v in aud_head.state_dict().items()},\n",
|
| 609 |
+
" \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
|
| 610 |
+
"\n",
|
| 611 |
+
"CKPT_PATH = os.path.join(OUT_DIR, \"ft_joint_full.pt\")\n",
|
| 612 |
+
"def save_full(state, val_emos=float(\"nan\")):\n",
|
| 613 |
+
" torch.save({**state, \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
|
| 614 |
+
" \"WAVLM_DIM\": WAVLM_DIM, \"AUD_DIM\": AUD_DIM,\n",
|
| 615 |
+
" \"UNFREEZE_WAVLM\": UNFREEZE_WAVLM, \"UNFREEZE_AUD\": UNFREEZE_AUD,\n",
|
| 616 |
+
" \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
|
| 617 |
+
"\n",
|
| 618 |
+
"# Init best từ trạng thái warm-start hiện tại → chỉ lưu nếu train tốt hơn\n",
|
| 619 |
+
"m0 = evaluate(); best = mean_srcc(m0); best_state = snapshot(); save_full(best_state, m0.get(\"emos\", float(\"nan\")))\n",
|
| 620 |
+
"print(f\"📍 Khởi điểm (warm-start): mean SRCC = {best:.4f} | \"\n",
|
| 621 |
+
" + \" \".join(f\"{k}={m0[k]:.3f}\" for k in ['emos','val','aro','dom'] if k in m0))\n",
|
| 622 |
+
"\n",
|
| 623 |
+
"bad = 0\n",
|
| 624 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 625 |
+
" set_train(True)\n",
|
| 626 |
+
" opt.zero_grad(); run = 0.0; nb = 0\n",
|
| 627 |
+
" for step, b in enumerate(tqdm(tr_loader, desc=f\"epoch {ep}\")):\n",
|
| 628 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 629 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 630 |
+
" loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
|
| 631 |
+
" scaler.scale(loss).backward()\n",
|
| 632 |
+
" if (step + 1) % ACCUM == 0:\n",
|
| 633 |
+
" scaler.step(opt); scaler.update(); opt.zero_grad()\n",
|
| 634 |
+
" run += loss.item() * ACCUM; nb += 1\n",
|
| 635 |
+
" m = evaluate(); sc = mean_srcc(m)\n",
|
| 636 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
|
| 637 |
+
" print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
|
| 638 |
+
" if sc > best:\n",
|
| 639 |
+
" best = sc; best_state = snapshot(); save_full(best_state, m[\"emos\"])\n",
|
| 640 |
+
" print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\"); bad = 0\n",
|
| 641 |
+
" else:\n",
|
| 642 |
+
" bad += 1\n",
|
| 643 |
+
" if bad >= PATIENCE:\n",
|
| 644 |
+
" print(f\"Early stop ở epoch {ep}.\"); break\n",
|
| 645 |
+
"\n",
|
| 646 |
+
"# Nạp lại best\n",
|
| 647 |
+
"wavlm.load_state_dict(best_state[\"wavlm\"]); aud.load_state_dict(best_state[\"aud\"])\n",
|
| 648 |
+
"aud_head.load_state_dict(best_state[\"aud_head\"]); heads.load_state_dict(best_state[\"heads\"])\n",
|
| 649 |
+
"final = evaluate()\n",
|
| 650 |
+
"print(\"\\n✅ VAL (nội bộ) — exp11 (fine-tune CẢ 2 + fusion):\")\n",
|
| 651 |
+
"print(f\" EMOS={final['emos']:.4f} (exp08 {EXP08['emos']})\")\n",
|
| 652 |
+
"if HAS_VAD:\n",
|
| 653 |
+
" print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} \"\n",
|
| 654 |
+
" f\"(exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
|
| 655 |
+
"print(f\" mean SRCC: warm-start {mean_srcc(m0):.4f} → exp11 {mean_srcc(final):.4f} \"\n",
|
| 656 |
+
" + (\"🚀 cải thiện\" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else \"➖ không cải thiện\"))\n",
|
| 657 |
+
"save_full(best_state, final.get(\"emos\", float(\"nan\")))\n",
|
| 658 |
+
"print(\"Đã lưu FULL:\", CKPT_PATH, \"→ NHỚ Save Version!\")"
|
| 659 |
+
]
|
| 660 |
+
},
|
| 661 |
+
{
|
| 662 |
+
"cell_type": "markdown",
|
| 663 |
+
"id": "dcb57395",
|
| 664 |
+
"metadata": {},
|
| 665 |
+
"source": [
|
| 666 |
+
"## 6. Dự đoán DEV → answer.txt (5 cột cảm xúc từ exp11; QMOS mượn exp07 / UTMOSv2)"
|
| 667 |
+
]
|
| 668 |
+
},
|
| 669 |
+
{
|
| 670 |
+
"cell_type": "code",
|
| 671 |
+
"execution_count": null,
|
| 672 |
+
"id": "fcae1d4c",
|
| 673 |
+
"metadata": {
|
| 674 |
+
"lines_to_next_cell": 1
|
| 675 |
+
},
|
| 676 |
+
"outputs": [],
|
| 677 |
+
"source": [
|
| 678 |
+
"def list_dev():\n",
|
| 679 |
+
" with open(DEV_SCP) as f:\n",
|
| 680 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 681 |
+
"\n",
|
| 682 |
+
"dev_names = list_dev()\n",
|
| 683 |
+
"if LIMIT_DEV:\n",
|
| 684 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 685 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 686 |
+
"\n",
|
| 687 |
+
"def load_exp07_qmos():\n",
|
| 688 |
+
" if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
|
| 689 |
+
" import csv\n",
|
| 690 |
+
" d = {}\n",
|
| 691 |
+
" with open(EXP07_ANSWER) as f:\n",
|
| 692 |
+
" for row in csv.DictReader(f):\n",
|
| 693 |
+
" d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
|
| 694 |
+
" print(f\"✅ Mượn QMOS exp07: {len(d)//2} wav\")\n",
|
| 695 |
+
" return d\n",
|
| 696 |
+
" return None\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"qmos_map = load_exp07_qmos()\n",
|
| 699 |
+
"if qmos_map is None:\n",
|
| 700 |
+
" print(\"ℹ️ Không có exp07 → QMOS bằng UTMOSv2.\")\n",
|
| 701 |
+
" pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
|
| 702 |
+
" import utmosv2\n",
|
| 703 |
+
" v2 = utmosv2.create_model(pretrained=True)\n",
|
| 704 |
+
" qmos_map = {}\n",
|
| 705 |
+
" for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
|
| 706 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
|
| 707 |
+
" if os.path.exists(wav):\n",
|
| 708 |
+
" o = v2.predict(input_path=wav)\n",
|
| 709 |
+
" qmos_map[n] = float(o[\"predicted_mos\"]) if isinstance(o, dict) else float(o)\n",
|
| 710 |
+
" del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 711 |
+
"\n",
|
| 712 |
+
"@torch.no_grad()\n",
|
| 713 |
+
"def predict_emotion(sid):\n",
|
| 714 |
+
" pair = load_pair(sid)\n",
|
| 715 |
+
" if pair is None:\n",
|
| 716 |
+
" return None\n",
|
| 717 |
+
" wave, iv = pair\n",
|
| 718 |
+
" set_train(False)\n",
|
| 719 |
+
" w = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 720 |
+
" ivt = torch.from_numpy(iv).unsqueeze(0).to(device)\n",
|
| 721 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
|
| 722 |
+
" tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
|
| 723 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 724 |
+
" feat = torch.cat([wavlm_embed(w, am), aud_embed(ivt, am)], dim=1)\n",
|
| 725 |
+
" emos_p, cat_l, vad_p = heads(feat, tgt)\n",
|
| 726 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 727 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 728 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 729 |
+
" return emos, cat5, vad3\n",
|
| 730 |
+
"\n",
|
| 731 |
+
"def fmt_cat(p5):\n",
|
| 732 |
+
" return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 733 |
+
"\n",
|
| 734 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 735 |
+
"n_real = n_def = 0\n",
|
| 736 |
+
"with open(answer_path, \"w\") as f:\n",
|
| 737 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 738 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 739 |
+
" sid = stem(name)\n",
|
| 740 |
+
" pr = predict_emotion(sid)\n",
|
| 741 |
+
" if pr is None:\n",
|
| 742 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
|
| 743 |
+
" else:\n",
|
| 744 |
+
" emos, cat5, vad3 = pr; n_real += 1\n",
|
| 745 |
+
" qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
|
| 746 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 747 |
+
"print(f\"Ghi {len(dev_names)} dòng → {answer_path} | cảm xúc thật {n_real}, mặc định {n_def}\")"
|
| 748 |
+
]
|
| 749 |
+
},
|
| 750 |
+
{
|
| 751 |
+
"cell_type": "markdown",
|
| 752 |
+
"id": "7b04afd9",
|
| 753 |
+
"metadata": {},
|
| 754 |
+
"source": [
|
| 755 |
+
"## 7. Validate + zip"
|
| 756 |
+
]
|
| 757 |
+
},
|
| 758 |
+
{
|
| 759 |
+
"cell_type": "code",
|
| 760 |
+
"execution_count": null,
|
| 761 |
+
"id": "0dfcf6ef",
|
| 762 |
+
"metadata": {},
|
| 763 |
+
"outputs": [],
|
| 764 |
+
"source": [
|
| 765 |
+
"def validate(path):\n",
|
| 766 |
+
" import csv\n",
|
| 767 |
+
" with open(path) as f:\n",
|
| 768 |
+
" rows = list(csv.reader(f))\n",
|
| 769 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0], \"Header sai\"\n",
|
| 770 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 771 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 772 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 773 |
+
"\n",
|
| 774 |
+
"validate(answer_path)\n",
|
| 775 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp11_joint.zip answer.txt && unzip -l submission_track2_exp11_joint.zip\")\n",
|
| 776 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp11_joint.zip\"))"
|
| 777 |
+
]
|
| 778 |
+
},
|
| 779 |
+
{
|
| 780 |
+
"cell_type": "markdown",
|
| 781 |
+
"id": "8b7adc9b",
|
| 782 |
+
"metadata": {},
|
| 783 |
+
"source": [
|
| 784 |
+
"## Ghi chú\n",
|
| 785 |
+
"- **exp11 = fine-tune CẢ WavLM + audeering, FUSION 1 model** (khác exp08 audeering frozen, khác exp10 ensemble).\n",
|
| 786 |
+
"- **Warm-start:** WavLM + heads từ `ft_emotion_full_20epoch.pt` (exp08) → bắt đầu từ điểm tốt; audeering từ\n",
|
| 787 |
+
" pretrained gốc, mở băng để học thêm. Khởi điểm = đúng exp08 → train chỉ có thể tốt lên (giữ best).\n",
|
| 788 |
+
"- **OOM:** đây là cấu hình nặng nhất. Nếu CUDA OOM → giảm `UNFREEZE_WAVLM`/`UNFREEZE_AUD` (4→2),\n",
|
| 789 |
+
" `MAX_SECONDS` (6→5), giữ `BATCH=1` + tăng `ACCUM`.\n",
|
| 790 |
+
"- **Checkpoint:** lưu `ft_joint_full.pt` mỗi best (đủ cả 2 backbone + heads) → kernel chết vẫn còn. Save Version!\n",
|
| 791 |
+
"- **QMOS** vẫn mượn exp07 (0.548). So sánh nộp: exp11 vs exp08(0.811) vs exp10(ensemble) → chọn bản tốt nhất.\n",
|
| 792 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (exp11)."
|
| 793 |
+
]
|
| 794 |
+
}
|
| 795 |
+
],
|
| 796 |
+
"metadata": {
|
| 797 |
+
"jupytext": {
|
| 798 |
+
"cell_metadata_filter": "-all",
|
| 799 |
+
"main_language": "python",
|
| 800 |
+
"notebook_metadata_filter": "-all"
|
| 801 |
+
}
|
| 802 |
+
},
|
| 803 |
+
"nbformat": 4,
|
| 804 |
+
"nbformat_minor": 5
|
| 805 |
+
}
|
track2/exp11_finetune_joint_pipeline.py
ADDED
|
@@ -0,0 +1,665 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp11 (FINE-TUNE ĐỒNG THỜI WavLM + audeering, FUSION 1 model) — Kaggle T4
|
| 3 |
+
#
|
| 4 |
+
# **Khác exp08:** exp08 chỉ fine-tune WavLM, audeering **đóng băng** (frozen, cache). exp11 **MỞ BĂNG CẢ HAI**
|
| 5 |
+
# backbone và fuse đặc trưng **trong cùng 1 model** → cả hai cùng học cho bài MOS cảm xúc 2026.
|
| 6 |
+
#
|
| 7 |
+
# ```
|
| 8 |
+
# wav ─┬─► WavLM-large (warm-start exp08, TRAINABLE: mở băng N lớp trên) ─► pool ─► emb_wavlm ┐
|
| 9 |
+
# └─► audeering MSP (TRAINABLE: mở băng N lớp trên) ─► pool ─► [emb_aud(1024) | vad3] ──────┼─► TRUNK ─┬─► EMOS (+target)
|
| 10 |
+
# ┘ ├─► CAT (5)
|
| 11 |
+
# └─► VAD (3)
|
| 12 |
+
# QMOS: KHÔNG train ở đây → mượn cột QMOS exp07 (0.548) hoặc UTMOSv2.
|
| 13 |
+
# ```
|
| 14 |
+
#
|
| 15 |
+
# ## Vì sao "feature fusion + fine-tune cả 2" (khác ensemble exp10)
|
| 16 |
+
# - **exp10 = ensemble:** 2 model RIÊNG → trung bình cột VAD ở mức answer. An toàn nhưng 2 model không "nói chuyện".
|
| 17 |
+
# - **exp11 = fusion:** 1 model, 2 backbone fuse Ở TRONG → trunk học phối hợp cả hai góc nhìn (WavLM categorical +
|
| 18 |
+
# audeering dimensional) → kỳ vọng mạnh hơn nếu không OOM/overfit.
|
| 19 |
+
#
|
| 20 |
+
# ## ⚠️ ĐÁNH ĐỔI PHẢI BIẾT — đây là cấu hình NẶNG nhất (2 backbone large cùng có gradient)
|
| 21 |
+
# - **Rủi ro OOM cao trên T4 (16GB).** Đã bật sẵn mọi cách giảm bộ nhớ: `BATCH=1` + grad-accum,
|
| 22 |
+
# gradient-checkpointing CẢ 2 backbone, AMP fp16, `MAX_SECONDS=6`, mở băng ÍT lớp (mặc định 4 mỗi backbone).
|
| 23 |
+
# - Nếu vẫn OOM: giảm `UNFREEZE_WAVLM`/`UNFREEZE_AUD` → 2, giảm `MAX_SECONDS` → 5, tăng `ACCUM`.
|
| 24 |
+
# - **Chậm + đốt giờ GPU** (2 backbone forward+backward, không cache được). **LẦN ĐẦU BẮT BUỘC `LIMIT_TRAIN=300`,
|
| 25 |
+
# `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.
|
| 26 |
+
# - **Lưới an toàn:** đừng đốt lượt nộp — chỉ nộp khi exp11 thắng exp08 (0.811) TRÊN VAL NỘI BỘ.
|
| 27 |
+
#
|
| 28 |
+
# **Cách chạy:** GPU **T4** + Internet **On** → Add Input (data + checkpoint exp08 + [tùy chọn] answer exp07) →
|
| 29 |
+
# sửa slug cell 0 → Run All. Ghi config→kết quả→nhận xét vào `docs/04_experiments_log.md` (exp11).
|
| 30 |
+
|
| 31 |
+
# %% [markdown]
|
| 32 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 33 |
+
|
| 34 |
+
# %%
|
| 35 |
+
import os
|
| 36 |
+
|
| 37 |
+
DATA_ROOT = "/kaggle/input/datasets/minhtoan2/vmc2026-track2-full/vmc2026-track2" # << SỬA slug cho khớp Add Input
|
| 38 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 39 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
|
| 40 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
|
| 41 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 42 |
+
|
| 43 |
+
# ── Warm-start / RESUME: trỏ tới 1 trong 2 loại checkpoint ───────────────────
|
| 44 |
+
# • ft_emotion_full_20epoch.pt (exp08): có 'wavlm'+'heads' → WARM-START (audeering từ pretrained gốc).
|
| 45 |
+
# • ft_joint_full.pt (exp11): có thêm 'aud'+'aud_head' → RESUME ĐỦ (khôi phục cả 2 backbone đã fine-tune).
|
| 46 |
+
# Notebook TỰ nhận biết theo key trong checkpoint. Để "" nếu train WavLM từ SAILER trắng.
|
| 47 |
+
WARMSTART_CKPT = "/kaggle/input/ft-joint-full/ft_joint_full.pt" # << exp08 ckpt (warm-start) HOẶC exp11 ckpt (resume)
|
| 48 |
+
|
| 49 |
+
# Mượn cột QMOS exp07 (0.548). Không có → UTMOSv2.
|
| 50 |
+
EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) answer.txt exp07; không có → UTMOSv2
|
| 51 |
+
|
| 52 |
+
OUT_DIR = "/kaggle/working"
|
| 53 |
+
|
| 54 |
+
# ── Fine-tune / siêu tham số (CẤU HÌNH NẶNG — đã tối ưu cho T4) ───────────────
|
| 55 |
+
DEVICE = "cuda"
|
| 56 |
+
SR = 16000
|
| 57 |
+
MAX_SECONDS = 6 # ↓ so exp08 (8) để tiết kiệm VRAM (2 backbone)
|
| 58 |
+
UNFREEZE_WAVLM = 4 # số lớp encoder WavLM mở băng (OOM → 2)
|
| 59 |
+
UNFREEZE_AUD = 4 # số lớp encoder audeering mở băng (OOM → 2)
|
| 60 |
+
TRUNK_HIDDEN = 512 # PHẢI khớp checkpoint exp08 nếu warm-start heads
|
| 61 |
+
HEAD_HIDDEN = 128 # PHẢI khớp checkpoint exp08
|
| 62 |
+
DROPOUT = 0.3
|
| 63 |
+
LR_BACKBONE = 1e-5 # LR chung cho 2 backbone
|
| 64 |
+
LR_HEAD = 1e-3
|
| 65 |
+
RESUME_LR_SCALE = 1.0 # <1.0 để GIẢM LR khi resume (vd 0.5 nếu val đã chững) — nhân vào cả 2 nhóm LR
|
| 66 |
+
WEIGHT_DECAY = 1e-5
|
| 67 |
+
EPOCHS = 12
|
| 68 |
+
PATIENCE = 4 # dừng khi val không lên; LUÔN giữ best
|
| 69 |
+
BATCH = 1 # ⚠️ 2 backbone large → batch nhỏ
|
| 70 |
+
ACCUM = 16 # effective batch = 16
|
| 71 |
+
VAL_FRAC = 0.10
|
| 72 |
+
SEED = 42
|
| 73 |
+
USE_AMP = True
|
| 74 |
+
USE_GRAD_CKPT = True
|
| 75 |
+
USE_UNCERTAINTY = True
|
| 76 |
+
|
| 77 |
+
LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
|
| 78 |
+
LIMIT_DEV = 20 # << LẦN ĐẦU 20; ch���y thật None
|
| 79 |
+
|
| 80 |
+
EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751} # mốc để so
|
| 81 |
+
|
| 82 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 83 |
+
_EMO_ALIAS = {
|
| 84 |
+
"angry": "angry", "anger": "angry",
|
| 85 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 86 |
+
"neutral": "neutral", "calm": "neutral",
|
| 87 |
+
"sad": "sad", "sadness": "sad",
|
| 88 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
def norm_emotion(label):
|
| 92 |
+
key = str(label).strip().lower()
|
| 93 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 94 |
+
|
| 95 |
+
def stem(p):
|
| 96 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 97 |
+
|
| 98 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 99 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 100 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 101 |
+
print((" ✅ " if (WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT)) else " ⚠️ KHÔNG có ") + str(WARMSTART_CKPT)
|
| 102 |
+
+ (" → warm-start" if (WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT)) else " → train từ SAILER trắng"))
|
| 103 |
+
|
| 104 |
+
# %% [markdown]
|
| 105 |
+
# ## 1. Cài đặt + tải code SAILER (dựng đúng kiến trúc WavLM)
|
| 106 |
+
|
| 107 |
+
# %%
|
| 108 |
+
import sys, subprocess
|
| 109 |
+
|
| 110 |
+
def pip_install(*pkgs):
|
| 111 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 112 |
+
|
| 113 |
+
pip_install("transformers", "huggingface_hub", "safetensors", "loralib", "speechbrain",
|
| 114 |
+
"speechmos", "librosa", "soundfile", "scipy", "scikit-learn", "pandas", "tqdm")
|
| 115 |
+
|
| 116 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 117 |
+
if not os.path.exists(REPO_DIR):
|
| 118 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 119 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 120 |
+
if REPO_DIR not in sys.path:
|
| 121 |
+
sys.path.insert(0, REPO_DIR)
|
| 122 |
+
|
| 123 |
+
# %% [markdown]
|
| 124 |
+
# ## 2A. WavLM TRAINABLE (warm-start SAILER / checkpoint exp08)
|
| 125 |
+
|
| 126 |
+
# %%
|
| 127 |
+
import torch
|
| 128 |
+
import torch.nn as nn
|
| 129 |
+
import torch.nn.functional as F
|
| 130 |
+
import numpy as np
|
| 131 |
+
|
| 132 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 133 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
|
| 134 |
+
|
| 135 |
+
# Nạp checkpoint exp08 (nếu có) — lấy cả 'wavlm', 'heads', thống kê chuẩn hóa
|
| 136 |
+
ckpt = None
|
| 137 |
+
if WARMSTART_CKPT and os.path.exists(WARMSTART_CKPT):
|
| 138 |
+
ckpt = torch.load(WARMSTART_CKPT, map_location="cpu", weights_only=False)
|
| 139 |
+
print("✅ Nạp checkpoint warm-start:", WARMSTART_CKPT, "| keys:", list(ckpt.keys()))
|
| 140 |
+
if "wavlm" not in ckpt:
|
| 141 |
+
print(" ⚠️ Checkpoint KHÔNG có 'wavlm' (chỉ heads?) → vẫn dựng WavLM từ SAILER, chỉ warm-start heads nếu khớp.")
|
| 142 |
+
|
| 143 |
+
def find_hf_backbone(module):
|
| 144 |
+
cands = []
|
| 145 |
+
for name, m in module.named_modules():
|
| 146 |
+
enc = getattr(m, "encoder", None)
|
| 147 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 148 |
+
and getattr(enc, "layers", None) is not None:
|
| 149 |
+
cands.append((name, m))
|
| 150 |
+
if not cands:
|
| 151 |
+
return None, None
|
| 152 |
+
cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
|
| 153 |
+
return cands[0]
|
| 154 |
+
|
| 155 |
+
wavlm = None
|
| 156 |
+
try:
|
| 157 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 158 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 159 |
+
name, wavlm = find_hf_backbone(_wrapper)
|
| 160 |
+
if wavlm is not None:
|
| 161 |
+
print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'")
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
|
| 164 |
+
|
| 165 |
+
if wavlm is None:
|
| 166 |
+
from transformers import WavLMModel
|
| 167 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 168 |
+
print("ℹ️ Fallback: microsoft/wavlm-large.")
|
| 169 |
+
|
| 170 |
+
wavlm = wavlm.to(device)
|
| 171 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 172 |
+
wavlm.config.layerdrop = 0.0 # ⚠️ tắt layerdrop khi dùng gradient-checkpointing (tránh CheckpointError)
|
| 173 |
+
|
| 174 |
+
# Đè trọng số đã fine-tune từ checkpoint exp08 (nếu có)
|
| 175 |
+
if ckpt is not None and "wavlm" in ckpt:
|
| 176 |
+
miss, unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
|
| 177 |
+
print(f"🔁 load wavlm từ checkpoint exp08: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).")
|
| 178 |
+
|
| 179 |
+
# Đóng băng partial: chỉ mở UNFREEZE_WAVLM lớp trên
|
| 180 |
+
for p in wavlm.parameters():
|
| 181 |
+
p.requires_grad = False
|
| 182 |
+
_wl = wavlm.encoder.layers
|
| 183 |
+
for layer in _wl[max(0, len(_wl) - UNFREEZE_WAVLM):]:
|
| 184 |
+
for p in layer.parameters():
|
| 185 |
+
p.requires_grad = True
|
| 186 |
+
print(f"WavLM: {len(_wl)} lớp · mở băng {min(UNFREEZE_WAVLM, len(_wl))} → "
|
| 187 |
+
f"{sum(p.numel() for p in wavlm.parameters() if p.requires_grad)/1e6:.1f}M param train (dim {WAVLM_DIM})")
|
| 188 |
+
|
| 189 |
+
if USE_GRAD_CKPT:
|
| 190 |
+
wavlm.gradient_checkpointing_enable()
|
| 191 |
+
if hasattr(wavlm, "enable_input_require_grads"):
|
| 192 |
+
wavlm.enable_input_require_grads()
|
| 193 |
+
|
| 194 |
+
def masked_mean(hidden, attn_mask, model):
|
| 195 |
+
if attn_mask is None:
|
| 196 |
+
return hidden.mean(dim=1)
|
| 197 |
+
try:
|
| 198 |
+
fm = model._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
|
| 199 |
+
except Exception:
|
| 200 |
+
return hidden.mean(dim=1)
|
| 201 |
+
fm = fm.unsqueeze(-1).to(hidden.dtype)
|
| 202 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 203 |
+
|
| 204 |
+
def wavlm_embed(input_values, attn_mask):
|
| 205 |
+
out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
|
| 206 |
+
return masked_mean(out, attn_mask, wavlm)
|
| 207 |
+
|
| 208 |
+
# %% [markdown]
|
| 209 |
+
# ## 2B. audeering TRAINABLE (mở băng — khác exp08 là frozen)
|
| 210 |
+
# Nạp backbone tay + head dimensional gốc; mở băng `UNFREEZE_AUD` lớp trên. Đặc trưng fuse = [hidden(1024) | vad3].
|
| 211 |
+
|
| 212 |
+
# %%
|
| 213 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 214 |
+
from huggingface_hub import hf_hub_download
|
| 215 |
+
|
| 216 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 217 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 218 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 219 |
+
aud = Wav2Vec2Model(aud_cfg)
|
| 220 |
+
try:
|
| 221 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 222 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 223 |
+
except Exception:
|
| 224 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 225 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 226 |
+
aud.load_state_dict(bb_sd, strict=False)
|
| 227 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 228 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd["classifier.out_proj.weight"].shape[0]))
|
| 229 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 230 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 231 |
+
aud = aud.to(device); aud_head = aud_head.to(device)
|
| 232 |
+
aud.config.layerdrop = 0.0 # ⚠️ tắt layerdrop khi dùng gradient-checkpointing (tránh CheckpointError)
|
| 233 |
+
AUD_DIM = _hid + 3 # = 1027 (khớp exp08 để warm-start heads)
|
| 234 |
+
|
| 235 |
+
# RESUME: nếu checkpoint là ft_joint_full.pt (có 'aud') → khôi phục audeering ĐÃ fine-tune (đè pretrained)
|
| 236 |
+
if ckpt is not None and "aud" in ckpt:
|
| 237 |
+
amiss, aunexp = aud.load_state_dict(ckpt["aud"], strict=False)
|
| 238 |
+
print(f"🔁 RESUME audeering từ checkpoint: thiếu {len(amiss)} / dư {len(aunexp)} key (kỳ vọng ~0).")
|
| 239 |
+
if "aud_head" in ckpt:
|
| 240 |
+
aud_head.load_state_dict(ckpt["aud_head"]); print("🔁 RESUME aud_head từ checkpoint.")
|
| 241 |
+
else:
|
| 242 |
+
print("ℹ️ Checkpoint không có 'aud' → audeering khởi từ pretrained gốc (chế độ warm-start exp08).")
|
| 243 |
+
|
| 244 |
+
# Đóng băng partial audeering: mở UNFREEZE_AUD lớp trên + head dimensional luôn trainable
|
| 245 |
+
for p in aud.parameters():
|
| 246 |
+
p.requires_grad = False
|
| 247 |
+
_al = aud.encoder.layers
|
| 248 |
+
for layer in _al[max(0, len(_al) - UNFREEZE_AUD):]:
|
| 249 |
+
for p in layer.parameters():
|
| 250 |
+
p.requires_grad = True
|
| 251 |
+
for p in aud_head.parameters():
|
| 252 |
+
p.requires_grad = True
|
| 253 |
+
print(f"audeering: {len(_al)} lớp · mở băng {min(UNFREEZE_AUD, len(_al))} → "
|
| 254 |
+
f"{sum(p.numel() for p in aud.parameters() if p.requires_grad)/1e6:.1f}M param train (hidden {_hid}, fuse dim {AUD_DIM})")
|
| 255 |
+
|
| 256 |
+
if USE_GRAD_CKPT:
|
| 257 |
+
aud.gradient_checkpointing_enable()
|
| 258 |
+
if hasattr(aud, "enable_input_require_grads"):
|
| 259 |
+
aud.enable_input_require_grads()
|
| 260 |
+
|
| 261 |
+
def aud_embed(input_values, attn_mask):
|
| 262 |
+
"""Trả về [hidden(1024) | vad3] — vad3 từ head dimensional gốc, theo thứ tự VAL,ARO,DOM."""
|
| 263 |
+
h = masked_mean(aud(input_values, attention_mask=attn_mask).last_hidden_state, attn_mask, aud)
|
| 264 |
+
out = aud_head(h) # [B,3] thứ tự gốc audeering: (arousal, dominance, valence)
|
| 265 |
+
vad = torch.stack([1 + 4 * out[:, 2], 1 + 4 * out[:, 0], 1 + 4 * out[:, 1]], dim=1) # → VAL,ARO,DOM
|
| 266 |
+
return torch.cat([h, vad], dim=1)
|
| 267 |
+
|
| 268 |
+
# %% [markdown]
|
| 269 |
+
# ## 3. Đọc & gộp nhãn theo wavID (như exp08)
|
| 270 |
+
|
| 271 |
+
# %%
|
| 272 |
+
import librosa
|
| 273 |
+
import pandas as pd
|
| 274 |
+
from tqdm.auto import tqdm
|
| 275 |
+
|
| 276 |
+
def load_target_emotions():
|
| 277 |
+
tgt = {}
|
| 278 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 279 |
+
for ln in f:
|
| 280 |
+
parts = ln.strip().split("|")
|
| 281 |
+
if len(parts) >= 2:
|
| 282 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 283 |
+
return tgt
|
| 284 |
+
|
| 285 |
+
def _col(cols_map, *names, df=None, default_idx=None):
|
| 286 |
+
for n in names:
|
| 287 |
+
if n in cols_map:
|
| 288 |
+
return cols_map[n]
|
| 289 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 290 |
+
|
| 291 |
+
def parse_emocat_votes(cell):
|
| 292 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 293 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 294 |
+
e = norm_emotion(tok)
|
| 295 |
+
if e in EMOTIONS5:
|
| 296 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 297 |
+
return v
|
| 298 |
+
|
| 299 |
+
def load_train_labels():
|
| 300 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 301 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 302 |
+
wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
|
| 303 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 304 |
+
val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
|
| 305 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 306 |
+
assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
|
| 307 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 308 |
+
rows = []
|
| 309 |
+
for sid, g in df.groupby("_stem"):
|
| 310 |
+
rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
|
| 311 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 312 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 313 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 314 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 315 |
+
if cat_col:
|
| 316 |
+
for cell in g[cat_col]:
|
| 317 |
+
votes += parse_emocat_votes(cell)
|
| 318 |
+
s = votes.sum()
|
| 319 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
|
| 320 |
+
for i in range(len(EMOTIONS5)):
|
| 321 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 322 |
+
rows.append(rec)
|
| 323 |
+
return pd.DataFrame(rows)
|
| 324 |
+
|
| 325 |
+
target_map = load_target_emotions()
|
| 326 |
+
train_df = load_train_labels()
|
| 327 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 328 |
+
print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
|
| 329 |
+
|
| 330 |
+
# %% [markdown]
|
| 331 |
+
# ## 4. Dataset/loader — trả về CẢ raw wave (cho WavLM) + input_values audeering
|
| 332 |
+
# Hai backbone cần đầu vào khác nhau: WavLM nhận wave thô; audeering nhận wave đã chuẩn hóa bởi processor.
|
| 333 |
+
# Cùng độ dài → dùng chung attention mask theo mức sample.
|
| 334 |
+
|
| 335 |
+
# %%
|
| 336 |
+
from torch.utils.data import Dataset, DataLoader
|
| 337 |
+
from sklearn.model_selection import train_test_split
|
| 338 |
+
|
| 339 |
+
train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
|
| 340 |
+
if LIMIT_TRAIN:
|
| 341 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 342 |
+
lab = train_df.set_index("wavID")
|
| 343 |
+
|
| 344 |
+
# Chuẩn hóa: lấy TỪ checkpoint nếu warm-start (để khớp head đã train); không thì fit từ data.
|
| 345 |
+
if ckpt is not None and "emos_mu" in ckpt:
|
| 346 |
+
emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
|
| 347 |
+
vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
|
| 348 |
+
print(f"Chuẩn hóa TỪ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
|
| 349 |
+
else:
|
| 350 |
+
def _zfit(a):
|
| 351 |
+
a = np.asarray(a, dtype=np.float32); return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
|
| 352 |
+
emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
|
| 353 |
+
if HAS_VAD:
|
| 354 |
+
vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], np.float32)
|
| 355 |
+
vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], np.float32)
|
| 356 |
+
else:
|
| 357 |
+
vad_mu = np.zeros(3, np.float32); vad_sd = np.ones(3, np.float32)
|
| 358 |
+
print(f"Chuẩn hóa fit từ data: emos μ={emos_mu:.3f} σ={emos_sd:.3f}")
|
| 359 |
+
|
| 360 |
+
def onehot_target(tgt):
|
| 361 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 362 |
+
if tgt in EMOTIONS5:
|
| 363 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 364 |
+
return v
|
| 365 |
+
|
| 366 |
+
def load_pair(sid):
|
| 367 |
+
"""Trả về (wave_thô, iv_audeering) cùng độ dài; None nếu thiếu file."""
|
| 368 |
+
p = os.path.join(WAV_DIR, sid if str(sid).endswith(".wav") else str(sid) + ".wav")
|
| 369 |
+
if not os.path.exists(p):
|
| 370 |
+
return None
|
| 371 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 372 |
+
wave = wave[: MAX_SECONDS * SR].astype(np.float32)
|
| 373 |
+
iv = np.asarray(aud_proc(wave, sampling_rate=SR).input_values[0], dtype=np.float32)
|
| 374 |
+
return wave, iv
|
| 375 |
+
|
| 376 |
+
class JointDataset(Dataset):
|
| 377 |
+
def __init__(self, stems):
|
| 378 |
+
self.stems = [s for s in stems if load_pair(s) is not None]
|
| 379 |
+
def __len__(self):
|
| 380 |
+
return len(self.stems)
|
| 381 |
+
def __getitem__(self, i):
|
| 382 |
+
s = self.stems[i]
|
| 383 |
+
wave, iv = load_pair(s)
|
| 384 |
+
emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
|
| 385 |
+
if HAS_VAD:
|
| 386 |
+
vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
|
| 387 |
+
else:
|
| 388 |
+
vad = np.zeros(3, dtype=np.float32)
|
| 389 |
+
cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
|
| 390 |
+
return {"wave": wave, "iv": iv, "tgt": onehot_target(target_map.get(s)),
|
| 391 |
+
"emos": np.float32(emos), "vad": vad, "cat": cat,
|
| 392 |
+
"emos_raw": np.float32(lab.loc[s, "emos"]),
|
| 393 |
+
"vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
|
| 394 |
+
|
| 395 |
+
def collate(batch):
|
| 396 |
+
L = max(len(b["wave"]) for b in batch)
|
| 397 |
+
waves = np.zeros((len(batch), L), dtype=np.float32)
|
| 398 |
+
ivs = np.zeros((len(batch), L), dtype=np.float32)
|
| 399 |
+
mask = np.zeros((len(batch), L), dtype=np.float32)
|
| 400 |
+
for i, b in enumerate(batch):
|
| 401 |
+
n = len(b["wave"])
|
| 402 |
+
waves[i, :n] = b["wave"]; ivs[i, :len(b["iv"])] = b["iv"]; mask[i, :n] = 1.0
|
| 403 |
+
return {
|
| 404 |
+
"wave": torch.from_numpy(waves), "iv": torch.from_numpy(ivs), "attn_mask": torch.from_numpy(mask).long(),
|
| 405 |
+
"tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
|
| 406 |
+
"emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
|
| 407 |
+
"vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
|
| 408 |
+
"cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
|
| 409 |
+
"emos_raw": np.stack([b["emos_raw"] for b in batch]),
|
| 410 |
+
"vad_raw": np.stack([b["vad_raw"] for b in batch]),
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
ds = JointDataset(train_stems)
|
| 414 |
+
print("Dataset hợp lệ:", len(ds), "wav")
|
| 415 |
+
tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
|
| 416 |
+
tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
|
| 417 |
+
va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
|
| 418 |
+
|
| 419 |
+
# %% [markdown]
|
| 420 |
+
# ## 5. Heads (warm-start exp08 nếu khớp) + optimizer 2 backbone + train loop
|
| 421 |
+
|
| 422 |
+
# %%
|
| 423 |
+
from scipy.stats import spearmanr
|
| 424 |
+
|
| 425 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 426 |
+
N_EMO = len(EMOTIONS5)
|
| 427 |
+
TRUNK_IN = WAVLM_DIM + AUD_DIM
|
| 428 |
+
|
| 429 |
+
class EmoHeads(nn.Module):
|
| 430 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 431 |
+
super().__init__()
|
| 432 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 433 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 434 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 435 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 436 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 437 |
+
def forward(self, feat, tgt):
|
| 438 |
+
h = self.trunk(feat)
|
| 439 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 440 |
+
|
| 441 |
+
heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
|
| 442 |
+
if ckpt is not None and "heads" in ckpt:
|
| 443 |
+
hmiss, hunexp = heads.load_state_dict(ckpt["heads"], strict=False)
|
| 444 |
+
if len(hmiss) == 0 and len(hunexp) == 0:
|
| 445 |
+
print("🔁 warm-start heads từ exp08: KHỚP hoàn toàn.")
|
| 446 |
+
else:
|
| 447 |
+
print(f"⚠️ heads exp08 lệch (thiếu {len(hmiss)}/dư {len(hunexp)}) → có thể TRUNK_IN khác. Heads init mới phần lệch.")
|
| 448 |
+
print(f"Trunk input = {TRUNK_IN} (wavlm {WAVLM_DIM} + aud {AUD_DIM})")
|
| 449 |
+
|
| 450 |
+
TASKS = ["emos", "cat", "val", "aro", "dom"]
|
| 451 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 452 |
+
bb_params = [p for p in wavlm.parameters() if p.requires_grad] + \
|
| 453 |
+
[p for p in aud.parameters() if p.requires_grad] + list(aud_head.parameters())
|
| 454 |
+
head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
|
| 455 |
+
opt = torch.optim.AdamW([{"params": bb_params, "lr": LR_BACKBONE * RESUME_LR_SCALE},
|
| 456 |
+
{"params": head_params, "lr": LR_HEAD * RESUME_LR_SCALE}], weight_decay=WEIGHT_DECAY)
|
| 457 |
+
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
|
| 458 |
+
mse = nn.MSELoss()
|
| 459 |
+
|
| 460 |
+
def soft_ce(logits, target_dist):
|
| 461 |
+
return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
|
| 462 |
+
|
| 463 |
+
def forward_batch(b):
|
| 464 |
+
am = b["attn_mask"].to(device)
|
| 465 |
+
fw = wavlm_embed(b["wave"].to(device), am)
|
| 466 |
+
fa = aud_embed(b["iv"].to(device), am)
|
| 467 |
+
return heads(torch.cat([fw, fa], dim=1), b["tgt"].to(device))
|
| 468 |
+
|
| 469 |
+
def compute_loss(emos_p, cat_l, vad_p, b):
|
| 470 |
+
L = {}
|
| 471 |
+
L["emos"] = mse(emos_p, b["emos"].to(device))
|
| 472 |
+
L["cat"] = soft_ce(cat_l, b["cat"].to(device))
|
| 473 |
+
if HAS_VAD:
|
| 474 |
+
vt = b["vad"].to(device)
|
| 475 |
+
L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
|
| 476 |
+
else:
|
| 477 |
+
z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
|
| 478 |
+
if USE_UNCERTAINTY:
|
| 479 |
+
return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
|
| 480 |
+
return sum(L.values())
|
| 481 |
+
|
| 482 |
+
def set_train(flag):
|
| 483 |
+
wavlm.train(flag); aud.train(flag); aud_head.train(flag); heads.train(flag)
|
| 484 |
+
|
| 485 |
+
@torch.no_grad()
|
| 486 |
+
def evaluate():
|
| 487 |
+
set_train(False)
|
| 488 |
+
P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
|
| 489 |
+
catP, catY = [], []
|
| 490 |
+
for b in va_loader:
|
| 491 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 492 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 493 |
+
P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
|
| 494 |
+
vad_p = vad_p.float().cpu().numpy()
|
| 495 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 496 |
+
P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
|
| 497 |
+
catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
|
| 498 |
+
out = {}
|
| 499 |
+
for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
|
| 500 |
+
out[t] = spearmanr(P[t], Y[t]).correlation
|
| 501 |
+
q = np.concatenate(catP); p = np.concatenate(catY)
|
| 502 |
+
out["cat_err"] = float(np.abs(q - p).sum(1).mean())
|
| 503 |
+
return out
|
| 504 |
+
|
| 505 |
+
def mean_srcc(m):
|
| 506 |
+
keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 507 |
+
return float(np.mean([m[k] for k in keys]))
|
| 508 |
+
|
| 509 |
+
def snapshot():
|
| 510 |
+
return {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
|
| 511 |
+
"aud": {k: v.cpu().clone() for k, v in aud.state_dict().items()},
|
| 512 |
+
"aud_head": {k: v.cpu().clone() for k, v in aud_head.state_dict().items()},
|
| 513 |
+
"heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
|
| 514 |
+
|
| 515 |
+
CKPT_PATH = os.path.join(OUT_DIR, "ft_joint_full.pt")
|
| 516 |
+
def save_full(state, val_emos=float("nan")):
|
| 517 |
+
torch.save({**state, "emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
|
| 518 |
+
"WAVLM_DIM": WAVLM_DIM, "AUD_DIM": AUD_DIM,
|
| 519 |
+
"UNFREEZE_WAVLM": UNFREEZE_WAVLM, "UNFREEZE_AUD": UNFREEZE_AUD,
|
| 520 |
+
"val_emos": float(val_emos)}, CKPT_PATH)
|
| 521 |
+
|
| 522 |
+
# Init best từ trạng thái warm-start hiện tại → chỉ lưu nếu train tốt hơn
|
| 523 |
+
m0 = evaluate(); best = mean_srcc(m0); best_state = snapshot(); save_full(best_state, m0.get("emos", float("nan")))
|
| 524 |
+
print(f"📍 Khởi điểm (warm-start): mean SRCC = {best:.4f} | "
|
| 525 |
+
+ " ".join(f"{k}={m0[k]:.3f}" for k in ['emos','val','aro','dom'] if k in m0))
|
| 526 |
+
|
| 527 |
+
bad = 0
|
| 528 |
+
for ep in range(1, EPOCHS + 1):
|
| 529 |
+
set_train(True)
|
| 530 |
+
opt.zero_grad(); run = 0.0; nb = 0
|
| 531 |
+
for step, b in enumerate(tqdm(tr_loader, desc=f"epoch {ep}")):
|
| 532 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 533 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 534 |
+
loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
|
| 535 |
+
scaler.scale(loss).backward()
|
| 536 |
+
if (step + 1) % ACCUM == 0:
|
| 537 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 538 |
+
run += loss.item() * ACCUM; nb += 1
|
| 539 |
+
m = evaluate(); sc = mean_srcc(m)
|
| 540 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
|
| 541 |
+
print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
|
| 542 |
+
if sc > best:
|
| 543 |
+
best = sc; best_state = snapshot(); save_full(best_state, m["emos"])
|
| 544 |
+
print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})"); bad = 0
|
| 545 |
+
else:
|
| 546 |
+
bad += 1
|
| 547 |
+
if bad >= PATIENCE:
|
| 548 |
+
print(f"Early stop ở epoch {ep}."); break
|
| 549 |
+
|
| 550 |
+
# Nạp lại best
|
| 551 |
+
wavlm.load_state_dict(best_state["wavlm"]); aud.load_state_dict(best_state["aud"])
|
| 552 |
+
aud_head.load_state_dict(best_state["aud_head"]); heads.load_state_dict(best_state["heads"])
|
| 553 |
+
final = evaluate()
|
| 554 |
+
print("\n✅ VAL (nội bộ) — exp11 (fine-tune CẢ 2 + fusion):")
|
| 555 |
+
print(f" EMOS={final['emos']:.4f} (exp08 {EXP08['emos']})")
|
| 556 |
+
if HAS_VAD:
|
| 557 |
+
print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} "
|
| 558 |
+
f"(exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
|
| 559 |
+
print(f" mean SRCC: warm-start {mean_srcc(m0):.4f} → exp11 {mean_srcc(final):.4f} "
|
| 560 |
+
+ ("🚀 cải thiện" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else "➖ không cải thiện"))
|
| 561 |
+
save_full(best_state, final.get("emos", float("nan")))
|
| 562 |
+
print("Đã lưu FULL:", CKPT_PATH, "→ NHỚ Save Version!")
|
| 563 |
+
|
| 564 |
+
# %% [markdown]
|
| 565 |
+
# ## 6. Dự đoán DEV → answer.txt (5 cột cảm xúc từ exp11; QMOS mượn exp07 / UTMOSv2)
|
| 566 |
+
|
| 567 |
+
# %%
|
| 568 |
+
def list_dev():
|
| 569 |
+
with open(DEV_SCP) as f:
|
| 570 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 571 |
+
|
| 572 |
+
dev_names = list_dev()
|
| 573 |
+
if LIMIT_DEV:
|
| 574 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 575 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 576 |
+
|
| 577 |
+
def load_exp07_qmos():
|
| 578 |
+
if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
|
| 579 |
+
import csv
|
| 580 |
+
d = {}
|
| 581 |
+
with open(EXP07_ANSWER) as f:
|
| 582 |
+
for row in csv.DictReader(f):
|
| 583 |
+
d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
|
| 584 |
+
print(f"✅ Mượn QMOS exp07: {len(d)//2} wav")
|
| 585 |
+
return d
|
| 586 |
+
return None
|
| 587 |
+
|
| 588 |
+
qmos_map = load_exp07_qmos()
|
| 589 |
+
if qmos_map is None:
|
| 590 |
+
print("ℹ️ Không có exp07 → QMOS bằng UTMOSv2.")
|
| 591 |
+
pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
|
| 592 |
+
import utmosv2
|
| 593 |
+
v2 = utmosv2.create_model(pretrained=True)
|
| 594 |
+
qmos_map = {}
|
| 595 |
+
for n in tqdm(dev_names, desc="UTMOSv2"):
|
| 596 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
|
| 597 |
+
if os.path.exists(wav):
|
| 598 |
+
o = v2.predict(input_path=wav)
|
| 599 |
+
qmos_map[n] = float(o["predicted_mos"]) if isinstance(o, dict) else float(o)
|
| 600 |
+
del v2; torch.cuda.empty_cache() if device == "cuda" else None
|
| 601 |
+
|
| 602 |
+
@torch.no_grad()
|
| 603 |
+
def predict_emotion(sid):
|
| 604 |
+
pair = load_pair(sid)
|
| 605 |
+
if pair is None:
|
| 606 |
+
return None
|
| 607 |
+
wave, iv = pair
|
| 608 |
+
set_train(False)
|
| 609 |
+
w = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 610 |
+
ivt = torch.from_numpy(iv).unsqueeze(0).to(device)
|
| 611 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
|
| 612 |
+
tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
|
| 613 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 614 |
+
feat = torch.cat([wavlm_embed(w, am), aud_embed(ivt, am)], dim=1)
|
| 615 |
+
emos_p, cat_l, vad_p = heads(feat, tgt)
|
| 616 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 617 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 618 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 619 |
+
return emos, cat5, vad3
|
| 620 |
+
|
| 621 |
+
def fmt_cat(p5):
|
| 622 |
+
return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 623 |
+
|
| 624 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 625 |
+
n_real = n_def = 0
|
| 626 |
+
with open(answer_path, "w") as f:
|
| 627 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 628 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 629 |
+
sid = stem(name)
|
| 630 |
+
pr = predict_emotion(sid)
|
| 631 |
+
if pr is None:
|
| 632 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
|
| 633 |
+
else:
|
| 634 |
+
emos, cat5, vad3 = pr; n_real += 1
|
| 635 |
+
qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
|
| 636 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 637 |
+
print(f"Ghi {len(dev_names)} dòng → {answer_path} | cảm xúc thật {n_real}, mặc định {n_def}")
|
| 638 |
+
|
| 639 |
+
# %% [markdown]
|
| 640 |
+
# ## 7. Validate + zip
|
| 641 |
+
|
| 642 |
+
# %%
|
| 643 |
+
def validate(path):
|
| 644 |
+
import csv
|
| 645 |
+
with open(path) as f:
|
| 646 |
+
rows = list(csv.reader(f))
|
| 647 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0], "Header sai"
|
| 648 |
+
for i, r in enumerate(rows[1:], 2):
|
| 649 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 650 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 651 |
+
|
| 652 |
+
validate(answer_path)
|
| 653 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp11_joint.zip answer.txt && unzip -l submission_track2_exp11_joint.zip")
|
| 654 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp11_joint.zip"))
|
| 655 |
+
|
| 656 |
+
# %% [markdown]
|
| 657 |
+
# ## Ghi chú
|
| 658 |
+
# - **exp11 = fine-tune CẢ WavLM + audeering, FUSION 1 model** (khác exp08 audeering frozen, khác exp10 ensemble).
|
| 659 |
+
# - **Warm-start:** WavLM + heads từ `ft_emotion_full_20epoch.pt` (exp08) → bắt đầu từ điểm tốt; audeering từ
|
| 660 |
+
# pretrained gốc, mở băng để học thêm. Khởi điểm = đúng exp08 → train chỉ có thể tốt lên (giữ best).
|
| 661 |
+
# - **OOM:** đây là cấu hình nặng nhất. Nếu CUDA OOM → giảm `UNFREEZE_WAVLM`/`UNFREEZE_AUD` (4→2),
|
| 662 |
+
# `MAX_SECONDS` (6→5), giữ `BATCH=1` + tăng `ACCUM`.
|
| 663 |
+
# - **Checkpoint:** lưu `ft_joint_full.pt` mỗi best (đủ cả 2 backbone + heads) → kernel chết vẫn còn. Save Version!
|
| 664 |
+
# - **QMOS** vẫn mượn exp07 (0.548). So sánh nộp: exp11 vs exp08(0.811) vs exp10(ensemble) → chọn bản tốt nhất.
|
| 665 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (exp11).
|
track2/exp12_wavlm_scratch.ipynb
ADDED
|
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "73aea642",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp12 (WavLM: SCRATCH vs BASE vs SAILER — ablation khởi tạo) — Kaggle T4\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục đích:** kiểm chứng giả thuyết của mentor — *\"với 12k data, train from scratch có tốt hơn fine-tune không?\"*\n",
|
| 11 |
+
"Một notebook, đổi cờ `INIT_MODE` để chạy 3 cách khởi tạo backbone WavLM, so trên CÙNG kiến trúc/data:\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"| INIT_MODE | Khởi tạo WavLM | Train gì | Ý nghĩa |\n",
|
| 14 |
+
"|---|---|---|---|\n",
|
| 15 |
+
"| `scratch` | **ngẫu nhiên** (không pretrain) | **toàn bộ** backbone | \"from scratch\" đúng nghĩa mentor nói |\n",
|
| 16 |
+
"| `base` | microsoft/wavlm-large (pretrain SSL, KHÔNG cảm xúc) | mở băng N lớp trên | đo lợi ích của SAILER warm-start |\n",
|
| 17 |
+
"| `sailer` | warm-start cảm xúc (như exp08) | mở băng N lớp trên | bản mạnh hiện tại |\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"**Chỉ WavLM** (bỏ audeering) để cô lập đúng biến \"khởi tạo\". QMOS mượn exp07 / UTMOSv2.\n",
|
| 20 |
+
"\n",
|
| 21 |
+
"## ⚠️ Kỳ vọng trung thực (để đọc kết quả đúng)\n",
|
| 22 |
+
"- `scratch` gần như CHẮC CHẮN yếu hơn `base`/`sailer`: 12k mẫu là quá ít để dạy WavLM \"nghe\" từ đầu\n",
|
| 23 |
+
" (SSL pretrain dùng ~94.000 GIỜ audio). Đây là ablation để **chứng minh bằng số**, không phải để vượt.\n",
|
| 24 |
+
"- `scratch` phải mở băng TOÀN BỘ (mới có gì để học) → **nặng + chậm + dễ OOM** trên T4. Dùng LIMIT nhỏ trước.\n",
|
| 25 |
+
"- So sánh bằng **VAL nội bộ** giữa 3 mode đã đủ kết luận; muốn chắc thì nộp mode tốt nhất lên DEV.\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"**Cách chạy:** GPU T4 + Internet On → sửa cell 0 (`INIT_MODE` + slug) → Run All. Chạy 3 lần đổi INIT_MODE."
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "markdown",
|
| 32 |
+
"id": "c0242f5c",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"source": [
|
| 35 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"id": "a167d9d3",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"import os\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"INIT_MODE = \"sailer\" # << \"scratch\" | \"base\" | \"sailer\" (đổi rồi chạy lại để so) — \"sailer\" = WavLM warm-start cảm xúc\n",
|
| 48 |
+
"\n",
|
| 49 |
+
"DATA_ROOT = \"/kaggle/input/datasets/minhtoan2/vmc2026-track2-full\" # << SỬA slug cho khớp Add Input\n",
|
| 50 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 51 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\"\n",
|
| 52 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\"\n",
|
| 53 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn) mượn QMOS 0.548; không có → UTMOSv2\n",
|
| 56 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"# ── Siêu tham số ─────────────────────────────────────────────────────────────\n",
|
| 59 |
+
"DEVICE = \"cuda\"\n",
|
| 60 |
+
"SR = 16000\n",
|
| 61 |
+
"MAX_SECONDS = 6\n",
|
| 62 |
+
"TRUNK_HIDDEN = 512\n",
|
| 63 |
+
"HEAD_HIDDEN = 128\n",
|
| 64 |
+
"DROPOUT = 0.3\n",
|
| 65 |
+
"WEIGHT_DECAY = 1e-5\n",
|
| 66 |
+
"EPOCHS = 15\n",
|
| 67 |
+
"PATIENCE = 5\n",
|
| 68 |
+
"BATCH = 4\n",
|
| 69 |
+
"ACCUM = 8\n",
|
| 70 |
+
"VAL_FRAC = 0.10\n",
|
| 71 |
+
"SEED = 42\n",
|
| 72 |
+
"USE_AMP = True\n",
|
| 73 |
+
"USE_GRAD_CKPT = True\n",
|
| 74 |
+
"USE_UNCERTAINTY = True\n",
|
| 75 |
+
"\n",
|
| 76 |
+
"# Khởi tạo & LR & mở băng — TỰ đặt theo INIT_MODE (scratch cần LR lớn + mở băng toàn bộ)\n",
|
| 77 |
+
"if INIT_MODE == \"scratch\":\n",
|
| 78 |
+
" UNFREEZE_TOP_LAYERS = \"all\" # random init → phải train tất cả mới học được\n",
|
| 79 |
+
" LR_BACKBONE = 1e-4 # random init cần bước lớn hơn fine-tune\n",
|
| 80 |
+
"elif INIT_MODE in (\"base\", \"sailer\"):\n",
|
| 81 |
+
" UNFREEZE_TOP_LAYERS = 6 # fine-tune: chỉ mở băng N lớp trên (tiết kiệm VRAM, chống overfit)\n",
|
| 82 |
+
" LR_BACKBONE = 1e-5\n",
|
| 83 |
+
"else:\n",
|
| 84 |
+
" raise ValueError(f\"INIT_MODE lạ: {INIT_MODE}\")\n",
|
| 85 |
+
"LR_HEAD = 1e-3\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
|
| 88 |
+
"LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"EXP08 = {\"emos\": 0.811, \"cat_err\": 0.133, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751} # mốc DEV để tham khảo\n",
|
| 91 |
+
"\n",
|
| 92 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 93 |
+
"_EMO_ALIAS = {\n",
|
| 94 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 95 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 96 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 97 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 98 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 99 |
+
"}\n",
|
| 100 |
+
"\n",
|
| 101 |
+
"def norm_emotion(label):\n",
|
| 102 |
+
" key = str(label).strip().lower()\n",
|
| 103 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"def stem(p):\n",
|
| 106 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 107 |
+
"\n",
|
| 108 |
+
"print(f\"INIT_MODE = {INIT_MODE} | UNFREEZE = {UNFREEZE_TOP_LAYERS} | LR_BACKBONE = {LR_BACKBONE}\")\n",
|
| 109 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 110 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 111 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "markdown",
|
| 116 |
+
"id": "46cd2554",
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"source": [
|
| 119 |
+
"## 1. Cài đặt (clone SAILER chỉ khi INIT_MODE='sailer')"
|
| 120 |
+
]
|
| 121 |
+
},
|
| 122 |
+
{
|
| 123 |
+
"cell_type": "code",
|
| 124 |
+
"execution_count": null,
|
| 125 |
+
"id": "51808707",
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"outputs": [],
|
| 128 |
+
"source": [
|
| 129 |
+
"import sys, subprocess\n",
|
| 130 |
+
"import numpy as _np\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"# ⚠️ KHÓA numpy = bản Kaggle đang có → pip KHÔNG được nâng/hạ numpy → tránh \"SystemError: bad call flags\"\n",
|
| 133 |
+
"# (lỗi import torch do numpy lệch phiên bản với torch đã biên dịch sẵn).\n",
|
| 134 |
+
"_NPIN = f\"numpy=={_np.__version__}\"\n",
|
| 135 |
+
"print(\"Khóa numpy ở:\", _NPIN)\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"def pip_install(*pkgs):\n",
|
| 138 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs, _NPIN], check=True)\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"# Kaggle đã có sẵn torch/transformers/librosa/scipy/sklearn/pandas/tqdm/huggingface_hub/safetensors.\n",
|
| 141 |
+
"# Chỉ cài thêm vài gói speech còn thiếu (kèm khóa numpy ở trên).\n",
|
| 142 |
+
"pip_install(\"loralib\", \"speechmos\", \"soundfile\")\n",
|
| 143 |
+
"if INIT_MODE == \"sailer\":\n",
|
| 144 |
+
" pip_install(\"speechbrain\")\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"if INIT_MODE == \"sailer\":\n",
|
| 147 |
+
" REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 148 |
+
" if not os.path.exists(REPO_DIR):\n",
|
| 149 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 150 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 151 |
+
" if REPO_DIR not in sys.path:\n",
|
| 152 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 153 |
+
]
|
| 154 |
+
},
|
| 155 |
+
{
|
| 156 |
+
"cell_type": "markdown",
|
| 157 |
+
"id": "5288727c",
|
| 158 |
+
"metadata": {},
|
| 159 |
+
"source": [
|
| 160 |
+
"## 2. Dựng WavLM theo INIT_MODE"
|
| 161 |
+
]
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"cell_type": "code",
|
| 165 |
+
"execution_count": null,
|
| 166 |
+
"id": "c828dcd3",
|
| 167 |
+
"metadata": {
|
| 168 |
+
"lines_to_next_cell": 1
|
| 169 |
+
},
|
| 170 |
+
"outputs": [],
|
| 171 |
+
"source": [
|
| 172 |
+
"import torch\n",
|
| 173 |
+
"import torch.nn as nn\n",
|
| 174 |
+
"import torch.nn.functional as F\n",
|
| 175 |
+
"import numpy as np\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 178 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
|
| 179 |
+
"\n",
|
| 180 |
+
"from transformers import WavLMModel, WavLMConfig\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"def find_hf_backbone(module):\n",
|
| 183 |
+
" cands = []\n",
|
| 184 |
+
" for name, m in module.named_modules():\n",
|
| 185 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 186 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 187 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 188 |
+
" cands.append((name, m))\n",
|
| 189 |
+
" if not cands:\n",
|
| 190 |
+
" return None, None\n",
|
| 191 |
+
" cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
|
| 192 |
+
" return cands[0]\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"wavlm = None\n",
|
| 195 |
+
"if INIT_MODE == \"scratch\":\n",
|
| 196 |
+
" # Random init NHƯNG giữ ĐÚNG kiến trúc large (để công bằng với base/sailer)\n",
|
| 197 |
+
" cfg = WavLMConfig.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 198 |
+
" wavlm = WavLMModel(cfg) # KHÔNG load trọng số → ngẫu nhiên\n",
|
| 199 |
+
" print(\"🎲 WavLM-large khởi tạo NGẪU NHIÊN (from scratch, không pretrain).\")\n",
|
| 200 |
+
"elif INIT_MODE == \"base\":\n",
|
| 201 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 202 |
+
" print(\"📦 WavLM-large pretrain SSL (chưa học cảm xúc).\")\n",
|
| 203 |
+
"elif INIT_MODE == \"sailer\":\n",
|
| 204 |
+
" try:\n",
|
| 205 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 206 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 207 |
+
" name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 208 |
+
" print(f\"🔥 WavLM warm-start SAILER (cảm xúc) tại '.{name}'\")\n",
|
| 209 |
+
" except Exception as e:\n",
|
| 210 |
+
" print(\"⚠️ Lỗi nạp SAILER:\", repr(e), \"→ fallback base pretrained.\")\n",
|
| 211 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"wavlm = wavlm.to(device)\n",
|
| 214 |
+
"WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 215 |
+
"wavlm.config.layerdrop = 0.0 # ⚠️ BẮT BUỘC khi dùng gradient-checkpointing (tránh CheckpointError do bỏ lớp ngẫu nhiên)\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"# Mở băng theo cấu hình\n",
|
| 218 |
+
"if UNFREEZE_TOP_LAYERS == \"all\":\n",
|
| 219 |
+
" for p in wavlm.parameters():\n",
|
| 220 |
+
" p.requires_grad = True\n",
|
| 221 |
+
" n_open = \"ALL\"\n",
|
| 222 |
+
"else:\n",
|
| 223 |
+
" for p in wavlm.parameters():\n",
|
| 224 |
+
" p.requires_grad = False\n",
|
| 225 |
+
" _wl = wavlm.encoder.layers\n",
|
| 226 |
+
" for layer in _wl[max(0, len(_wl) - UNFREEZE_TOP_LAYERS):]:\n",
|
| 227 |
+
" for p in layer.parameters():\n",
|
| 228 |
+
" p.requires_grad = True\n",
|
| 229 |
+
" n_open = f\"top {min(UNFREEZE_TOP_LAYERS, len(_wl))}/{len(_wl)}\"\n",
|
| 230 |
+
"print(f\"WavLM mở băng: {n_open} → {sum(p.numel() for p in wavlm.parameters() if p.requires_grad)/1e6:.1f}M param train (dim {WAVLM_DIM})\")\n",
|
| 231 |
+
"\n",
|
| 232 |
+
"if USE_GRAD_CKPT:\n",
|
| 233 |
+
" wavlm.gradient_checkpointing_enable()\n",
|
| 234 |
+
" if hasattr(wavlm, \"enable_input_require_grads\"):\n",
|
| 235 |
+
" wavlm.enable_input_require_grads()\n",
|
| 236 |
+
"\n",
|
| 237 |
+
"def masked_mean(hidden, attn_mask):\n",
|
| 238 |
+
" if attn_mask is None:\n",
|
| 239 |
+
" return hidden.mean(dim=1)\n",
|
| 240 |
+
" try:\n",
|
| 241 |
+
" fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
|
| 242 |
+
" except Exception:\n",
|
| 243 |
+
" return hidden.mean(dim=1)\n",
|
| 244 |
+
" fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
|
| 245 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
|
| 246 |
+
"\n",
|
| 247 |
+
"def wavlm_embed(input_values, attn_mask):\n",
|
| 248 |
+
" out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
|
| 249 |
+
" return masked_mean(out, attn_mask)"
|
| 250 |
+
]
|
| 251 |
+
},
|
| 252 |
+
{
|
| 253 |
+
"cell_type": "markdown",
|
| 254 |
+
"id": "6b963eb2",
|
| 255 |
+
"metadata": {},
|
| 256 |
+
"source": [
|
| 257 |
+
"## 3. Đọc & gộp nhãn theo wavID"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "code",
|
| 262 |
+
"execution_count": null,
|
| 263 |
+
"id": "4ba4667b",
|
| 264 |
+
"metadata": {},
|
| 265 |
+
"outputs": [],
|
| 266 |
+
"source": [
|
| 267 |
+
"import librosa\n",
|
| 268 |
+
"import pandas as pd\n",
|
| 269 |
+
"from tqdm.auto import tqdm\n",
|
| 270 |
+
"\n",
|
| 271 |
+
"def load_target_emotions():\n",
|
| 272 |
+
" tgt = {}\n",
|
| 273 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 274 |
+
" for ln in f:\n",
|
| 275 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 276 |
+
" if len(parts) >= 2:\n",
|
| 277 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 278 |
+
" return tgt\n",
|
| 279 |
+
"\n",
|
| 280 |
+
"def _col(cols_map, *names, df=None, default_idx=None):\n",
|
| 281 |
+
" for n in names:\n",
|
| 282 |
+
" if n in cols_map:\n",
|
| 283 |
+
" return cols_map[n]\n",
|
| 284 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 285 |
+
"\n",
|
| 286 |
+
"def parse_emocat_votes(cell):\n",
|
| 287 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 288 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 289 |
+
" e = norm_emotion(tok)\n",
|
| 290 |
+
" if e in EMOTIONS5:\n",
|
| 291 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 292 |
+
" return v\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"def load_train_labels():\n",
|
| 295 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 296 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 297 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
|
| 298 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 299 |
+
" val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 300 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 301 |
+
" assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
|
| 302 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 303 |
+
" rows = []\n",
|
| 304 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 305 |
+
" rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
|
| 306 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 307 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 308 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 309 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 310 |
+
" if cat_col:\n",
|
| 311 |
+
" for cell in g[cat_col]:\n",
|
| 312 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 313 |
+
" s = votes.sum()\n",
|
| 314 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
|
| 315 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 316 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 317 |
+
" rows.append(rec)\n",
|
| 318 |
+
" return pd.DataFrame(rows)\n",
|
| 319 |
+
"\n",
|
| 320 |
+
"target_map = load_target_emotions()\n",
|
| 321 |
+
"train_df = load_train_labels()\n",
|
| 322 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 323 |
+
"print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
|
| 324 |
+
]
|
| 325 |
+
},
|
| 326 |
+
{
|
| 327 |
+
"cell_type": "markdown",
|
| 328 |
+
"id": "7efc0957",
|
| 329 |
+
"metadata": {},
|
| 330 |
+
"source": [
|
| 331 |
+
"## 4. Dataset/loader (chỉ raw wave cho WavLM)"
|
| 332 |
+
]
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"cell_type": "code",
|
| 336 |
+
"execution_count": null,
|
| 337 |
+
"id": "a0ae0f55",
|
| 338 |
+
"metadata": {},
|
| 339 |
+
"outputs": [],
|
| 340 |
+
"source": [
|
| 341 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 342 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 343 |
+
"\n",
|
| 344 |
+
"train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
|
| 345 |
+
"if LIMIT_TRAIN:\n",
|
| 346 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 347 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"def _zfit(a):\n",
|
| 350 |
+
" a = np.asarray(a, dtype=np.float32); return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
|
| 351 |
+
"emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
|
| 352 |
+
"if HAS_VAD:\n",
|
| 353 |
+
" vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], np.float32)\n",
|
| 354 |
+
" vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], np.float32)\n",
|
| 355 |
+
"else:\n",
|
| 356 |
+
" vad_mu = np.zeros(3, np.float32); vad_sd = np.ones(3, np.float32)\n",
|
| 357 |
+
"print(f\"Chuẩn hóa: emos μ={emos_mu:.3f} σ={emos_sd:.3f}\")\n",
|
| 358 |
+
"\n",
|
| 359 |
+
"def onehot_target(tgt):\n",
|
| 360 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 361 |
+
" if tgt in EMOTIONS5:\n",
|
| 362 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 363 |
+
" return v\n",
|
| 364 |
+
"\n",
|
| 365 |
+
"def load_wav(sid):\n",
|
| 366 |
+
" p = os.path.join(WAV_DIR, sid if str(sid).endswith(\".wav\") else str(sid) + \".wav\")\n",
|
| 367 |
+
" if not os.path.exists(p):\n",
|
| 368 |
+
" return None\n",
|
| 369 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 370 |
+
" return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
|
| 371 |
+
"\n",
|
| 372 |
+
"class EmoDataset(Dataset):\n",
|
| 373 |
+
" def __init__(self, stems):\n",
|
| 374 |
+
" self.stems = [s for s in stems if load_wav(s) is not None]\n",
|
| 375 |
+
" def __len__(self):\n",
|
| 376 |
+
" return len(self.stems)\n",
|
| 377 |
+
" def __getitem__(self, i):\n",
|
| 378 |
+
" s = self.stems[i]\n",
|
| 379 |
+
" wave = load_wav(s)\n",
|
| 380 |
+
" emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
|
| 381 |
+
" if HAS_VAD:\n",
|
| 382 |
+
" vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
|
| 383 |
+
" else:\n",
|
| 384 |
+
" vad = np.zeros(3, dtype=np.float32)\n",
|
| 385 |
+
" cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
|
| 386 |
+
" return {\"wave\": wave, \"tgt\": onehot_target(target_map.get(s)),\n",
|
| 387 |
+
" \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
|
| 388 |
+
" \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
|
| 389 |
+
" \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
|
| 390 |
+
"\n",
|
| 391 |
+
"def collate(batch):\n",
|
| 392 |
+
" L = max(len(b[\"wave\"]) for b in batch)\n",
|
| 393 |
+
" waves = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 394 |
+
" mask = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 395 |
+
" for i, b in enumerate(batch):\n",
|
| 396 |
+
" waves[i, : len(b[\"wave\"])] = b[\"wave\"]; mask[i, : len(b[\"wave\"])] = 1.0\n",
|
| 397 |
+
" return {\n",
|
| 398 |
+
" \"input_values\": torch.from_numpy(waves), \"attn_mask\": torch.from_numpy(mask).long(),\n",
|
| 399 |
+
" \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
|
| 400 |
+
" \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
|
| 401 |
+
" \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
|
| 402 |
+
" \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
|
| 403 |
+
" \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
|
| 404 |
+
" \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
|
| 405 |
+
" }\n",
|
| 406 |
+
"\n",
|
| 407 |
+
"ds = EmoDataset(train_stems)\n",
|
| 408 |
+
"print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
|
| 409 |
+
"tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
|
| 410 |
+
"tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
|
| 411 |
+
"va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
|
| 412 |
+
]
|
| 413 |
+
},
|
| 414 |
+
{
|
| 415 |
+
"cell_type": "markdown",
|
| 416 |
+
"id": "d249653b",
|
| 417 |
+
"metadata": {},
|
| 418 |
+
"source": [
|
| 419 |
+
"## 5. Heads + train loop"
|
| 420 |
+
]
|
| 421 |
+
},
|
| 422 |
+
{
|
| 423 |
+
"cell_type": "code",
|
| 424 |
+
"execution_count": null,
|
| 425 |
+
"id": "fac929f9",
|
| 426 |
+
"metadata": {
|
| 427 |
+
"lines_to_next_cell": 1
|
| 428 |
+
},
|
| 429 |
+
"outputs": [],
|
| 430 |
+
"source": [
|
| 431 |
+
"from scipy.stats import spearmanr\n",
|
| 432 |
+
"\n",
|
| 433 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 434 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 435 |
+
"\n",
|
| 436 |
+
"class EmoHeads(nn.Module):\n",
|
| 437 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 438 |
+
" super().__init__()\n",
|
| 439 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 440 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 441 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 442 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 443 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 444 |
+
" def forward(self, feat, tgt):\n",
|
| 445 |
+
" h = self.trunk(feat)\n",
|
| 446 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 447 |
+
"\n",
|
| 448 |
+
"heads = EmoHeads(WAVLM_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
|
| 449 |
+
"\n",
|
| 450 |
+
"TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 451 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 452 |
+
"bb_params = [p for p in wavlm.parameters() if p.requires_grad]\n",
|
| 453 |
+
"head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 454 |
+
"opt = torch.optim.AdamW([{\"params\": bb_params, \"lr\": LR_BACKBONE},\n",
|
| 455 |
+
" {\"params\": head_params, \"lr\": LR_HEAD}], weight_decay=WEIGHT_DECAY)\n",
|
| 456 |
+
"scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
|
| 457 |
+
"mse = nn.MSELoss()\n",
|
| 458 |
+
"\n",
|
| 459 |
+
"def soft_ce(logits, target_dist):\n",
|
| 460 |
+
" return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
|
| 461 |
+
"\n",
|
| 462 |
+
"def forward_batch(b):\n",
|
| 463 |
+
" feat = wavlm_embed(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
|
| 464 |
+
" return heads(feat, b[\"tgt\"].to(device))\n",
|
| 465 |
+
"\n",
|
| 466 |
+
"def compute_loss(emos_p, cat_l, vad_p, b):\n",
|
| 467 |
+
" L = {}\n",
|
| 468 |
+
" L[\"emos\"] = mse(emos_p, b[\"emos\"].to(device))\n",
|
| 469 |
+
" L[\"cat\"] = soft_ce(cat_l, b[\"cat\"].to(device))\n",
|
| 470 |
+
" if HAS_VAD:\n",
|
| 471 |
+
" vt = b[\"vad\"].to(device)\n",
|
| 472 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
|
| 473 |
+
" else:\n",
|
| 474 |
+
" z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 475 |
+
" if USE_UNCERTAINTY:\n",
|
| 476 |
+
" return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
|
| 477 |
+
" return sum(L.values())\n",
|
| 478 |
+
"\n",
|
| 479 |
+
"@torch.no_grad()\n",
|
| 480 |
+
"def evaluate():\n",
|
| 481 |
+
" wavlm.eval(); heads.eval()\n",
|
| 482 |
+
" P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
|
| 483 |
+
" catP, catY = [], []\n",
|
| 484 |
+
" for b in va_loader:\n",
|
| 485 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 486 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 487 |
+
" P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
|
| 488 |
+
" vad_p = vad_p.float().cpu().numpy()\n",
|
| 489 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 490 |
+
" P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
|
| 491 |
+
" catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
|
| 492 |
+
" out = {}\n",
|
| 493 |
+
" for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else []):\n",
|
| 494 |
+
" out[t] = spearmanr(P[t], Y[t]).correlation\n",
|
| 495 |
+
" q = np.concatenate(catP); p = np.concatenate(catY)\n",
|
| 496 |
+
" out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
|
| 497 |
+
" return out\n",
|
| 498 |
+
"\n",
|
| 499 |
+
"def mean_srcc(m):\n",
|
| 500 |
+
" keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 501 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 502 |
+
"\n",
|
| 503 |
+
"CKPT_PATH = os.path.join(OUT_DIR, f\"ft_wavlm_{INIT_MODE}.pt\")\n",
|
| 504 |
+
"def save_full(state, val_emos=float(\"nan\")):\n",
|
| 505 |
+
" torch.save({\"wavlm\": state[\"wavlm\"], \"heads\": state[\"heads\"], \"INIT_MODE\": INIT_MODE,\n",
|
| 506 |
+
" \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
|
| 507 |
+
" \"WAVLM_DIM\": WAVLM_DIM, \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
|
| 508 |
+
"\n",
|
| 509 |
+
"best, best_state, bad = -1e9, None, 0\n",
|
| 510 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 511 |
+
" wavlm.train(); heads.train()\n",
|
| 512 |
+
" opt.zero_grad(); run = 0.0; nb = 0\n",
|
| 513 |
+
" for step, b in enumerate(tqdm(tr_loader, desc=f\"[{INIT_MODE}] epoch {ep}\")):\n",
|
| 514 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 515 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 516 |
+
" loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
|
| 517 |
+
" scaler.scale(loss).backward()\n",
|
| 518 |
+
" if (step + 1) % ACCUM == 0:\n",
|
| 519 |
+
" scaler.step(opt); scaler.update(); opt.zero_grad()\n",
|
| 520 |
+
" run += loss.item() * ACCUM; nb += 1\n",
|
| 521 |
+
" m = evaluate(); sc = mean_srcc(m)\n",
|
| 522 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
|
| 523 |
+
" print(f\"[{INIT_MODE}] epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
|
| 524 |
+
" if sc > best:\n",
|
| 525 |
+
" best = sc\n",
|
| 526 |
+
" best_state = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
|
| 527 |
+
" \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
|
| 528 |
+
" save_full(best_state, m[\"emos\"]); bad = 0\n",
|
| 529 |
+
" print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\")\n",
|
| 530 |
+
" else:\n",
|
| 531 |
+
" bad += 1\n",
|
| 532 |
+
" if bad >= PATIENCE:\n",
|
| 533 |
+
" print(f\"Early stop ở epoch {ep}.\"); break\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"if best_state:\n",
|
| 536 |
+
" wavlm.load_state_dict(best_state[\"wavlm\"]); heads.load_state_dict(best_state[\"heads\"])\n",
|
| 537 |
+
"final = evaluate()\n",
|
| 538 |
+
"print(f\"\\n✅ VAL (nội bộ) — exp12 INIT_MODE={INIT_MODE}:\")\n",
|
| 539 |
+
"print(f\" EMOS={final['emos']:.4f}\", end=\"\")\n",
|
| 540 |
+
"if HAS_VAD:\n",
|
| 541 |
+
" print(f\" | VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}\")\n",
|
| 542 |
+
"else:\n",
|
| 543 |
+
" print()\n",
|
| 544 |
+
"print(f\" cat_err={final['cat_err']:.4f} | mean SRCC={mean_srcc(final):.4f}\")\n",
|
| 545 |
+
"print(f\" (Mốc DEV exp08 để tham khảo: EMOS {EXP08['emos']}, VAD {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
|
| 546 |
+
"print(\" ➜ GHI con số này vào bảng ablation 04_ rồi đổi INIT_MODE chạy lại để so 3 mode.\")"
|
| 547 |
+
]
|
| 548 |
+
},
|
| 549 |
+
{
|
| 550 |
+
"cell_type": "markdown",
|
| 551 |
+
"id": "0874af79",
|
| 552 |
+
"metadata": {},
|
| 553 |
+
"source": [
|
| 554 |
+
"## 6. Dự đoán DEV → answer.txt (QMOS mượn exp07 / UTMOSv2)"
|
| 555 |
+
]
|
| 556 |
+
},
|
| 557 |
+
{
|
| 558 |
+
"cell_type": "code",
|
| 559 |
+
"execution_count": null,
|
| 560 |
+
"id": "b21df9af",
|
| 561 |
+
"metadata": {
|
| 562 |
+
"lines_to_next_cell": 1
|
| 563 |
+
},
|
| 564 |
+
"outputs": [],
|
| 565 |
+
"source": [
|
| 566 |
+
"def list_dev():\n",
|
| 567 |
+
" with open(DEV_SCP) as f:\n",
|
| 568 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"dev_names = list_dev()\n",
|
| 571 |
+
"if LIMIT_DEV:\n",
|
| 572 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 573 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 574 |
+
"\n",
|
| 575 |
+
"def load_exp07_qmos():\n",
|
| 576 |
+
" if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
|
| 577 |
+
" import csv\n",
|
| 578 |
+
" d = {}\n",
|
| 579 |
+
" with open(EXP07_ANSWER) as f:\n",
|
| 580 |
+
" for row in csv.DictReader(f):\n",
|
| 581 |
+
" d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
|
| 582 |
+
" print(f\"✅ Mượn QMOS exp07: {len(d)//2} wav\")\n",
|
| 583 |
+
" return d\n",
|
| 584 |
+
" return None\n",
|
| 585 |
+
"\n",
|
| 586 |
+
"qmos_map = load_exp07_qmos()\n",
|
| 587 |
+
"if qmos_map is None:\n",
|
| 588 |
+
" print(\"ℹ️ Không có exp07 → QMOS bằng UTMOSv2.\")\n",
|
| 589 |
+
" pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
|
| 590 |
+
" import utmosv2\n",
|
| 591 |
+
" v2 = utmosv2.create_model(pretrained=True)\n",
|
| 592 |
+
" qmos_map = {}\n",
|
| 593 |
+
" for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
|
| 594 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
|
| 595 |
+
" if os.path.exists(wav):\n",
|
| 596 |
+
" o = v2.predict(input_path=wav)\n",
|
| 597 |
+
" qmos_map[n] = float(o[\"predicted_mos\"]) if isinstance(o, dict) else float(o)\n",
|
| 598 |
+
" del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 599 |
+
"\n",
|
| 600 |
+
"@torch.no_grad()\n",
|
| 601 |
+
"def predict_emotion(sid):\n",
|
| 602 |
+
" wave = load_wav(sid)\n",
|
| 603 |
+
" if wave is None:\n",
|
| 604 |
+
" return None\n",
|
| 605 |
+
" wavlm.eval(); heads.eval()\n",
|
| 606 |
+
" iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 607 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
|
| 608 |
+
" tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
|
| 609 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 610 |
+
" feat = wavlm_embed(iv, am)\n",
|
| 611 |
+
" emos_p, cat_l, vad_p = heads(feat, tgt)\n",
|
| 612 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 613 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 614 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 615 |
+
" return emos, cat5, vad3\n",
|
| 616 |
+
"\n",
|
| 617 |
+
"def fmt_cat(p5):\n",
|
| 618 |
+
" return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"answer_path = os.path.join(OUT_DIR, f\"answer_{INIT_MODE}.txt\")\n",
|
| 621 |
+
"n_real = n_def = 0\n",
|
| 622 |
+
"with open(answer_path, \"w\") as f:\n",
|
| 623 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 624 |
+
" for name in tqdm(dev_names, desc=f\"answer[{INIT_MODE}]\"):\n",
|
| 625 |
+
" sid = stem(name)\n",
|
| 626 |
+
" pr = predict_emotion(sid)\n",
|
| 627 |
+
" if pr is None:\n",
|
| 628 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
|
| 629 |
+
" else:\n",
|
| 630 |
+
" emos, cat5, vad3 = pr; n_real += 1\n",
|
| 631 |
+
" qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
|
| 632 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 633 |
+
"print(f\"Ghi {len(dev_names)} dòng → {answer_path} | thật {n_real}, mặc định {n_def}\")"
|
| 634 |
+
]
|
| 635 |
+
},
|
| 636 |
+
{
|
| 637 |
+
"cell_type": "markdown",
|
| 638 |
+
"id": "eaebc2e5",
|
| 639 |
+
"metadata": {},
|
| 640 |
+
"source": [
|
| 641 |
+
"## 7. Validate + zip"
|
| 642 |
+
]
|
| 643 |
+
},
|
| 644 |
+
{
|
| 645 |
+
"cell_type": "code",
|
| 646 |
+
"execution_count": null,
|
| 647 |
+
"id": "43e0440b",
|
| 648 |
+
"metadata": {},
|
| 649 |
+
"outputs": [],
|
| 650 |
+
"source": [
|
| 651 |
+
"def validate(path):\n",
|
| 652 |
+
" import csv\n",
|
| 653 |
+
" with open(path) as f:\n",
|
| 654 |
+
" rows = list(csv.reader(f))\n",
|
| 655 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0], \"Header sai\"\n",
|
| 656 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 657 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 658 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"validate(answer_path)\n",
|
| 661 |
+
"os.system(f\"cd {OUT_DIR} && cp answer_{INIT_MODE}.txt answer.txt && zip -j submission_track2_exp12_{INIT_MODE}.zip answer.txt && unzip -l submission_track2_exp12_{INIT_MODE}.zip\")\n",
|
| 662 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, f\"submission_track2_exp12_{INIT_MODE}.zip\"))"
|
| 663 |
+
]
|
| 664 |
+
},
|
| 665 |
+
{
|
| 666 |
+
"cell_type": "markdown",
|
| 667 |
+
"id": "ade69063",
|
| 668 |
+
"metadata": {},
|
| 669 |
+
"source": [
|
| 670 |
+
"## Ghi chú\n",
|
| 671 |
+
"- **Chạy 3 lần** đổi `INIT_MODE` (\"scratch\"→\"base\"→\"sailer\"), ghi `mean SRCC` mỗi lần vào BẢNG ABLATION\n",
|
| 672 |
+
" trong `docs/04_experiments_log.md` → trả lời mentor bằng số: from-scratch tốt hơn fine-tune không?\n",
|
| 673 |
+
"- **scratch nặng:** mở băng toàn bộ WavLM-large. Nếu OOM → giảm `BATCH` (4→2), `MAX_SECONDS` (6→5),\n",
|
| 674 |
+
" hoặc đổi sang `microsoft/wavlm-base-plus` (sửa cell 2) cho khả thi (lưu ý: khác kiến trúc → so kém công bằng hơn).\n",
|
| 675 |
+
"- **scratch chậm + cần nhiều epoch hơn** (random init): để `EPOCHS=15`, `PATIENCE=5`. Vẫn nhiều khả năng < base/sailer.\n",
|
| 676 |
+
"- **Đừng nhầm VAL nội bộ với DEV.** So 3 mode bằng VAL nội bộ đã đủ kết luận; muốn chắc thì nộp mode tốt nhất.\n",
|
| 677 |
+
"- Checkpoint lưu `ft_wavlm_<mode>.pt`. Save Version sau mỗi lần chạy."
|
| 678 |
+
]
|
| 679 |
+
}
|
| 680 |
+
],
|
| 681 |
+
"metadata": {
|
| 682 |
+
"jupytext": {
|
| 683 |
+
"cell_metadata_filter": "-all",
|
| 684 |
+
"main_language": "python",
|
| 685 |
+
"notebook_metadata_filter": "-all"
|
| 686 |
+
}
|
| 687 |
+
},
|
| 688 |
+
"nbformat": 4,
|
| 689 |
+
"nbformat_minor": 5
|
| 690 |
+
}
|
track2/exp12_wavlm_scratch_pipeline.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp12 (WavLM: SCRATCH vs BASE vs SAILER — ablation khởi tạo) — Kaggle T4
|
| 3 |
+
#
|
| 4 |
+
# **Mục đích:** kiểm chứng giả thuyết của mentor — *"với 12k data, train from scratch có tốt hơn fine-tune không?"*
|
| 5 |
+
# Một notebook, đổi cờ `INIT_MODE` để chạy 3 cách khởi tạo backbone WavLM, so trên CÙNG kiến trúc/data:
|
| 6 |
+
#
|
| 7 |
+
# | INIT_MODE | Khởi tạo WavLM | Train gì | Ý nghĩa |
|
| 8 |
+
# |---|---|---|---|
|
| 9 |
+
# | `scratch` | **ngẫu nhiên** (không pretrain) | **toàn bộ** backbone | "from scratch" đúng nghĩa mentor nói |
|
| 10 |
+
# | `base` | microsoft/wavlm-large (pretrain SSL, KHÔNG cảm xúc) | mở băng N lớp trên | đo lợi ích của SAILER warm-start |
|
| 11 |
+
# | `sailer` | warm-start cảm xúc (như exp08) | mở băng N lớp trên | bản mạnh hiện tại |
|
| 12 |
+
#
|
| 13 |
+
# **Chỉ WavLM** (bỏ audeering) để cô lập đúng biến "khởi tạo". QMOS mượn exp07 / UTMOSv2.
|
| 14 |
+
#
|
| 15 |
+
# ## ⚠️ Kỳ vọng trung thực (để đọc kết quả đúng)
|
| 16 |
+
# - `scratch` gần như CHẮC CHẮN yếu hơn `base`/`sailer`: 12k mẫu là quá ít để dạy WavLM "nghe" từ đầu
|
| 17 |
+
# (SSL pretrain dùng ~94.000 GIỜ audio). Đây là ablation để **chứng minh bằng số**, không phải để vượt.
|
| 18 |
+
# - `scratch` phải mở băng TOÀN BỘ (mới có gì để học) → **nặng + chậm + dễ OOM** trên T4. Dùng LIMIT nhỏ trước.
|
| 19 |
+
# - So sánh bằng **VAL nội bộ** giữa 3 mode đã đủ kết luận; muốn chắc thì nộp mode tốt nhất lên DEV.
|
| 20 |
+
#
|
| 21 |
+
# **Cách chạy:** GPU T4 + Internet On → sửa cell 0 (`INIT_MODE` + slug) → Run All. Chạy 3 lần đổi INIT_MODE.
|
| 22 |
+
|
| 23 |
+
# %% [markdown]
|
| 24 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 25 |
+
|
| 26 |
+
# %%
|
| 27 |
+
import os
|
| 28 |
+
|
| 29 |
+
INIT_MODE = "sailer" # << "scratch" | "base" | "sailer" (đổi rồi chạy lại để so) — "sailer" = WavLM warm-start cảm xúc
|
| 30 |
+
|
| 31 |
+
DATA_ROOT = "/kaggle/input/datasets/minhtoan2/vmc2026-track2-full" # << SỬA slug cho khớp Add Input
|
| 32 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 33 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv"
|
| 34 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv"
|
| 35 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 36 |
+
|
| 37 |
+
EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn) mượn QMOS 0.548; không có → UTMOSv2
|
| 38 |
+
OUT_DIR = "/kaggle/working"
|
| 39 |
+
|
| 40 |
+
# ── Siêu tham số ─────────────────────────────────────────────────────────────
|
| 41 |
+
DEVICE = "cuda"
|
| 42 |
+
SR = 16000
|
| 43 |
+
MAX_SECONDS = 6
|
| 44 |
+
TRUNK_HIDDEN = 512
|
| 45 |
+
HEAD_HIDDEN = 128
|
| 46 |
+
DROPOUT = 0.3
|
| 47 |
+
WEIGHT_DECAY = 1e-5
|
| 48 |
+
EPOCHS = 15
|
| 49 |
+
PATIENCE = 5
|
| 50 |
+
BATCH = 4
|
| 51 |
+
ACCUM = 8
|
| 52 |
+
VAL_FRAC = 0.10
|
| 53 |
+
SEED = 42
|
| 54 |
+
USE_AMP = True
|
| 55 |
+
USE_GRAD_CKPT = True
|
| 56 |
+
USE_UNCERTAINTY = True
|
| 57 |
+
|
| 58 |
+
# Khởi tạo & LR & mở băng — TỰ đặt theo INIT_MODE (scratch cần LR lớn + mở băng toàn bộ)
|
| 59 |
+
if INIT_MODE == "scratch":
|
| 60 |
+
UNFREEZE_TOP_LAYERS = "all" # random init → phải train tất cả mới học được
|
| 61 |
+
LR_BACKBONE = 1e-4 # random init cần bước lớn hơn fine-tune
|
| 62 |
+
elif INIT_MODE in ("base", "sailer"):
|
| 63 |
+
UNFREEZE_TOP_LAYERS = 6 # fine-tune: chỉ mở băng N lớp trên (tiết kiệm VRAM, chống overfit)
|
| 64 |
+
LR_BACKBONE = 1e-5
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError(f"INIT_MODE lạ: {INIT_MODE}")
|
| 67 |
+
LR_HEAD = 1e-3
|
| 68 |
+
|
| 69 |
+
LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
|
| 70 |
+
LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
|
| 71 |
+
|
| 72 |
+
EXP08 = {"emos": 0.811, "cat_err": 0.133, "val": 0.659, "aro": 0.793, "dom": 0.751} # mốc DEV để tham khảo
|
| 73 |
+
|
| 74 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 75 |
+
_EMO_ALIAS = {
|
| 76 |
+
"angry": "angry", "anger": "angry",
|
| 77 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 78 |
+
"neutral": "neutral", "calm": "neutral",
|
| 79 |
+
"sad": "sad", "sadness": "sad",
|
| 80 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
def norm_emotion(label):
|
| 84 |
+
key = str(label).strip().lower()
|
| 85 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 86 |
+
|
| 87 |
+
def stem(p):
|
| 88 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 89 |
+
|
| 90 |
+
print(f"INIT_MODE = {INIT_MODE} | UNFREEZE = {UNFREEZE_TOP_LAYERS} | LR_BACKBONE = {LR_BACKBONE}")
|
| 91 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 92 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 93 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 94 |
+
|
| 95 |
+
# %% [markdown]
|
| 96 |
+
# ## 1. Cài đặt (clone SAILER chỉ khi INIT_MODE='sailer')
|
| 97 |
+
|
| 98 |
+
# %%
|
| 99 |
+
import sys, subprocess
|
| 100 |
+
import numpy as _np
|
| 101 |
+
|
| 102 |
+
# ⚠️ KHÓA numpy = bản Kaggle đang có → pip KHÔNG được nâng/hạ numpy → tránh "SystemError: bad call flags"
|
| 103 |
+
# (lỗi import torch do numpy lệch phiên bản với torch đã biên dịch sẵn).
|
| 104 |
+
_NPIN = f"numpy=={_np.__version__}"
|
| 105 |
+
print("Khóa numpy ở:", _NPIN)
|
| 106 |
+
|
| 107 |
+
def pip_install(*pkgs):
|
| 108 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs, _NPIN], check=True)
|
| 109 |
+
|
| 110 |
+
# Kaggle đã có sẵn torch/transformers/librosa/scipy/sklearn/pandas/tqdm/huggingface_hub/safetensors.
|
| 111 |
+
# Chỉ cài thêm vài gói speech còn thiếu (kèm khóa numpy ở trên).
|
| 112 |
+
pip_install("loralib", "speechmos", "soundfile")
|
| 113 |
+
if INIT_MODE == "sailer":
|
| 114 |
+
pip_install("speechbrain")
|
| 115 |
+
|
| 116 |
+
if INIT_MODE == "sailer":
|
| 117 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 118 |
+
if not os.path.exists(REPO_DIR):
|
| 119 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 120 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 121 |
+
if REPO_DIR not in sys.path:
|
| 122 |
+
sys.path.insert(0, REPO_DIR)
|
| 123 |
+
|
| 124 |
+
# %% [markdown]
|
| 125 |
+
# ## 2. Dựng WavLM theo INIT_MODE
|
| 126 |
+
|
| 127 |
+
# %%
|
| 128 |
+
import torch
|
| 129 |
+
import torch.nn as nn
|
| 130 |
+
import torch.nn.functional as F
|
| 131 |
+
import numpy as np
|
| 132 |
+
|
| 133 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 134 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
|
| 135 |
+
|
| 136 |
+
from transformers import WavLMModel, WavLMConfig
|
| 137 |
+
|
| 138 |
+
def find_hf_backbone(module):
|
| 139 |
+
cands = []
|
| 140 |
+
for name, m in module.named_modules():
|
| 141 |
+
enc = getattr(m, "encoder", None)
|
| 142 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 143 |
+
and getattr(enc, "layers", None) is not None:
|
| 144 |
+
cands.append((name, m))
|
| 145 |
+
if not cands:
|
| 146 |
+
return None, None
|
| 147 |
+
cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
|
| 148 |
+
return cands[0]
|
| 149 |
+
|
| 150 |
+
wavlm = None
|
| 151 |
+
if INIT_MODE == "scratch":
|
| 152 |
+
# Random init NHƯNG giữ ĐÚNG kiến trúc large (để công bằng với base/sailer)
|
| 153 |
+
cfg = WavLMConfig.from_pretrained("microsoft/wavlm-large")
|
| 154 |
+
wavlm = WavLMModel(cfg) # KHÔNG load trọng số → ngẫu nhiên
|
| 155 |
+
print("🎲 WavLM-large khởi tạo NGẪU NHIÊN (from scratch, không pretrain).")
|
| 156 |
+
elif INIT_MODE == "base":
|
| 157 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 158 |
+
print("📦 WavLM-large pretrain SSL (chưa học cảm xúc).")
|
| 159 |
+
elif INIT_MODE == "sailer":
|
| 160 |
+
try:
|
| 161 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 162 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 163 |
+
name, wavlm = find_hf_backbone(_wrapper)
|
| 164 |
+
print(f"🔥 WavLM warm-start SAILER (cảm xúc) tại '.{name}'")
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print("⚠️ Lỗi nạp SAILER:", repr(e), "→ fallback base pretrained.")
|
| 167 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 168 |
+
|
| 169 |
+
wavlm = wavlm.to(device)
|
| 170 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 171 |
+
wavlm.config.layerdrop = 0.0 # ⚠️ BẮT BUỘC khi dùng gradient-checkpointing (tránh CheckpointError do bỏ lớp ngẫu nhiên)
|
| 172 |
+
|
| 173 |
+
# Mở băng theo cấu hình
|
| 174 |
+
if UNFREEZE_TOP_LAYERS == "all":
|
| 175 |
+
for p in wavlm.parameters():
|
| 176 |
+
p.requires_grad = True
|
| 177 |
+
n_open = "ALL"
|
| 178 |
+
else:
|
| 179 |
+
for p in wavlm.parameters():
|
| 180 |
+
p.requires_grad = False
|
| 181 |
+
_wl = wavlm.encoder.layers
|
| 182 |
+
for layer in _wl[max(0, len(_wl) - UNFREEZE_TOP_LAYERS):]:
|
| 183 |
+
for p in layer.parameters():
|
| 184 |
+
p.requires_grad = True
|
| 185 |
+
n_open = f"top {min(UNFREEZE_TOP_LAYERS, len(_wl))}/{len(_wl)}"
|
| 186 |
+
print(f"WavLM mở băng: {n_open} → {sum(p.numel() for p in wavlm.parameters() if p.requires_grad)/1e6:.1f}M param train (dim {WAVLM_DIM})")
|
| 187 |
+
|
| 188 |
+
if USE_GRAD_CKPT:
|
| 189 |
+
wavlm.gradient_checkpointing_enable()
|
| 190 |
+
if hasattr(wavlm, "enable_input_require_grads"):
|
| 191 |
+
wavlm.enable_input_require_grads()
|
| 192 |
+
|
| 193 |
+
def masked_mean(hidden, attn_mask):
|
| 194 |
+
if attn_mask is None:
|
| 195 |
+
return hidden.mean(dim=1)
|
| 196 |
+
try:
|
| 197 |
+
fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
|
| 198 |
+
except Exception:
|
| 199 |
+
return hidden.mean(dim=1)
|
| 200 |
+
fm = fm.unsqueeze(-1).to(hidden.dtype)
|
| 201 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 202 |
+
|
| 203 |
+
def wavlm_embed(input_values, attn_mask):
|
| 204 |
+
out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
|
| 205 |
+
return masked_mean(out, attn_mask)
|
| 206 |
+
|
| 207 |
+
# %% [markdown]
|
| 208 |
+
# ## 3. Đọc & gộp nhãn theo wavID
|
| 209 |
+
|
| 210 |
+
# %%
|
| 211 |
+
import librosa
|
| 212 |
+
import pandas as pd
|
| 213 |
+
from tqdm.auto import tqdm
|
| 214 |
+
|
| 215 |
+
def load_target_emotions():
|
| 216 |
+
tgt = {}
|
| 217 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 218 |
+
for ln in f:
|
| 219 |
+
parts = ln.strip().split("|")
|
| 220 |
+
if len(parts) >= 2:
|
| 221 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 222 |
+
return tgt
|
| 223 |
+
|
| 224 |
+
def _col(cols_map, *names, df=None, default_idx=None):
|
| 225 |
+
for n in names:
|
| 226 |
+
if n in cols_map:
|
| 227 |
+
return cols_map[n]
|
| 228 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 229 |
+
|
| 230 |
+
def parse_emocat_votes(cell):
|
| 231 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 232 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 233 |
+
e = norm_emotion(tok)
|
| 234 |
+
if e in EMOTIONS5:
|
| 235 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 236 |
+
return v
|
| 237 |
+
|
| 238 |
+
def load_train_labels():
|
| 239 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 240 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 241 |
+
wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
|
| 242 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 243 |
+
val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
|
| 244 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 245 |
+
assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
|
| 246 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 247 |
+
rows = []
|
| 248 |
+
for sid, g in df.groupby("_stem"):
|
| 249 |
+
rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
|
| 250 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 251 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 252 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 253 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 254 |
+
if cat_col:
|
| 255 |
+
for cell in g[cat_col]:
|
| 256 |
+
votes += parse_emocat_votes(cell)
|
| 257 |
+
s = votes.sum()
|
| 258 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
|
| 259 |
+
for i in range(len(EMOTIONS5)):
|
| 260 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 261 |
+
rows.append(rec)
|
| 262 |
+
return pd.DataFrame(rows)
|
| 263 |
+
|
| 264 |
+
target_map = load_target_emotions()
|
| 265 |
+
train_df = load_train_labels()
|
| 266 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 267 |
+
print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
|
| 268 |
+
|
| 269 |
+
# %% [markdown]
|
| 270 |
+
# ## 4. Dataset/loader (chỉ raw wave cho WavLM)
|
| 271 |
+
|
| 272 |
+
# %%
|
| 273 |
+
from torch.utils.data import Dataset, DataLoader
|
| 274 |
+
from sklearn.model_selection import train_test_split
|
| 275 |
+
|
| 276 |
+
train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
|
| 277 |
+
if LIMIT_TRAIN:
|
| 278 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 279 |
+
lab = train_df.set_index("wavID")
|
| 280 |
+
|
| 281 |
+
def _zfit(a):
|
| 282 |
+
a = np.asarray(a, dtype=np.float32); return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
|
| 283 |
+
emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
|
| 284 |
+
if HAS_VAD:
|
| 285 |
+
vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], np.float32)
|
| 286 |
+
vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], np.float32)
|
| 287 |
+
else:
|
| 288 |
+
vad_mu = np.zeros(3, np.float32); vad_sd = np.ones(3, np.float32)
|
| 289 |
+
print(f"Chuẩn hóa: emos μ={emos_mu:.3f} σ={emos_sd:.3f}")
|
| 290 |
+
|
| 291 |
+
def onehot_target(tgt):
|
| 292 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 293 |
+
if tgt in EMOTIONS5:
|
| 294 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 295 |
+
return v
|
| 296 |
+
|
| 297 |
+
def load_wav(sid):
|
| 298 |
+
p = os.path.join(WAV_DIR, sid if str(sid).endswith(".wav") else str(sid) + ".wav")
|
| 299 |
+
if not os.path.exists(p):
|
| 300 |
+
return None
|
| 301 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 302 |
+
return wave[: MAX_SECONDS * SR].astype(np.float32)
|
| 303 |
+
|
| 304 |
+
class EmoDataset(Dataset):
|
| 305 |
+
def __init__(self, stems):
|
| 306 |
+
self.stems = [s for s in stems if load_wav(s) is not None]
|
| 307 |
+
def __len__(self):
|
| 308 |
+
return len(self.stems)
|
| 309 |
+
def __getitem__(self, i):
|
| 310 |
+
s = self.stems[i]
|
| 311 |
+
wave = load_wav(s)
|
| 312 |
+
emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
|
| 313 |
+
if HAS_VAD:
|
| 314 |
+
vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
|
| 315 |
+
else:
|
| 316 |
+
vad = np.zeros(3, dtype=np.float32)
|
| 317 |
+
cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
|
| 318 |
+
return {"wave": wave, "tgt": onehot_target(target_map.get(s)),
|
| 319 |
+
"emos": np.float32(emos), "vad": vad, "cat": cat,
|
| 320 |
+
"emos_raw": np.float32(lab.loc[s, "emos"]),
|
| 321 |
+
"vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
|
| 322 |
+
|
| 323 |
+
def collate(batch):
|
| 324 |
+
L = max(len(b["wave"]) for b in batch)
|
| 325 |
+
waves = np.zeros((len(batch), L), dtype=np.float32)
|
| 326 |
+
mask = np.zeros((len(batch), L), dtype=np.float32)
|
| 327 |
+
for i, b in enumerate(batch):
|
| 328 |
+
waves[i, : len(b["wave"])] = b["wave"]; mask[i, : len(b["wave"])] = 1.0
|
| 329 |
+
return {
|
| 330 |
+
"input_values": torch.from_numpy(waves), "attn_mask": torch.from_numpy(mask).long(),
|
| 331 |
+
"tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
|
| 332 |
+
"emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
|
| 333 |
+
"vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
|
| 334 |
+
"cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
|
| 335 |
+
"emos_raw": np.stack([b["emos_raw"] for b in batch]),
|
| 336 |
+
"vad_raw": np.stack([b["vad_raw"] for b in batch]),
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
ds = EmoDataset(train_stems)
|
| 340 |
+
print("Dataset hợp lệ:", len(ds), "wav")
|
| 341 |
+
tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
|
| 342 |
+
tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
|
| 343 |
+
va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
|
| 344 |
+
|
| 345 |
+
# %% [markdown]
|
| 346 |
+
# ## 5. Heads + train loop
|
| 347 |
+
|
| 348 |
+
# %%
|
| 349 |
+
from scipy.stats import spearmanr
|
| 350 |
+
|
| 351 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 352 |
+
N_EMO = len(EMOTIONS5)
|
| 353 |
+
|
| 354 |
+
class EmoHeads(nn.Module):
|
| 355 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 356 |
+
super().__init__()
|
| 357 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 358 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 359 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 360 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 361 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 362 |
+
def forward(self, feat, tgt):
|
| 363 |
+
h = self.trunk(feat)
|
| 364 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 365 |
+
|
| 366 |
+
heads = EmoHeads(WAVLM_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
|
| 367 |
+
|
| 368 |
+
TASKS = ["emos", "cat", "val", "aro", "dom"]
|
| 369 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 370 |
+
bb_params = [p for p in wavlm.parameters() if p.requires_grad]
|
| 371 |
+
head_params = list(heads.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
|
| 372 |
+
opt = torch.optim.AdamW([{"params": bb_params, "lr": LR_BACKBONE},
|
| 373 |
+
{"params": head_params, "lr": LR_HEAD}], weight_decay=WEIGHT_DECAY)
|
| 374 |
+
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
|
| 375 |
+
mse = nn.MSELoss()
|
| 376 |
+
|
| 377 |
+
def soft_ce(logits, target_dist):
|
| 378 |
+
return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
|
| 379 |
+
|
| 380 |
+
def forward_batch(b):
|
| 381 |
+
feat = wavlm_embed(b["input_values"].to(device), b["attn_mask"].to(device))
|
| 382 |
+
return heads(feat, b["tgt"].to(device))
|
| 383 |
+
|
| 384 |
+
def compute_loss(emos_p, cat_l, vad_p, b):
|
| 385 |
+
L = {}
|
| 386 |
+
L["emos"] = mse(emos_p, b["emos"].to(device))
|
| 387 |
+
L["cat"] = soft_ce(cat_l, b["cat"].to(device))
|
| 388 |
+
if HAS_VAD:
|
| 389 |
+
vt = b["vad"].to(device)
|
| 390 |
+
L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
|
| 391 |
+
else:
|
| 392 |
+
z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
|
| 393 |
+
if USE_UNCERTAINTY:
|
| 394 |
+
return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
|
| 395 |
+
return sum(L.values())
|
| 396 |
+
|
| 397 |
+
@torch.no_grad()
|
| 398 |
+
def evaluate():
|
| 399 |
+
wavlm.eval(); heads.eval()
|
| 400 |
+
P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
|
| 401 |
+
catP, catY = [], []
|
| 402 |
+
for b in va_loader:
|
| 403 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 404 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 405 |
+
P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
|
| 406 |
+
vad_p = vad_p.float().cpu().numpy()
|
| 407 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 408 |
+
P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
|
| 409 |
+
catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
|
| 410 |
+
out = {}
|
| 411 |
+
for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else []):
|
| 412 |
+
out[t] = spearmanr(P[t], Y[t]).correlation
|
| 413 |
+
q = np.concatenate(catP); p = np.concatenate(catY)
|
| 414 |
+
out["cat_err"] = float(np.abs(q - p).sum(1).mean())
|
| 415 |
+
return out
|
| 416 |
+
|
| 417 |
+
def mean_srcc(m):
|
| 418 |
+
keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 419 |
+
return float(np.mean([m[k] for k in keys]))
|
| 420 |
+
|
| 421 |
+
CKPT_PATH = os.path.join(OUT_DIR, f"ft_wavlm_{INIT_MODE}.pt")
|
| 422 |
+
def save_full(state, val_emos=float("nan")):
|
| 423 |
+
torch.save({"wavlm": state["wavlm"], "heads": state["heads"], "INIT_MODE": INIT_MODE,
|
| 424 |
+
"emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
|
| 425 |
+
"WAVLM_DIM": WAVLM_DIM, "val_emos": float(val_emos)}, CKPT_PATH)
|
| 426 |
+
|
| 427 |
+
best, best_state, bad = -1e9, None, 0
|
| 428 |
+
for ep in range(1, EPOCHS + 1):
|
| 429 |
+
wavlm.train(); heads.train()
|
| 430 |
+
opt.zero_grad(); run = 0.0; nb = 0
|
| 431 |
+
for step, b in enumerate(tqdm(tr_loader, desc=f"[{INIT_MODE}] epoch {ep}")):
|
| 432 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 433 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 434 |
+
loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
|
| 435 |
+
scaler.scale(loss).backward()
|
| 436 |
+
if (step + 1) % ACCUM == 0:
|
| 437 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 438 |
+
run += loss.item() * ACCUM; nb += 1
|
| 439 |
+
m = evaluate(); sc = mean_srcc(m)
|
| 440 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
|
| 441 |
+
print(f"[{INIT_MODE}] epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
|
| 442 |
+
if sc > best:
|
| 443 |
+
best = sc
|
| 444 |
+
best_state = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
|
| 445 |
+
"heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
|
| 446 |
+
save_full(best_state, m["emos"]); bad = 0
|
| 447 |
+
print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})")
|
| 448 |
+
else:
|
| 449 |
+
bad += 1
|
| 450 |
+
if bad >= PATIENCE:
|
| 451 |
+
print(f"Early stop ở epoch {ep}."); break
|
| 452 |
+
|
| 453 |
+
if best_state:
|
| 454 |
+
wavlm.load_state_dict(best_state["wavlm"]); heads.load_state_dict(best_state["heads"])
|
| 455 |
+
final = evaluate()
|
| 456 |
+
print(f"\n✅ VAL (nội bộ) — exp12 INIT_MODE={INIT_MODE}:")
|
| 457 |
+
print(f" EMOS={final['emos']:.4f}", end="")
|
| 458 |
+
if HAS_VAD:
|
| 459 |
+
print(f" | VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}")
|
| 460 |
+
else:
|
| 461 |
+
print()
|
| 462 |
+
print(f" cat_err={final['cat_err']:.4f} | mean SRCC={mean_srcc(final):.4f}")
|
| 463 |
+
print(f" (Mốc DEV exp08 để tham khảo: EMOS {EXP08['emos']}, VAD {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
|
| 464 |
+
print(" ➜ GHI con số này vào bảng ablation 04_ rồi đổi INIT_MODE chạy lại để so 3 mode.")
|
| 465 |
+
|
| 466 |
+
# %% [markdown]
|
| 467 |
+
# ## 6. Dự đoán DEV → answer.txt (QMOS mượn exp07 / UTMOSv2)
|
| 468 |
+
|
| 469 |
+
# %%
|
| 470 |
+
def list_dev():
|
| 471 |
+
with open(DEV_SCP) as f:
|
| 472 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 473 |
+
|
| 474 |
+
dev_names = list_dev()
|
| 475 |
+
if LIMIT_DEV:
|
| 476 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 477 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 478 |
+
|
| 479 |
+
def load_exp07_qmos():
|
| 480 |
+
if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
|
| 481 |
+
import csv
|
| 482 |
+
d = {}
|
| 483 |
+
with open(EXP07_ANSWER) as f:
|
| 484 |
+
for row in csv.DictReader(f):
|
| 485 |
+
d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
|
| 486 |
+
print(f"✅ Mượn QMOS exp07: {len(d)//2} wav")
|
| 487 |
+
return d
|
| 488 |
+
return None
|
| 489 |
+
|
| 490 |
+
qmos_map = load_exp07_qmos()
|
| 491 |
+
if qmos_map is None:
|
| 492 |
+
print("ℹ️ Không có exp07 → QMOS bằng UTMOSv2.")
|
| 493 |
+
pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
|
| 494 |
+
import utmosv2
|
| 495 |
+
v2 = utmosv2.create_model(pretrained=True)
|
| 496 |
+
qmos_map = {}
|
| 497 |
+
for n in tqdm(dev_names, desc="UTMOSv2"):
|
| 498 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
|
| 499 |
+
if os.path.exists(wav):
|
| 500 |
+
o = v2.predict(input_path=wav)
|
| 501 |
+
qmos_map[n] = float(o["predicted_mos"]) if isinstance(o, dict) else float(o)
|
| 502 |
+
del v2; torch.cuda.empty_cache() if device == "cuda" else None
|
| 503 |
+
|
| 504 |
+
@torch.no_grad()
|
| 505 |
+
def predict_emotion(sid):
|
| 506 |
+
wave = load_wav(sid)
|
| 507 |
+
if wave is None:
|
| 508 |
+
return None
|
| 509 |
+
wavlm.eval(); heads.eval()
|
| 510 |
+
iv = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 511 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
|
| 512 |
+
tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
|
| 513 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 514 |
+
feat = wavlm_embed(iv, am)
|
| 515 |
+
emos_p, cat_l, vad_p = heads(feat, tgt)
|
| 516 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 517 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 518 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 519 |
+
return emos, cat5, vad3
|
| 520 |
+
|
| 521 |
+
def fmt_cat(p5):
|
| 522 |
+
return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 523 |
+
|
| 524 |
+
answer_path = os.path.join(OUT_DIR, f"answer_{INIT_MODE}.txt")
|
| 525 |
+
n_real = n_def = 0
|
| 526 |
+
with open(answer_path, "w") as f:
|
| 527 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 528 |
+
for name in tqdm(dev_names, desc=f"answer[{INIT_MODE}]"):
|
| 529 |
+
sid = stem(name)
|
| 530 |
+
pr = predict_emotion(sid)
|
| 531 |
+
if pr is None:
|
| 532 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
|
| 533 |
+
else:
|
| 534 |
+
emos, cat5, vad3 = pr; n_real += 1
|
| 535 |
+
qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
|
| 536 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 537 |
+
print(f"Ghi {len(dev_names)} dòng → {answer_path} | thật {n_real}, mặc định {n_def}")
|
| 538 |
+
|
| 539 |
+
# %% [markdown]
|
| 540 |
+
# ## 7. Validate + zip
|
| 541 |
+
|
| 542 |
+
# %%
|
| 543 |
+
def validate(path):
|
| 544 |
+
import csv
|
| 545 |
+
with open(path) as f:
|
| 546 |
+
rows = list(csv.reader(f))
|
| 547 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0], "Header sai"
|
| 548 |
+
for i, r in enumerate(rows[1:], 2):
|
| 549 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 550 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 551 |
+
|
| 552 |
+
validate(answer_path)
|
| 553 |
+
os.system(f"cd {OUT_DIR} && cp answer_{INIT_MODE}.txt answer.txt && zip -j submission_track2_exp12_{INIT_MODE}.zip answer.txt && unzip -l submission_track2_exp12_{INIT_MODE}.zip")
|
| 554 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, f"submission_track2_exp12_{INIT_MODE}.zip"))
|
| 555 |
+
|
| 556 |
+
# %% [markdown]
|
| 557 |
+
# ## Ghi chú
|
| 558 |
+
# - **Chạy 3 lần** đổi `INIT_MODE` ("scratch"→"base"→"sailer"), ghi `mean SRCC` mỗi lần vào BẢNG ABLATION
|
| 559 |
+
# trong `docs/04_experiments_log.md` → trả lời mentor bằng số: from-scratch tốt hơn fine-tune không?
|
| 560 |
+
# - **scratch nặng:** mở băng toàn bộ WavLM-large. Nếu OOM → giảm `BATCH` (4→2), `MAX_SECONDS` (6→5),
|
| 561 |
+
# hoặc đổi sang `microsoft/wavlm-base-plus` (sửa cell 2) cho khả thi (lưu ý: khác kiến trúc → so kém công bằng hơn).
|
| 562 |
+
# - **scratch chậm + cần nhiều epoch hơn** (random init): để `EPOCHS=15`, `PATIENCE=5`. Vẫn nhiều khả năng < base/sailer.
|
| 563 |
+
# - **Đừng nhầm VAL nội bộ với DEV.** So 3 mode bằng VAL nội bộ đã đủ kết luận; muốn chắc thì nộp mode tốt nhất.
|
| 564 |
+
# - Checkpoint lưu `ft_wavlm_<mode>.pt`. Save Version sau mỗi lần chạy.
|
track2/exp13_finetune_qmos.ipynb
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "d3c827cb",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp13 (FINE-TUNE UTMOS cho QMOS) + answer 6 cột — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục tiêu:** QMOS hiện tốt nhất = 0.548 (exp07, head ĐÓNG BĂNG + neo UTMOS). exp13 thử **MỞ BĂNG\n",
|
| 11 |
+
"(fine-tune) thẳng UTMOS** trên nhãn `qMOS` thật của Track 2 → kéo model chất lượng về đúng domain giọng\n",
|
| 12 |
+
"cảm xúc. Sau đó **mượn 5 cột cảm xúc từ checkpoint exp08** (`ft_emotion_full_20epoch.pt` — bản TỐT NHẤT)\n",
|
| 13 |
+
"→ ghép `answer.txt` 6 cột.\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"## Vì sao fine-tune UTMOS (không phải UTMOSv2)\n",
|
| 16 |
+
"- UTMOS (`utmos22_strong`, tarepan/SpeechMOS) = **1 model đơn**, tải qua `torch.hub`, **bản thân đã dự đoán\n",
|
| 17 |
+
" QMOS** → warm-start hoàn hảo cho cột chất lượng (khác UTMOSv2 = ensemble nhiều fold + 2 luồng → khó train).\n",
|
| 18 |
+
"- forward: `model(wave[B,T], sr) -> MOS[B]`, là `nn.Module` chuẩn → backprop được toàn model.\n",
|
| 19 |
+
"- **Không dùng neo UTMOS riêng** (đã chốt): khi fine-tune chính UTMOS thì \"neo\" nằm sẵn trong trọng số\n",
|
| 20 |
+
" warm-start → head/neo ngoài là thừa.\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"## Thiết kế\n",
|
| 23 |
+
"```\n",
|
| 24 |
+
" [PHẦN A] wav ─► UTMOS (utmos22_strong, TRAINABLE, warm-start pretrained) ─► QMOS (train trên qMOS gold)\n",
|
| 25 |
+
" [PHẦN B] wav ─► WavLM(exp08 ft) + audeering(frozen) ─► EMOS/CAT/VAD (NẠP ckpt, chỉ inference)\n",
|
| 26 |
+
" [PHẦN C] ghép QMOS(A) + 5 cột cảm xúc(B) ─► answer.txt 6 cột ─► validate ─► zip\n",
|
| 27 |
+
"```\n",
|
| 28 |
+
"\n",
|
| 29 |
+
"## ⚠️ Phải biết trước\n",
|
| 30 |
+
"- Fine-tune = **không cache** (mỗi epoch chạy lại UTMOS forward+backward) → tốn giờ GPU. **Lần đầu BẮT BUỘC\n",
|
| 31 |
+
" `LIMIT_TRAIN=300`, `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.\n",
|
| 32 |
+
"- Lưới an toàn: chỉ nộp QMOS fine-tune nếu **SRCC val nội bộ > zero-shot UTMOS** (mục A in cả 2 số).\n",
|
| 33 |
+
"- **Lưu checkpoint `ft_qmos_utmos.pt` mỗi best + Save Version NGAY** (bài học exp08: kernel chết là mất).\n",
|
| 34 |
+
"\n",
|
| 35 |
+
"**Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input (1) dataset Track 2, (2) dataset chứa\n",
|
| 36 |
+
"`ft_emotion_full.pt` (exp08), (3) tùy chọn cache `aud_dev.npz` → sửa slug cell 0 → Run All."
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "markdown",
|
| 41 |
+
"id": "6a78806d",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"source": [
|
| 44 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": null,
|
| 50 |
+
"id": "1374fa7d",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"import os, shutil, glob\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets/train.csv + wav/ + metadata.csv) ──\n",
|
| 57 |
+
"def find_data_root(search_root=\"/kaggle/input\"):\n",
|
| 58 |
+
" cands = []\n",
|
| 59 |
+
" for train_csv in glob.glob(os.path.join(search_root, \"**\", \"sets\", \"train.csv\"), recursive=True):\n",
|
| 60 |
+
" root = os.path.dirname(os.path.dirname(train_csv)) # .../<root>/sets/train.csv → <root>\n",
|
| 61 |
+
" score = os.path.isdir(os.path.join(root, \"wav\")) + os.path.exists(os.path.join(root, \"metadata.csv\"))\n",
|
| 62 |
+
" cands.append((score, root))\n",
|
| 63 |
+
" cands.sort(reverse=True) # ưu tiên thư mục đủ wav + metadata\n",
|
| 64 |
+
" return cands\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"_cands = find_data_root(\"/kaggle/input\")\n",
|
| 67 |
+
"if _cands:\n",
|
| 68 |
+
" print(\"🔎 Ứng viên DATA_ROOT (điểm cao = đủ wav+metadata):\")\n",
|
| 69 |
+
" for sc, r in _cands:\n",
|
| 70 |
+
" print(f\" [{sc}/2] {r}\")\n",
|
| 71 |
+
" DATA_ROOT = _cands[0][1]\n",
|
| 72 |
+
" print(f\"👉 Tự chọn DATA_ROOT = {DATA_ROOT}\")\n",
|
| 73 |
+
"else:\n",
|
| 74 |
+
" DATA_ROOT = \"/kaggle/input/datasets/minhtoan2\" # dự phòng — sửa tay nếu auto-dò không thấy\n",
|
| 75 |
+
" print(f\"❌ Không thấy sets/train.csv trong /kaggle/input → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)\")\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 78 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (cho cột cảm xúc)\n",
|
| 79 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
|
| 80 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"# ── Checkpoint cảm xúc exp08 (để sinh 5 cột EMOS/CAT/VAD) ─────────────────────\n",
|
| 83 |
+
"# ⭐ TỐT NHẤT = ft_emotion_full_20epoch.pt (bản 20 epoch) — dùng bản này, KHÔNG dùng ft_emotion_full.pt.\n",
|
| 84 |
+
"EMO_CKPT = \"/kaggle/input/ft-emotion-full/ft_emotion_full_20epoch.pt\" # << ckpt exp08 20ep (CÓ backbone WavLM)\n",
|
| 85 |
+
"CACHE_INPUT = \"/kaggle/input/ft-emotion-cache\" # << (tùy chọn) thư mục chứa aud_dev.npz; \"\" nếu không có\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 88 |
+
"CACHE_DIR = \"/kaggle/working/ft_cache\" # /kaggle/input read-only → copy cache audeering sang đây\n",
|
| 89 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# ── PHẦN A: fine-tune UTMOS (QMOS) ───────────────────────────────────────────\n",
|
| 92 |
+
"DEVICE = \"cuda\"\n",
|
| 93 |
+
"SR = 16000\n",
|
| 94 |
+
"QMOS_MAX_SEC = 12 # cắt audio chặn bộ nhớ backprop (UTMOS); OOM thì giảm 10/8\n",
|
| 95 |
+
"LR = 1e-5 # LR nhỏ cho fine-tune (warm-start sẵn tốt)\n",
|
| 96 |
+
"WEIGHT_DECAY = 1e-5\n",
|
| 97 |
+
"EPOCHS = 10 # TRẦN; early-stop quyết số epoch thật\n",
|
| 98 |
+
"PATIENCE = 3\n",
|
| 99 |
+
"BATCH = 1 # UTMOS forward KHÔNG có attention-mask → BATCH=1 an toàn (pad zero sẽ lệch pooling)\n",
|
| 100 |
+
"ACCUM = 16 # effective batch = BATCH*ACCUM = 16\n",
|
| 101 |
+
"VAL_FRAC = 0.10\n",
|
| 102 |
+
"SEED = 42\n",
|
| 103 |
+
"USE_AMP = True\n",
|
| 104 |
+
"RANK_LAMBDA = 0.0 # 0 = chỉ MSE. >0 (vd 0.3) = cộng pairwise ranking loss (tối ưu thẳng thứ hạng=SRCC)\n",
|
| 105 |
+
"FREEZE_FEAT_EXT = True # đóng băng feature-extractor (CNN conv) của UTMOS → đỡ VRAM + chống overfit\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"# ── PHẦN B: inference cảm xúc (PHẢI khớp kiến trúc exp08) ─────────────────────\n",
|
| 108 |
+
"EMO_MAX_SEC = 8\n",
|
| 109 |
+
"UNFREEZE_TOP_LAYERS = 6 # khớp ckpt exp08\n",
|
| 110 |
+
"TRUNK_HIDDEN = 512\n",
|
| 111 |
+
"HEAD_HIDDEN = 128\n",
|
| 112 |
+
"DROPOUT = 0.3\n",
|
| 113 |
+
"USE_AUDEERING = True # khớp ckpt exp08\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
|
| 116 |
+
"LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"# Mốc QMOS để so (leaderboard DEV)\n",
|
| 119 |
+
"QMOS_BASELINE = {\"utmos_zeroshot\": 0.414, \"exp07_head\": 0.548}\n",
|
| 120 |
+
"\n",
|
| 121 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 122 |
+
"_EMO_ALIAS = {\n",
|
| 123 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 124 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 125 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 126 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 127 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 128 |
+
"}\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"def norm_emotion(label):\n",
|
| 131 |
+
" key = str(label).strip().lower()\n",
|
| 132 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"def stem(p):\n",
|
| 135 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 138 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP, EMO_CKPT]:\n",
|
| 139 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
|
| 140 |
+
"print(f\"Fine-tune UTMOS: LR {LR} · BATCH {BATCH}×ACCUM {ACCUM} · MAX {QMOS_MAX_SEC}s · rank λ {RANK_LAMBDA}\")\n",
|
| 141 |
+
"\n",
|
| 142 |
+
"# Copy cache audeering (aud_dev.npz) từ input read-only sang working (để cột cảm xúc khỏi trích lại)\n",
|
| 143 |
+
"if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
|
| 144 |
+
" n = 0\n",
|
| 145 |
+
" for fn in os.listdir(CACHE_INPUT):\n",
|
| 146 |
+
" if fn.startswith(\"aud_\") and fn.endswith(\".npz\"):\n",
|
| 147 |
+
" shutil.copy(os.path.join(CACHE_INPUT, fn), os.path.join(CACHE_DIR, fn)); n += 1\n",
|
| 148 |
+
" print(f\"📦 Copy {n} file cache audeering từ {CACHE_INPUT} → {CACHE_DIR}\")\n",
|
| 149 |
+
"else:\n",
|
| 150 |
+
" print(\"ℹ️ Không có CACHE_INPUT → sẽ tự trích audeering cho DEV (chậm hơn lần đầu).\")"
|
| 151 |
+
]
|
| 152 |
+
},
|
| 153 |
+
{
|
| 154 |
+
"cell_type": "markdown",
|
| 155 |
+
"id": "5e568431",
|
| 156 |
+
"metadata": {},
|
| 157 |
+
"source": [
|
| 158 |
+
"## 1. Cài đặt"
|
| 159 |
+
]
|
| 160 |
+
},
|
| 161 |
+
{
|
| 162 |
+
"cell_type": "code",
|
| 163 |
+
"execution_count": null,
|
| 164 |
+
"id": "731d1056",
|
| 165 |
+
"metadata": {},
|
| 166 |
+
"outputs": [],
|
| 167 |
+
"source": [
|
| 168 |
+
"import sys, subprocess\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"def pip_install(*pkgs):\n",
|
| 171 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"pip_install(\"speechmos\", \"loralib\", \"speechbrain\", \"librosa\", \"soundfile\",\n",
|
| 174 |
+
" \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
|
| 175 |
+
"\n",
|
| 176 |
+
"# Code SAILER (để dựng đúng kiến trúc WavLM rồi nạp ckpt exp08 đè lên) — chỉ cần cho PHẦN B\n",
|
| 177 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 178 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 179 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 180 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 181 |
+
"if REPO_DIR not in sys.path:\n",
|
| 182 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 183 |
+
]
|
| 184 |
+
},
|
| 185 |
+
{
|
| 186 |
+
"cell_type": "markdown",
|
| 187 |
+
"id": "732b81f8",
|
| 188 |
+
"metadata": {},
|
| 189 |
+
"source": [
|
| 190 |
+
"## 2. Nhãn vàng qMOS (gộp trung bình theo wav) — như exp06/exp09a"
|
| 191 |
+
]
|
| 192 |
+
},
|
| 193 |
+
{
|
| 194 |
+
"cell_type": "code",
|
| 195 |
+
"execution_count": null,
|
| 196 |
+
"id": "8cfb94af",
|
| 197 |
+
"metadata": {},
|
| 198 |
+
"outputs": [],
|
| 199 |
+
"source": [
|
| 200 |
+
"import numpy as np\n",
|
| 201 |
+
"import pandas as pd\n",
|
| 202 |
+
"\n",
|
| 203 |
+
"def load_qmos_labels():\n",
|
| 204 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 205 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 206 |
+
" wav_col = cols.get(\"wavid\") or cols.get(\"wav\") or list(df.columns)[1]\n",
|
| 207 |
+
" qmos_col = cols.get(\"qmos\") or cols.get(\"mos\")\n",
|
| 208 |
+
" assert qmos_col, f\"Không thấy cột qMOS (cột: {list(df.columns)})\"\n",
|
| 209 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 210 |
+
" g = df.groupby(\"_stem\")[qmos_col].mean()\n",
|
| 211 |
+
" return {s: float(v) for s, v in g.items()}\n",
|
| 212 |
+
"\n",
|
| 213 |
+
"qmos_gold = load_qmos_labels()\n",
|
| 214 |
+
"print(f\"Số wav train có nhãn qMOS: {len(qmos_gold)}\")\n",
|
| 215 |
+
"_vals = np.array(list(qmos_gold.values()))\n",
|
| 216 |
+
"print(f\"qMOS gold: mean {_vals.mean():.3f} · std {_vals.std():.3f} · min {_vals.min():.2f} · max {_vals.max():.2f}\")"
|
| 217 |
+
]
|
| 218 |
+
},
|
| 219 |
+
{
|
| 220 |
+
"cell_type": "markdown",
|
| 221 |
+
"id": "374534a0",
|
| 222 |
+
"metadata": {},
|
| 223 |
+
"source": [
|
| 224 |
+
"## 3. PHẦN A — Fine-tune UTMOS trên qMOS\n",
|
| 225 |
+
"UTMOS xuất MOS thang ~1–5 (đã warm-start) → train MSE trên thang GỐC (không z-score, để giữ ý nghĩa warm-start).\n",
|
| 226 |
+
"`BATCH=1` + grad-accum: tránh phải pad (UTMOS forward không nhận attention-mask)."
|
| 227 |
+
]
|
| 228 |
+
},
|
| 229 |
+
{
|
| 230 |
+
"cell_type": "code",
|
| 231 |
+
"execution_count": null,
|
| 232 |
+
"id": "4636d35c",
|
| 233 |
+
"metadata": {
|
| 234 |
+
"lines_to_next_cell": 1
|
| 235 |
+
},
|
| 236 |
+
"outputs": [],
|
| 237 |
+
"source": [
|
| 238 |
+
"import torch\n",
|
| 239 |
+
"import torch.nn as nn\n",
|
| 240 |
+
"import librosa\n",
|
| 241 |
+
"from tqdm.auto import tqdm\n",
|
| 242 |
+
"from scipy.stats import spearmanr\n",
|
| 243 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 244 |
+
"\n",
|
| 245 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 246 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
|
| 247 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 248 |
+
"\n",
|
| 249 |
+
"# Nạp UTMOS (torch.hub) — model nn.Module, forward(wave[B,T], sr) -> MOS[B]\n",
|
| 250 |
+
"utmos = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\", trust_repo=True).to(device)\n",
|
| 251 |
+
"n_all = sum(p.numel() for p in utmos.parameters())\n",
|
| 252 |
+
"\n",
|
| 253 |
+
"# (tùy chọn) đóng băng feature-extractor (các lớp conv trích đặc trưng) → đỡ VRAM + chống overfit\n",
|
| 254 |
+
"if FREEZE_FEAT_EXT:\n",
|
| 255 |
+
" n_frozen = 0\n",
|
| 256 |
+
" for name, p in utmos.named_parameters():\n",
|
| 257 |
+
" if \"feature_extractor\" in name or \"feature_projection\" in name or \"conv\" in name.lower():\n",
|
| 258 |
+
" p.requires_grad = False; n_frozen += p.numel()\n",
|
| 259 |
+
" print(f\"❄️ Đóng băng feature-extractor: {n_frozen/1e6:.1f}M / {n_all/1e6:.1f}M param\")\n",
|
| 260 |
+
"n_train = sum(p.numel() for p in utmos.parameters() if p.requires_grad)\n",
|
| 261 |
+
"print(f\"UTMOS: {n_all/1e6:.1f}M param tổng · {n_train/1e6:.1f}M param sẽ train\")\n",
|
| 262 |
+
"\n",
|
| 263 |
+
"def load_wav_qmos(sid):\n",
|
| 264 |
+
" p = os.path.join(WAV_DIR, sid + \".wav\")\n",
|
| 265 |
+
" if not os.path.exists(p):\n",
|
| 266 |
+
" return None\n",
|
| 267 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 268 |
+
" return wave[: QMOS_MAX_SEC * SR].astype(np.float32)\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"# Tập train QMOS: chỉ wav tồn tại trên đĩa\n",
|
| 271 |
+
"train_stems_q = [s for s in qmos_gold if os.path.exists(os.path.join(WAV_DIR, s + \".wav\"))]\n",
|
| 272 |
+
"np.random.shuffle(train_stems_q)\n",
|
| 273 |
+
"if LIMIT_TRAIN:\n",
|
| 274 |
+
" train_stems_q = train_stems_q[:LIMIT_TRAIN]\n",
|
| 275 |
+
"tr_q, va_q = train_test_split(train_stems_q, test_size=VAL_FRAC, random_state=SEED)\n",
|
| 276 |
+
"print(f\"QMOS train: {len(tr_q)} · val nội bộ: {len(va_q)}\")\n",
|
| 277 |
+
"\n",
|
| 278 |
+
"opt = torch.optim.AdamW([p for p in utmos.parameters() if p.requires_grad],\n",
|
| 279 |
+
" lr=LR, weight_decay=WEIGHT_DECAY)\n",
|
| 280 |
+
"scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
|
| 281 |
+
"mse = nn.MSELoss()\n",
|
| 282 |
+
"\n",
|
| 283 |
+
"def utmos_forward(wave_np):\n",
|
| 284 |
+
" \"\"\"1 wav numpy -> MOS scalar tensor (giữ grad).\"\"\"\n",
|
| 285 |
+
" x = torch.from_numpy(wave_np).unsqueeze(0).to(device) # [1, T]\n",
|
| 286 |
+
" out = utmos(x, SR) # [1] (hoặc [1,?])\n",
|
| 287 |
+
" return out.reshape(-1).mean() # scalar an toàn mọi shape\n",
|
| 288 |
+
"\n",
|
| 289 |
+
"def pairwise_rank_loss(preds, targets):\n",
|
| 290 |
+
" \"\"\"Hinge ranking trên các cặp trong 1 nhóm (khuyến khích đúng thứ hạng = tối ưu SRCC).\"\"\"\n",
|
| 291 |
+
" p = torch.stack(preds); t = torch.tensor(targets, device=device, dtype=torch.float32)\n",
|
| 292 |
+
" if len(p) < 2:\n",
|
| 293 |
+
" return torch.zeros((), device=device)\n",
|
| 294 |
+
" sign = torch.sign(t.unsqueeze(0) - t.unsqueeze(1))\n",
|
| 295 |
+
" diff = p.unsqueeze(0) - p.unsqueeze(1)\n",
|
| 296 |
+
" return torch.relu(-sign * diff).mean()\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"@torch.no_grad()\n",
|
| 299 |
+
"def eval_qmos_val():\n",
|
| 300 |
+
" utmos.eval()\n",
|
| 301 |
+
" preds, gts = [], []\n",
|
| 302 |
+
" for s in va_q:\n",
|
| 303 |
+
" wave = load_wav_qmos(s)\n",
|
| 304 |
+
" if wave is None:\n",
|
| 305 |
+
" continue\n",
|
| 306 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 307 |
+
" preds.append(float(utmos_forward(wave).item()))\n",
|
| 308 |
+
" gts.append(qmos_gold[s])\n",
|
| 309 |
+
" return float(spearmanr(preds, gts).correlation)\n",
|
| 310 |
+
"\n",
|
| 311 |
+
"# Baseline ZERO-SHOT (trước khi train) trên CÙNG val → mốc phải vượt\n",
|
| 312 |
+
"srcc_zeroshot = eval_qmos_val()\n",
|
| 313 |
+
"print(f\"\\n📍 UTMOS zero-shot (val nội bộ): SRCC = {srcc_zeroshot:.4f} \"\n",
|
| 314 |
+
" f\"(leaderboard DEV ~{QMOS_BASELINE['utmos_zeroshot']}; exp07 head {QMOS_BASELINE['exp07_head']})\")\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"CKPT_QMOS = os.path.join(OUT_DIR, \"ft_qmos_utmos.pt\")\n",
|
| 317 |
+
"def save_qmos_ckpt(srcc):\n",
|
| 318 |
+
" torch.save({\"utmos_state\": {k: v.cpu() for k, v in utmos.state_dict().items()},\n",
|
| 319 |
+
" \"val_srcc\": float(srcc), \"raw_scale\": True,\n",
|
| 320 |
+
" \"QMOS_MAX_SEC\": QMOS_MAX_SEC, \"FREEZE_FEAT_EXT\": FREEZE_FEAT_EXT}, CKPT_QMOS)\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"best, best_state, bad = srcc_zeroshot, {k: v.cpu().clone() for k, v in utmos.state_dict().items()}, 0\n",
|
| 323 |
+
"save_qmos_ckpt(best) # lưu sẵn bản zero-shot (worst case vẫn = baseline)\n",
|
| 324 |
+
"\n",
|
| 325 |
+
"# Gom theo CỬA SỔ = ACCUM mẫu HỢP LỆ (micro). Hai chế độ backward:\n",
|
| 326 |
+
"# • RANK off (mặc định) → backward NGAY từng mẫu → đồ thị giải phóng liền → VRAM thấp.\n",
|
| 327 |
+
"# • RANK on → ranking cần SO các pred TRONG cửa sổ → PHẢI giữ đồ thị cả cửa sổ →\n",
|
| 328 |
+
"# gom MSE (win_loss) + pred (buf_p) rồi backward MỘT lần (MSE_mean + λ·rank).\n",
|
| 329 |
+
"# ⚠️ Lỗi cũ: backward MSE từng bước đã giải phóng đồ thị → rank_loss.backward() sau đó\n",
|
| 330 |
+
"# sẽ lỗi \"backward through the graph a second time\". Bản này gom rồi backward 1 lần → hết lỗi.\n",
|
| 331 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 332 |
+
" utmos.train()\n",
|
| 333 |
+
" opt.zero_grad()\n",
|
| 334 |
+
" np.random.shuffle(tr_q)\n",
|
| 335 |
+
" run = 0.0; nb = 0\n",
|
| 336 |
+
" micro = 0; win_loss = None; buf_p, buf_t = [], []\n",
|
| 337 |
+
" for s in tqdm(tr_q, desc=f\"epoch {ep}\"):\n",
|
| 338 |
+
" wave = load_wav_qmos(s)\n",
|
| 339 |
+
" if wave is None:\n",
|
| 340 |
+
" continue\n",
|
| 341 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 342 |
+
" pred = utmos_forward(wave)\n",
|
| 343 |
+
" loss = mse(pred, torch.tensor(qmos_gold[s], device=device, dtype=pred.dtype))\n",
|
| 344 |
+
" run += float(loss.item()); nb += 1\n",
|
| 345 |
+
" if RANK_LAMBDA > 0:\n",
|
| 346 |
+
" win_loss = loss if win_loss is None else win_loss + loss # GIỮ đồ thị (không backward ngay)\n",
|
| 347 |
+
" buf_p.append(pred); buf_t.append(qmos_gold[s]); micro += 1\n",
|
| 348 |
+
" else:\n",
|
| 349 |
+
" scaler.scale(loss / ACCUM).backward(); micro += 1 # backward ngay → VRAM thấp\n",
|
| 350 |
+
" if micro == ACCUM:\n",
|
| 351 |
+
" if RANK_LAMBDA > 0:\n",
|
| 352 |
+
" total = win_loss / micro\n",
|
| 353 |
+
" if len(buf_p) >= 2:\n",
|
| 354 |
+
" total = total + RANK_LAMBDA * pairwise_rank_loss(buf_p, buf_t)\n",
|
| 355 |
+
" scaler.scale(total).backward()\n",
|
| 356 |
+
" scaler.step(opt); scaler.update(); opt.zero_grad()\n",
|
| 357 |
+
" micro = 0; win_loss = None; buf_p, buf_t = [], []\n",
|
| 358 |
+
" # flush cửa sổ dư cuối epoch (số mẫu không chia hết cho ACCUM)\n",
|
| 359 |
+
" if micro > 0:\n",
|
| 360 |
+
" if RANK_LAMBDA > 0:\n",
|
| 361 |
+
" total = win_loss / micro\n",
|
| 362 |
+
" if len(buf_p) >= 2:\n",
|
| 363 |
+
" total = total + RANK_LAMBDA * pairwise_rank_loss(buf_p, buf_t)\n",
|
| 364 |
+
" scaler.scale(total).backward()\n",
|
| 365 |
+
" scaler.step(opt); scaler.update(); opt.zero_grad()\n",
|
| 366 |
+
" sc = eval_qmos_val()\n",
|
| 367 |
+
" print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | val SRCC {sc:.4f} \"\n",
|
| 368 |
+
" f\"(zero-shot {srcc_zeroshot:.4f} · best {max(best,sc):.4f})\")\n",
|
| 369 |
+
" if sc > best:\n",
|
| 370 |
+
" best = sc\n",
|
| 371 |
+
" best_state = {k: v.cpu().clone() for k, v in utmos.state_dict().items()}\n",
|
| 372 |
+
" save_qmos_ckpt(best)\n",
|
| 373 |
+
" print(f\" 💾 lưu best → {CKPT_QMOS} (epoch {ep}, SRCC {sc:.4f})\")\n",
|
| 374 |
+
" bad = 0\n",
|
| 375 |
+
" else:\n",
|
| 376 |
+
" bad += 1\n",
|
| 377 |
+
" if bad >= PATIENCE:\n",
|
| 378 |
+
" print(f\"Early stop ở epoch {ep}.\"); break\n",
|
| 379 |
+
"\n",
|
| 380 |
+
"utmos.load_state_dict(best_state)\n",
|
| 381 |
+
"print(f\"\\n✅ PHẦN A xong — QMOS val nội bộ: zero-shot {srcc_zeroshot:.4f} → fine-tune {best:.4f} \"\n",
|
| 382 |
+
" + (\"🚀 cải thiện\" if best > srcc_zeroshot + 1e-4 else \"➖ KHÔNG vượt zero-shot\"))\n",
|
| 383 |
+
"if best <= srcc_zeroshot + 1e-4:\n",
|
| 384 |
+
" print(\" ⚠️ Fine-tune chưa vượt zero-shot → cân nhắc tăng EPOCHS / bật RANK_LAMBDA=0.3 / \"\n",
|
| 385 |
+
" \"mở băng feature-extractor (FREEZE_FEAT_EXT=False); hoặc giữ QMOS exp07 (0.548).\")"
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"cell_type": "markdown",
|
| 390 |
+
"id": "d2448fa6",
|
| 391 |
+
"metadata": {},
|
| 392 |
+
"source": [
|
| 393 |
+
"## 4. PHẦN A (tiếp) — Dự đoán QMOS cho DEV bằng UTMOS đã fine-tune"
|
| 394 |
+
]
|
| 395 |
+
},
|
| 396 |
+
{
|
| 397 |
+
"cell_type": "code",
|
| 398 |
+
"execution_count": null,
|
| 399 |
+
"id": "f8463ca5",
|
| 400 |
+
"metadata": {},
|
| 401 |
+
"outputs": [],
|
| 402 |
+
"source": [
|
| 403 |
+
"def list_dev():\n",
|
| 404 |
+
" with open(DEV_SCP) as f:\n",
|
| 405 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 406 |
+
"\n",
|
| 407 |
+
"dev_names = list_dev()\n",
|
| 408 |
+
"if LIMIT_DEV:\n",
|
| 409 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 410 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 411 |
+
"\n",
|
| 412 |
+
"@torch.no_grad()\n",
|
| 413 |
+
"def predict_qmos(name):\n",
|
| 414 |
+
" p = os.path.join(WAV_DIR, name if str(name).endswith(\".wav\") else str(name) + \".wav\")\n",
|
| 415 |
+
" if not os.path.exists(p):\n",
|
| 416 |
+
" return None\n",
|
| 417 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 418 |
+
" wave = wave[: QMOS_MAX_SEC * SR].astype(np.float32)\n",
|
| 419 |
+
" utmos.eval()\n",
|
| 420 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 421 |
+
" v = float(utmos_forward(wave).item())\n",
|
| 422 |
+
" return float(np.clip(v, 1.0, 5.0))\n",
|
| 423 |
+
"\n",
|
| 424 |
+
"qmos_pred = {}\n",
|
| 425 |
+
"n_real = n_def = 0\n",
|
| 426 |
+
"for name in tqdm(dev_names, desc=\"QMOS dev\"):\n",
|
| 427 |
+
" v = predict_qmos(name)\n",
|
| 428 |
+
" if v is None:\n",
|
| 429 |
+
" v = 3.0; n_def += 1\n",
|
| 430 |
+
" else:\n",
|
| 431 |
+
" n_real += 1\n",
|
| 432 |
+
" qmos_pred[name] = v\n",
|
| 433 |
+
"print(f\"QMOS dự đoán: thật {n_real}, mặc định {n_def}\")\n",
|
| 434 |
+
"\n",
|
| 435 |
+
"# Giải phóng UTMOS trước khi nạp backbone cảm xúc (đỡ VRAM T4)\n",
|
| 436 |
+
"del utmos, opt, scaler\n",
|
| 437 |
+
"torch.cuda.empty_cache() if device == \"cuda\" else None"
|
| 438 |
+
]
|
| 439 |
+
},
|
| 440 |
+
{
|
| 441 |
+
"cell_type": "markdown",
|
| 442 |
+
"id": "b67b5d6b",
|
| 443 |
+
"metadata": {},
|
| 444 |
+
"source": [
|
| 445 |
+
"## 5. PHẦN B — Nạp ckpt exp08 (WavLM ft + audeering) → 5 cột cảm xúc cho DEV\n",
|
| 446 |
+
"Tái dùng nguyên cơ chế load của exp08b: dựng kiến trúc → `load_state_dict` từ `ft_emotion_full_20epoch.pt`."
|
| 447 |
+
]
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"cell_type": "code",
|
| 451 |
+
"execution_count": null,
|
| 452 |
+
"id": "f98eca99",
|
| 453 |
+
"metadata": {
|
| 454 |
+
"lines_to_next_cell": 1
|
| 455 |
+
},
|
| 456 |
+
"outputs": [],
|
| 457 |
+
"source": [
|
| 458 |
+
"import torch.nn.functional as F\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"ckpt = torch.load(EMO_CKPT, map_location=\"cpu\", weights_only=False) # ckpt có numpy → weights_only=False\n",
|
| 461 |
+
"assert \"wavlm\" in ckpt, (\"❌ EMO_CKPT không có 'wavlm' (backbone). Cần ft_emotion_full_20epoch.pt (bản đủ backbone), \"\n",
|
| 462 |
+
" \"KHÔNG phải ft_emotion_meta.pt cũ.\")\n",
|
| 463 |
+
"print(\"✅ Nạp ckpt cảm xúc:\", EMO_CKPT, \"| keys:\", list(ckpt.keys()))\n",
|
| 464 |
+
"\n",
|
| 465 |
+
"def find_hf_backbone(module):\n",
|
| 466 |
+
" cands = []\n",
|
| 467 |
+
" for nm, m in module.named_modules():\n",
|
| 468 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 469 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 470 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 471 |
+
" cands.append((nm, m))\n",
|
| 472 |
+
" if not cands:\n",
|
| 473 |
+
" return None, None\n",
|
| 474 |
+
" cands.sort(key=lambda x: sum(p.numel() for p in x[1].parameters()), reverse=True)\n",
|
| 475 |
+
" return cands[0]\n",
|
| 476 |
+
"\n",
|
| 477 |
+
"wavlm = None\n",
|
| 478 |
+
"try:\n",
|
| 479 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 480 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 481 |
+
" _name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 482 |
+
" if wavlm is not None:\n",
|
| 483 |
+
" print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{_name}'\")\n",
|
| 484 |
+
"except Exception as e:\n",
|
| 485 |
+
" print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
|
| 486 |
+
"if wavlm is None:\n",
|
| 487 |
+
" from transformers import WavLMModel\n",
|
| 488 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 489 |
+
" print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
|
| 490 |
+
"\n",
|
| 491 |
+
"wavlm = wavlm.to(device).eval()\n",
|
| 492 |
+
"WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 493 |
+
"miss, unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
|
| 494 |
+
"print(f\"🔁 load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).\")\n",
|
| 495 |
+
"\n",
|
| 496 |
+
"def masked_mean(hidden, attn_mask):\n",
|
| 497 |
+
" if attn_mask is None:\n",
|
| 498 |
+
" return hidden.mean(dim=1)\n",
|
| 499 |
+
" try:\n",
|
| 500 |
+
" fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)\n",
|
| 501 |
+
" except Exception:\n",
|
| 502 |
+
" return hidden.mean(dim=1)\n",
|
| 503 |
+
" fm = fm.unsqueeze(-1).to(hidden.dtype)\n",
|
| 504 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)\n",
|
| 505 |
+
"\n",
|
| 506 |
+
"@torch.no_grad()\n",
|
| 507 |
+
"def wavlm_embed(input_values, attn_mask):\n",
|
| 508 |
+
" out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
|
| 509 |
+
" return masked_mean(out, attn_mask)\n",
|
| 510 |
+
"\n",
|
| 511 |
+
"# ── audeering FROZEN (đặc trưng phụ) — như exp08 ──\n",
|
| 512 |
+
"AUD_DIM = 0\n",
|
| 513 |
+
"aud_backbone = aud_head = aud_proc = None\n",
|
| 514 |
+
"if USE_AUDEERING:\n",
|
| 515 |
+
" from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 516 |
+
" from huggingface_hub import hf_hub_download\n",
|
| 517 |
+
" AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 518 |
+
" aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 519 |
+
" aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 520 |
+
" aud_backbone = Wav2Vec2Model(aud_cfg)\n",
|
| 521 |
+
" try:\n",
|
| 522 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 523 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 524 |
+
" except Exception:\n",
|
| 525 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 526 |
+
" bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 527 |
+
" aud_backbone.load_state_dict(bb_sd, strict=False)\n",
|
| 528 |
+
" _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 529 |
+
" _out = _sd[\"classifier.out_proj.weight\"].shape[0]\n",
|
| 530 |
+
" aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
|
| 531 |
+
" aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 532 |
+
" aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 533 |
+
" aud_backbone = aud_backbone.to(device).eval()\n",
|
| 534 |
+
" aud_head = aud_head.to(device).eval()\n",
|
| 535 |
+
" AUD_DIM = _hid + 3\n",
|
| 536 |
+
" print(f\"✅ audeering frozen ({AUD_DIM}-D)\")\n",
|
| 537 |
+
"\n",
|
| 538 |
+
"def load_wav_emo(sid):\n",
|
| 539 |
+
" p = os.path.join(WAV_DIR, sid + \".wav\")\n",
|
| 540 |
+
" if not os.path.exists(p):\n",
|
| 541 |
+
" return None\n",
|
| 542 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 543 |
+
" return wave[: EMO_MAX_SEC * SR].astype(np.float32)\n",
|
| 544 |
+
"\n",
|
| 545 |
+
"@torch.no_grad()\n",
|
| 546 |
+
"def extract_audeering(stems, tag):\n",
|
| 547 |
+
" if not USE_AUDEERING:\n",
|
| 548 |
+
" return {}\n",
|
| 549 |
+
" cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
|
| 550 |
+
" store = {}\n",
|
| 551 |
+
" if os.path.exists(cache_path):\n",
|
| 552 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 553 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 554 |
+
" print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
|
| 555 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 556 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
|
| 557 |
+
" wave = load_wav_emo(s)\n",
|
| 558 |
+
" if wave is None:\n",
|
| 559 |
+
" continue\n",
|
| 560 |
+
" x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 561 |
+
" x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
|
| 562 |
+
" h = aud_backbone(x)[0].mean(dim=1)\n",
|
| 563 |
+
" out = aud_head(h)[0].cpu().numpy()\n",
|
| 564 |
+
" vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
|
| 565 |
+
" store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
|
| 566 |
+
" if (i + 1) % 500 == 0:\n",
|
| 567 |
+
" np.savez(cache_path, **store)\n",
|
| 568 |
+
" if todo:\n",
|
| 569 |
+
" np.savez(cache_path, **store)\n",
|
| 570 |
+
" return store\n",
|
| 571 |
+
"\n",
|
| 572 |
+
"# ── EmoHeads (khớp exp08) + nạp trọng số head + thống kê chuẩn hóa từ ckpt ──\n",
|
| 573 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 574 |
+
"TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)\n",
|
| 575 |
+
"\n",
|
| 576 |
+
"class EmoHeads(nn.Module):\n",
|
| 577 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 578 |
+
" super().__init__()\n",
|
| 579 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 580 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 581 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 582 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 583 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 584 |
+
" def forward(self, feat, tgt):\n",
|
| 585 |
+
" h = self.trunk(feat)\n",
|
| 586 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 587 |
+
"\n",
|
| 588 |
+
"heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device).eval()\n",
|
| 589 |
+
"hmiss, hunexp = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
|
| 590 |
+
"print(f\"🔁 load heads từ ckpt: thiếu {len(hmiss)} / dư {len(hunexp)} key (kỳ vọng 0).\")\n",
|
| 591 |
+
"\n",
|
| 592 |
+
"emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
|
| 593 |
+
"vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
|
| 594 |
+
"print(f\"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
|
| 595 |
+
"\n",
|
| 596 |
+
"# Target cảm xúc (cho EMOS head) từ metadata\n",
|
| 597 |
+
"def load_target_emotions():\n",
|
| 598 |
+
" tgt = {}\n",
|
| 599 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 600 |
+
" for ln in f:\n",
|
| 601 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 602 |
+
" if len(parts) >= 2:\n",
|
| 603 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 604 |
+
" return tgt\n",
|
| 605 |
+
"\n",
|
| 606 |
+
"target_map = load_target_emotions()\n",
|
| 607 |
+
"\n",
|
| 608 |
+
"def onehot_target(tgt):\n",
|
| 609 |
+
" v = np.zeros(N_EMO, dtype=np.float32)\n",
|
| 610 |
+
" if tgt in EMOTIONS5:\n",
|
| 611 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 612 |
+
" return v\n",
|
| 613 |
+
"\n",
|
| 614 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 615 |
+
"aud_dev = extract_audeering(dev_stems, \"dev\")\n",
|
| 616 |
+
"\n",
|
| 617 |
+
"@torch.no_grad()\n",
|
| 618 |
+
"def predict_emotion(sid):\n",
|
| 619 |
+
" wave = load_wav_emo(sid)\n",
|
| 620 |
+
" if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
|
| 621 |
+
" return None\n",
|
| 622 |
+
" iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 623 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
|
| 624 |
+
" tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
|
| 625 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 626 |
+
" fw = wavlm_embed(iv, am)\n",
|
| 627 |
+
" feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw\n",
|
| 628 |
+
" emos_p, cat_l, vad_p = heads(feat, tgt)\n",
|
| 629 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 630 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 631 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 632 |
+
" return emos, cat5, vad3"
|
| 633 |
+
]
|
| 634 |
+
},
|
| 635 |
+
{
|
| 636 |
+
"cell_type": "markdown",
|
| 637 |
+
"id": "f813bfaf",
|
| 638 |
+
"metadata": {},
|
| 639 |
+
"source": [
|
| 640 |
+
"## 6. PHẦN C — Ghép QMOS (fine-tune) + 5 cột cảm xúc (exp08) → answer.txt 6 cột"
|
| 641 |
+
]
|
| 642 |
+
},
|
| 643 |
+
{
|
| 644 |
+
"cell_type": "code",
|
| 645 |
+
"execution_count": null,
|
| 646 |
+
"id": "f9fd3208",
|
| 647 |
+
"metadata": {
|
| 648 |
+
"lines_to_next_cell": 1
|
| 649 |
+
},
|
| 650 |
+
"outputs": [],
|
| 651 |
+
"source": [
|
| 652 |
+
"def fmt_cat(p5):\n",
|
| 653 |
+
" return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 654 |
+
"\n",
|
| 655 |
+
"def build_answer(out_path):\n",
|
| 656 |
+
" n_real = n_def = 0\n",
|
| 657 |
+
" with open(out_path, \"w\") as f:\n",
|
| 658 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 659 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 660 |
+
" sid = stem(name)\n",
|
| 661 |
+
" pr = predict_emotion(sid)\n",
|
| 662 |
+
" if pr is None:\n",
|
| 663 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
|
| 664 |
+
" else:\n",
|
| 665 |
+
" emos, cat5, vad3 = pr; n_real += 1\n",
|
| 666 |
+
" qmos = qmos_pred.get(name, qmos_pred.get(sid, 3.0))\n",
|
| 667 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 668 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
|
| 669 |
+
"\n",
|
| 670 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 671 |
+
"build_answer(answer_path)"
|
| 672 |
+
]
|
| 673 |
+
},
|
| 674 |
+
{
|
| 675 |
+
"cell_type": "markdown",
|
| 676 |
+
"id": "3f78d7fe",
|
| 677 |
+
"metadata": {},
|
| 678 |
+
"source": [
|
| 679 |
+
"## 7. Validate + zip"
|
| 680 |
+
]
|
| 681 |
+
},
|
| 682 |
+
{
|
| 683 |
+
"cell_type": "code",
|
| 684 |
+
"execution_count": null,
|
| 685 |
+
"id": "d873783b",
|
| 686 |
+
"metadata": {},
|
| 687 |
+
"outputs": [],
|
| 688 |
+
"source": [
|
| 689 |
+
"def validate(path):\n",
|
| 690 |
+
" import csv\n",
|
| 691 |
+
" with open(path) as f:\n",
|
| 692 |
+
" rows = list(csv.reader(f))\n",
|
| 693 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0] and \"EMOS\" in rows[0], \"Header sai\"\n",
|
| 694 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 695 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 696 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 697 |
+
"\n",
|
| 698 |
+
"validate(answer_path)\n",
|
| 699 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp13_ft-qmos.zip answer.txt \"\n",
|
| 700 |
+
" f\"&& unzip -l submission_track2_exp13_ft-qmos.zip\")\n",
|
| 701 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp13_ft-qmos.zip\"))"
|
| 702 |
+
]
|
| 703 |
+
},
|
| 704 |
+
{
|
| 705 |
+
"cell_type": "markdown",
|
| 706 |
+
"id": "33290a62",
|
| 707 |
+
"metadata": {},
|
| 708 |
+
"source": [
|
| 709 |
+
"## Ghi chú\n",
|
| 710 |
+
"- **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để chạy trơn (không OOM, 1 epoch xong); rồi đặt `None`.\n",
|
| 711 |
+
"- **OOM trên T4?** giảm `QMOS_MAX_SEC` (12→10→8); giữ `FREEZE_FEAT_EXT=True`; `BATCH=1` đã là min.\n",
|
| 712 |
+
" ⚠️ **Bật `RANK_LAMBDA>0` tốn VRAM hơn** vì phải GIỮ đồ thị cả cửa sổ ACCUM (=16) để so thứ hạng →\n",
|
| 713 |
+
" nếu OOM khi bật ranking: giảm `ACCUM` (vd 8, cũng là kích thước nhóm ranking) hoặc `QMOS_MAX_SEC`.\n",
|
| 714 |
+
"- **Đọc mục A:** so `val SRCC fine-tune` với `zero-shot`. Chỉ nộp QMOS fine-tune nếu **vượt zero-shot**\n",
|
| 715 |
+
" (lý tưởng vượt cả exp07 0.548); nếu không → giữ QMOS exp07 (Add Input answer.txt exp07, đổi cột QMOS).\n",
|
| 716 |
+
"- Nếu chưa vượt: tăng `EPOCHS`, bật `RANK_LAMBDA=0.3` (tối ưu thẳng thứ hạng), hoặc `FREEZE_FEAT_EXT=False`\n",
|
| 717 |
+
" (mở băng feature-extractor — mạnh hơn nhưng dễ overfit + nặng VRAM).\n",
|
| 718 |
+
"- **Lưu checkpoint:** `ft_qmos_utmos.pt` lưu mỗi best → **Save Version NGAY** sau khi chạy (bài học exp08).\n",
|
| 719 |
+
"- **License QMOS:** UTMOS/SpeechMOS (kiểm tra license tarepan/SpeechMOS) — khai báo `docs/12_system_description.md`.\n",
|
| 720 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp13)."
|
| 721 |
+
]
|
| 722 |
+
}
|
| 723 |
+
],
|
| 724 |
+
"metadata": {
|
| 725 |
+
"jupytext": {
|
| 726 |
+
"cell_metadata_filter": "-all",
|
| 727 |
+
"main_language": "python",
|
| 728 |
+
"notebook_metadata_filter": "-all"
|
| 729 |
+
}
|
| 730 |
+
},
|
| 731 |
+
"nbformat": 4,
|
| 732 |
+
"nbformat_minor": 5
|
| 733 |
+
}
|
track2/exp13_finetune_qmos_pipeline.py
ADDED
|
@@ -0,0 +1,607 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp13 (FINE-TUNE UTMOS cho QMOS) + answer 6 cột — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục tiêu:** QMOS hiện tốt nhất = 0.548 (exp07, head ĐÓNG BĂNG + neo UTMOS). exp13 thử **MỞ BĂNG
|
| 5 |
+
# (fine-tune) thẳng UTMOS** trên nhãn `qMOS` thật của Track 2 → kéo model chất lượng về đúng domain giọng
|
| 6 |
+
# cảm xúc. Sau đó **mượn 5 cột cảm xúc từ checkpoint exp08** (`ft_emotion_full_20epoch.pt` — bản TỐT NHẤT)
|
| 7 |
+
# → ghép `answer.txt` 6 cột.
|
| 8 |
+
#
|
| 9 |
+
# ## Vì sao fine-tune UTMOS (không phải UTMOSv2)
|
| 10 |
+
# - UTMOS (`utmos22_strong`, tarepan/SpeechMOS) = **1 model đơn**, tải qua `torch.hub`, **bản thân đã dự đoán
|
| 11 |
+
# QMOS** → warm-start hoàn hảo cho cột chất lượng (khác UTMOSv2 = ensemble nhiều fold + 2 luồng → khó train).
|
| 12 |
+
# - forward: `model(wave[B,T], sr) -> MOS[B]`, là `nn.Module` chuẩn → backprop được toàn model.
|
| 13 |
+
# - **Không dùng neo UTMOS riêng** (đã chốt): khi fine-tune chính UTMOS thì "neo" nằm sẵn trong trọng số
|
| 14 |
+
# warm-start → head/neo ngoài là thừa.
|
| 15 |
+
#
|
| 16 |
+
# ## Thiết kế
|
| 17 |
+
# ```
|
| 18 |
+
# [PHẦN A] wav ─► UTMOS (utmos22_strong, TRAINABLE, warm-start pretrained) ─► QMOS (train trên qMOS gold)
|
| 19 |
+
# [PHẦN B] wav ─► WavLM(exp08 ft) + audeering(frozen) ─► EMOS/CAT/VAD (NẠP ckpt, chỉ inference)
|
| 20 |
+
# [PHẦN C] ghép QMOS(A) + 5 cột cảm xúc(B) ─► answer.txt 6 cột ─► validate ─► zip
|
| 21 |
+
# ```
|
| 22 |
+
#
|
| 23 |
+
# ## ⚠️ Phải biết trước
|
| 24 |
+
# - Fine-tune = **không cache** (mỗi epoch chạy lại UTMOS forward+backward) → tốn giờ GPU. **Lần đầu BẮT BUỘC
|
| 25 |
+
# `LIMIT_TRAIN=300`, `LIMIT_DEV=20`** để chỉnh trơn rồi mới `None`.
|
| 26 |
+
# - Lưới an toàn: chỉ nộp QMOS fine-tune nếu **SRCC val nội bộ > zero-shot UTMOS** (mục A in cả 2 số).
|
| 27 |
+
# - **Lưu checkpoint `ft_qmos_utmos.pt` mỗi best + Save Version NGAY** (bài học exp08: kernel chết là mất).
|
| 28 |
+
#
|
| 29 |
+
# **Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input (1) dataset Track 2, (2) dataset chứa
|
| 30 |
+
# `ft_emotion_full.pt` (exp08), (3) tùy chọn cache `aud_dev.npz` → sửa slug cell 0 → Run All.
|
| 31 |
+
|
| 32 |
+
# %% [markdown]
|
| 33 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 34 |
+
|
| 35 |
+
# %%
|
| 36 |
+
import os, shutil, glob
|
| 37 |
+
|
| 38 |
+
# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets/train.csv + wav/ + metadata.csv) ──
|
| 39 |
+
def find_data_root(search_root="/kaggle/input"):
|
| 40 |
+
cands = []
|
| 41 |
+
for train_csv in glob.glob(os.path.join(search_root, "**", "sets", "train.csv"), recursive=True):
|
| 42 |
+
root = os.path.dirname(os.path.dirname(train_csv)) # .../<root>/sets/train.csv → <root>
|
| 43 |
+
score = os.path.isdir(os.path.join(root, "wav")) + os.path.exists(os.path.join(root, "metadata.csv"))
|
| 44 |
+
cands.append((score, root))
|
| 45 |
+
cands.sort(reverse=True) # ưu tiên thư mục đủ wav + metadata
|
| 46 |
+
return cands
|
| 47 |
+
|
| 48 |
+
_cands = find_data_root("/kaggle/input")
|
| 49 |
+
if _cands:
|
| 50 |
+
print("🔎 Ứng viên DATA_ROOT (điểm cao = đủ wav+metadata):")
|
| 51 |
+
for sc, r in _cands:
|
| 52 |
+
print(f" [{sc}/2] {r}")
|
| 53 |
+
DATA_ROOT = _cands[0][1]
|
| 54 |
+
print(f"👉 Tự chọn DATA_ROOT = {DATA_ROOT}")
|
| 55 |
+
else:
|
| 56 |
+
DATA_ROOT = "/kaggle/input/datasets/minhtoan2" # dự phòng — sửa tay nếu auto-dò không thấy
|
| 57 |
+
print(f"❌ Không thấy sets/train.csv trong /kaggle/input → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)")
|
| 58 |
+
|
| 59 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 60 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (cho cột cảm xúc)
|
| 61 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
|
| 62 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 63 |
+
|
| 64 |
+
# ── Checkpoint cảm xúc exp08 (để sinh 5 cột EMOS/CAT/VAD) ─────────────────────
|
| 65 |
+
# ⭐ TỐT NHẤT = ft_emotion_full_20epoch.pt (bản 20 epoch) — dùng bản này, KHÔNG dùng ft_emotion_full.pt.
|
| 66 |
+
EMO_CKPT = "/kaggle/input/ft-emotion-full/ft_emotion_full_20epoch.pt" # << ckpt exp08 20ep (CÓ backbone WavLM)
|
| 67 |
+
CACHE_INPUT = "/kaggle/input/ft-emotion-cache" # << (tùy chọn) thư mục chứa aud_dev.npz; "" nếu không có
|
| 68 |
+
|
| 69 |
+
OUT_DIR = "/kaggle/working"
|
| 70 |
+
CACHE_DIR = "/kaggle/working/ft_cache" # /kaggle/input read-only → copy cache audeering sang đây
|
| 71 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
# ── PHẦN A: fine-tune UTMOS (QMOS) ───────────────────────────────────────────
|
| 74 |
+
DEVICE = "cuda"
|
| 75 |
+
SR = 16000
|
| 76 |
+
QMOS_MAX_SEC = 12 # cắt audio chặn bộ nhớ backprop (UTMOS); OOM thì giảm 10/8
|
| 77 |
+
LR = 1e-5 # LR nhỏ cho fine-tune (warm-start sẵn tốt)
|
| 78 |
+
WEIGHT_DECAY = 1e-5
|
| 79 |
+
EPOCHS = 10 # TRẦN; early-stop quyết số epoch thật
|
| 80 |
+
PATIENCE = 3
|
| 81 |
+
BATCH = 1 # UTMOS forward KHÔNG có attention-mask → BATCH=1 an toàn (pad zero sẽ lệch pooling)
|
| 82 |
+
ACCUM = 16 # effective batch = BATCH*ACCUM = 16
|
| 83 |
+
VAL_FRAC = 0.10
|
| 84 |
+
SEED = 42
|
| 85 |
+
USE_AMP = True
|
| 86 |
+
RANK_LAMBDA = 0.0 # 0 = chỉ MSE. >0 (vd 0.3) = cộng pairwise ranking loss (tối ưu thẳng thứ hạng=SRCC)
|
| 87 |
+
FREEZE_FEAT_EXT = True # đóng băng feature-extractor (CNN conv) của UTMOS → đỡ VRAM + chống overfit
|
| 88 |
+
|
| 89 |
+
# ── PHẦN B: inference cảm xúc (PHẢI khớp kiến trúc exp08) ─────────────────────
|
| 90 |
+
EMO_MAX_SEC = 8
|
| 91 |
+
UNFREEZE_TOP_LAYERS = 6 # khớp ckpt exp08
|
| 92 |
+
TRUNK_HIDDEN = 512
|
| 93 |
+
HEAD_HIDDEN = 128
|
| 94 |
+
DROPOUT = 0.3
|
| 95 |
+
USE_AUDEERING = True # khớp ckpt exp08
|
| 96 |
+
|
| 97 |
+
LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
|
| 98 |
+
LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
|
| 99 |
+
|
| 100 |
+
# Mốc QMOS để so (leaderboard DEV)
|
| 101 |
+
QMOS_BASELINE = {"utmos_zeroshot": 0.414, "exp07_head": 0.548}
|
| 102 |
+
|
| 103 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 104 |
+
_EMO_ALIAS = {
|
| 105 |
+
"angry": "angry", "anger": "angry",
|
| 106 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 107 |
+
"neutral": "neutral", "calm": "neutral",
|
| 108 |
+
"sad": "sad", "sadness": "sad",
|
| 109 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def norm_emotion(label):
|
| 113 |
+
key = str(label).strip().lower()
|
| 114 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 115 |
+
|
| 116 |
+
def stem(p):
|
| 117 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 118 |
+
|
| 119 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 120 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP, EMO_CKPT]:
|
| 121 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 122 |
+
print(f"Fine-tune UTMOS: LR {LR} · BATCH {BATCH}×ACCUM {ACCUM} · MAX {QMOS_MAX_SEC}s · rank λ {RANK_LAMBDA}")
|
| 123 |
+
|
| 124 |
+
# Copy cache audeering (aud_dev.npz) từ input read-only sang working (để cột cảm xúc khỏi trích lại)
|
| 125 |
+
if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
|
| 126 |
+
n = 0
|
| 127 |
+
for fn in os.listdir(CACHE_INPUT):
|
| 128 |
+
if fn.startswith("aud_") and fn.endswith(".npz"):
|
| 129 |
+
shutil.copy(os.path.join(CACHE_INPUT, fn), os.path.join(CACHE_DIR, fn)); n += 1
|
| 130 |
+
print(f"📦 Copy {n} file cache audeering từ {CACHE_INPUT} → {CACHE_DIR}")
|
| 131 |
+
else:
|
| 132 |
+
print("ℹ️ Không có CACHE_INPUT → sẽ tự trích audeering cho DEV (chậm hơn lần đầu).")
|
| 133 |
+
|
| 134 |
+
# %% [markdown]
|
| 135 |
+
# ## 1. Cài đặt
|
| 136 |
+
|
| 137 |
+
# %%
|
| 138 |
+
import sys, subprocess
|
| 139 |
+
|
| 140 |
+
def pip_install(*pkgs):
|
| 141 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 142 |
+
|
| 143 |
+
pip_install("speechmos", "loralib", "speechbrain", "librosa", "soundfile",
|
| 144 |
+
"scipy", "scikit-learn", "pandas", "tqdm")
|
| 145 |
+
|
| 146 |
+
# Code SAILER (để dựng đúng kiến trúc WavLM rồi nạp ckpt exp08 đè lên) — chỉ cần cho PHẦN B
|
| 147 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 148 |
+
if not os.path.exists(REPO_DIR):
|
| 149 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 150 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 151 |
+
if REPO_DIR not in sys.path:
|
| 152 |
+
sys.path.insert(0, REPO_DIR)
|
| 153 |
+
|
| 154 |
+
# %% [markdown]
|
| 155 |
+
# ## 2. Nhãn vàng qMOS (gộp trung bình theo wav) — như exp06/exp09a
|
| 156 |
+
|
| 157 |
+
# %%
|
| 158 |
+
import numpy as np
|
| 159 |
+
import pandas as pd
|
| 160 |
+
|
| 161 |
+
def load_qmos_labels():
|
| 162 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 163 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 164 |
+
wav_col = cols.get("wavid") or cols.get("wav") or list(df.columns)[1]
|
| 165 |
+
qmos_col = cols.get("qmos") or cols.get("mos")
|
| 166 |
+
assert qmos_col, f"Không thấy cột qMOS (cột: {list(df.columns)})"
|
| 167 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 168 |
+
g = df.groupby("_stem")[qmos_col].mean()
|
| 169 |
+
return {s: float(v) for s, v in g.items()}
|
| 170 |
+
|
| 171 |
+
qmos_gold = load_qmos_labels()
|
| 172 |
+
print(f"Số wav train có nhãn qMOS: {len(qmos_gold)}")
|
| 173 |
+
_vals = np.array(list(qmos_gold.values()))
|
| 174 |
+
print(f"qMOS gold: mean {_vals.mean():.3f} · std {_vals.std():.3f} · min {_vals.min():.2f} · max {_vals.max():.2f}")
|
| 175 |
+
|
| 176 |
+
# %% [markdown]
|
| 177 |
+
# ## 3. PHẦN A — Fine-tune UTMOS trên qMOS
|
| 178 |
+
# UTMOS xuất MOS thang ~1–5 (đã warm-start) → train MSE trên thang GỐC (không z-score, để giữ ý nghĩa warm-start).
|
| 179 |
+
# `BATCH=1` + grad-accum: tránh phải pad (UTMOS forward không nhận attention-mask).
|
| 180 |
+
|
| 181 |
+
# %%
|
| 182 |
+
import torch
|
| 183 |
+
import torch.nn as nn
|
| 184 |
+
import librosa
|
| 185 |
+
from tqdm.auto import tqdm
|
| 186 |
+
from scipy.stats import spearmanr
|
| 187 |
+
from sklearn.model_selection import train_test_split
|
| 188 |
+
|
| 189 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 190 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
|
| 191 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 192 |
+
|
| 193 |
+
# Nạp UTMOS (torch.hub) — model nn.Module, forward(wave[B,T], sr) -> MOS[B]
|
| 194 |
+
utmos = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True).to(device)
|
| 195 |
+
n_all = sum(p.numel() for p in utmos.parameters())
|
| 196 |
+
|
| 197 |
+
# (tùy chọn) đóng băng feature-extractor (các lớp conv trích đặc trưng) → đỡ VRAM + chống overfit
|
| 198 |
+
if FREEZE_FEAT_EXT:
|
| 199 |
+
n_frozen = 0
|
| 200 |
+
for name, p in utmos.named_parameters():
|
| 201 |
+
if "feature_extractor" in name or "feature_projection" in name or "conv" in name.lower():
|
| 202 |
+
p.requires_grad = False; n_frozen += p.numel()
|
| 203 |
+
print(f"❄️ Đóng băng feature-extractor: {n_frozen/1e6:.1f}M / {n_all/1e6:.1f}M param")
|
| 204 |
+
n_train = sum(p.numel() for p in utmos.parameters() if p.requires_grad)
|
| 205 |
+
print(f"UTMOS: {n_all/1e6:.1f}M param tổng · {n_train/1e6:.1f}M param sẽ train")
|
| 206 |
+
|
| 207 |
+
def load_wav_qmos(sid):
|
| 208 |
+
p = os.path.join(WAV_DIR, sid + ".wav")
|
| 209 |
+
if not os.path.exists(p):
|
| 210 |
+
return None
|
| 211 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 212 |
+
return wave[: QMOS_MAX_SEC * SR].astype(np.float32)
|
| 213 |
+
|
| 214 |
+
# Tập train QMOS: chỉ wav tồn tại trên đĩa
|
| 215 |
+
train_stems_q = [s for s in qmos_gold if os.path.exists(os.path.join(WAV_DIR, s + ".wav"))]
|
| 216 |
+
np.random.shuffle(train_stems_q)
|
| 217 |
+
if LIMIT_TRAIN:
|
| 218 |
+
train_stems_q = train_stems_q[:LIMIT_TRAIN]
|
| 219 |
+
tr_q, va_q = train_test_split(train_stems_q, test_size=VAL_FRAC, random_state=SEED)
|
| 220 |
+
print(f"QMOS train: {len(tr_q)} · val nội bộ: {len(va_q)}")
|
| 221 |
+
|
| 222 |
+
opt = torch.optim.AdamW([p for p in utmos.parameters() if p.requires_grad],
|
| 223 |
+
lr=LR, weight_decay=WEIGHT_DECAY)
|
| 224 |
+
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
|
| 225 |
+
mse = nn.MSELoss()
|
| 226 |
+
|
| 227 |
+
def utmos_forward(wave_np):
|
| 228 |
+
"""1 wav numpy -> MOS scalar tensor (giữ grad)."""
|
| 229 |
+
x = torch.from_numpy(wave_np).unsqueeze(0).to(device) # [1, T]
|
| 230 |
+
out = utmos(x, SR) # [1] (hoặc [1,?])
|
| 231 |
+
return out.reshape(-1).mean() # scalar an toàn mọi shape
|
| 232 |
+
|
| 233 |
+
def pairwise_rank_loss(preds, targets):
|
| 234 |
+
"""Hinge ranking trên các cặp trong 1 nhóm (khuyến khích đúng thứ hạng = tối ưu SRCC)."""
|
| 235 |
+
p = torch.stack(preds); t = torch.tensor(targets, device=device, dtype=torch.float32)
|
| 236 |
+
if len(p) < 2:
|
| 237 |
+
return torch.zeros((), device=device)
|
| 238 |
+
sign = torch.sign(t.unsqueeze(0) - t.unsqueeze(1))
|
| 239 |
+
diff = p.unsqueeze(0) - p.unsqueeze(1)
|
| 240 |
+
return torch.relu(-sign * diff).mean()
|
| 241 |
+
|
| 242 |
+
@torch.no_grad()
|
| 243 |
+
def eval_qmos_val():
|
| 244 |
+
utmos.eval()
|
| 245 |
+
preds, gts = [], []
|
| 246 |
+
for s in va_q:
|
| 247 |
+
wave = load_wav_qmos(s)
|
| 248 |
+
if wave is None:
|
| 249 |
+
continue
|
| 250 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 251 |
+
preds.append(float(utmos_forward(wave).item()))
|
| 252 |
+
gts.append(qmos_gold[s])
|
| 253 |
+
return float(spearmanr(preds, gts).correlation)
|
| 254 |
+
|
| 255 |
+
# Baseline ZERO-SHOT (trước khi train) trên CÙNG val → mốc phải vượt
|
| 256 |
+
srcc_zeroshot = eval_qmos_val()
|
| 257 |
+
print(f"\n📍 UTMOS zero-shot (val nội bộ): SRCC = {srcc_zeroshot:.4f} "
|
| 258 |
+
f"(leaderboard DEV ~{QMOS_BASELINE['utmos_zeroshot']}; exp07 head {QMOS_BASELINE['exp07_head']})")
|
| 259 |
+
|
| 260 |
+
CKPT_QMOS = os.path.join(OUT_DIR, "ft_qmos_utmos.pt")
|
| 261 |
+
def save_qmos_ckpt(srcc):
|
| 262 |
+
torch.save({"utmos_state": {k: v.cpu() for k, v in utmos.state_dict().items()},
|
| 263 |
+
"val_srcc": float(srcc), "raw_scale": True,
|
| 264 |
+
"QMOS_MAX_SEC": QMOS_MAX_SEC, "FREEZE_FEAT_EXT": FREEZE_FEAT_EXT}, CKPT_QMOS)
|
| 265 |
+
|
| 266 |
+
best, best_state, bad = srcc_zeroshot, {k: v.cpu().clone() for k, v in utmos.state_dict().items()}, 0
|
| 267 |
+
save_qmos_ckpt(best) # lưu sẵn bản zero-shot (worst case vẫn = baseline)
|
| 268 |
+
|
| 269 |
+
# Gom theo CỬA SỔ = ACCUM mẫu HỢP LỆ (micro). Hai chế độ backward:
|
| 270 |
+
# • RANK off (mặc định) → backward NGAY từng mẫu → đồ thị giải phóng liền → VRAM thấp.
|
| 271 |
+
# • RANK on → ranking cần SO các pred TRONG cửa sổ → PHẢI giữ đồ thị cả cửa sổ →
|
| 272 |
+
# gom MSE (win_loss) + pred (buf_p) rồi backward MỘT lần (MSE_mean + λ·rank).
|
| 273 |
+
# ⚠️ Lỗi cũ: backward MSE từng bước đã giải phóng đồ thị → rank_loss.backward() sau đó
|
| 274 |
+
# sẽ lỗi "backward through the graph a second time". Bản này gom rồi backward 1 lần → hết lỗi.
|
| 275 |
+
for ep in range(1, EPOCHS + 1):
|
| 276 |
+
utmos.train()
|
| 277 |
+
opt.zero_grad()
|
| 278 |
+
np.random.shuffle(tr_q)
|
| 279 |
+
run = 0.0; nb = 0
|
| 280 |
+
micro = 0; win_loss = None; buf_p, buf_t = [], []
|
| 281 |
+
for s in tqdm(tr_q, desc=f"epoch {ep}"):
|
| 282 |
+
wave = load_wav_qmos(s)
|
| 283 |
+
if wave is None:
|
| 284 |
+
continue
|
| 285 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 286 |
+
pred = utmos_forward(wave)
|
| 287 |
+
loss = mse(pred, torch.tensor(qmos_gold[s], device=device, dtype=pred.dtype))
|
| 288 |
+
run += float(loss.item()); nb += 1
|
| 289 |
+
if RANK_LAMBDA > 0:
|
| 290 |
+
win_loss = loss if win_loss is None else win_loss + loss # GIỮ đồ thị (không backward ngay)
|
| 291 |
+
buf_p.append(pred); buf_t.append(qmos_gold[s]); micro += 1
|
| 292 |
+
else:
|
| 293 |
+
scaler.scale(loss / ACCUM).backward(); micro += 1 # backward ngay → VRAM thấp
|
| 294 |
+
if micro == ACCUM:
|
| 295 |
+
if RANK_LAMBDA > 0:
|
| 296 |
+
total = win_loss / micro
|
| 297 |
+
if len(buf_p) >= 2:
|
| 298 |
+
total = total + RANK_LAMBDA * pairwise_rank_loss(buf_p, buf_t)
|
| 299 |
+
scaler.scale(total).backward()
|
| 300 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 301 |
+
micro = 0; win_loss = None; buf_p, buf_t = [], []
|
| 302 |
+
# flush cửa sổ dư cuối epoch (số mẫu không chia hết cho ACCUM)
|
| 303 |
+
if micro > 0:
|
| 304 |
+
if RANK_LAMBDA > 0:
|
| 305 |
+
total = win_loss / micro
|
| 306 |
+
if len(buf_p) >= 2:
|
| 307 |
+
total = total + RANK_LAMBDA * pairwise_rank_loss(buf_p, buf_t)
|
| 308 |
+
scaler.scale(total).backward()
|
| 309 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 310 |
+
sc = eval_qmos_val()
|
| 311 |
+
print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | val SRCC {sc:.4f} "
|
| 312 |
+
f"(zero-shot {srcc_zeroshot:.4f} · best {max(best,sc):.4f})")
|
| 313 |
+
if sc > best:
|
| 314 |
+
best = sc
|
| 315 |
+
best_state = {k: v.cpu().clone() for k, v in utmos.state_dict().items()}
|
| 316 |
+
save_qmos_ckpt(best)
|
| 317 |
+
print(f" 💾 lưu best → {CKPT_QMOS} (epoch {ep}, SRCC {sc:.4f})")
|
| 318 |
+
bad = 0
|
| 319 |
+
else:
|
| 320 |
+
bad += 1
|
| 321 |
+
if bad >= PATIENCE:
|
| 322 |
+
print(f"Early stop ở epoch {ep}."); break
|
| 323 |
+
|
| 324 |
+
utmos.load_state_dict(best_state)
|
| 325 |
+
print(f"\n✅ PHẦN A xong — QMOS val nội bộ: zero-shot {srcc_zeroshot:.4f} → fine-tune {best:.4f} "
|
| 326 |
+
+ ("🚀 cải thiện" if best > srcc_zeroshot + 1e-4 else "➖ KHÔNG vượt zero-shot"))
|
| 327 |
+
if best <= srcc_zeroshot + 1e-4:
|
| 328 |
+
print(" ⚠️ Fine-tune chưa vượt zero-shot → cân nhắc tăng EPOCHS / bật RANK_LAMBDA=0.3 / "
|
| 329 |
+
"mở băng feature-extractor (FREEZE_FEAT_EXT=False); hoặc giữ QMOS exp07 (0.548).")
|
| 330 |
+
|
| 331 |
+
# %% [markdown]
|
| 332 |
+
# ## 4. PHẦN A (tiếp) — Dự đoán QMOS cho DEV bằng UTMOS đã fine-tune
|
| 333 |
+
|
| 334 |
+
# %%
|
| 335 |
+
def list_dev():
|
| 336 |
+
with open(DEV_SCP) as f:
|
| 337 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 338 |
+
|
| 339 |
+
dev_names = list_dev()
|
| 340 |
+
if LIMIT_DEV:
|
| 341 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 342 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 343 |
+
|
| 344 |
+
@torch.no_grad()
|
| 345 |
+
def predict_qmos(name):
|
| 346 |
+
p = os.path.join(WAV_DIR, name if str(name).endswith(".wav") else str(name) + ".wav")
|
| 347 |
+
if not os.path.exists(p):
|
| 348 |
+
return None
|
| 349 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 350 |
+
wave = wave[: QMOS_MAX_SEC * SR].astype(np.float32)
|
| 351 |
+
utmos.eval()
|
| 352 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 353 |
+
v = float(utmos_forward(wave).item())
|
| 354 |
+
return float(np.clip(v, 1.0, 5.0))
|
| 355 |
+
|
| 356 |
+
qmos_pred = {}
|
| 357 |
+
n_real = n_def = 0
|
| 358 |
+
for name in tqdm(dev_names, desc="QMOS dev"):
|
| 359 |
+
v = predict_qmos(name)
|
| 360 |
+
if v is None:
|
| 361 |
+
v = 3.0; n_def += 1
|
| 362 |
+
else:
|
| 363 |
+
n_real += 1
|
| 364 |
+
qmos_pred[name] = v
|
| 365 |
+
print(f"QMOS dự đoán: thật {n_real}, mặc định {n_def}")
|
| 366 |
+
|
| 367 |
+
# Giải phóng UTMOS trước khi nạp backbone cảm xúc (đỡ VRAM T4)
|
| 368 |
+
del utmos, opt, scaler
|
| 369 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 370 |
+
|
| 371 |
+
# %% [markdown]
|
| 372 |
+
# ## 5. PHẦN B — Nạp ckpt exp08 (WavLM ft + audeering) → 5 cột cảm xúc cho DEV
|
| 373 |
+
# Tái dùng nguyên cơ chế load của exp08b: dựng kiến trúc → `load_state_dict` từ `ft_emotion_full_20epoch.pt`.
|
| 374 |
+
|
| 375 |
+
# %%
|
| 376 |
+
import torch.nn.functional as F
|
| 377 |
+
|
| 378 |
+
ckpt = torch.load(EMO_CKPT, map_location="cpu", weights_only=False) # ckpt có numpy → weights_only=False
|
| 379 |
+
assert "wavlm" in ckpt, ("❌ EMO_CKPT không có 'wavlm' (backbone). Cần ft_emotion_full_20epoch.pt (bản đủ backbone), "
|
| 380 |
+
"KHÔNG phải ft_emotion_meta.pt cũ.")
|
| 381 |
+
print("✅ Nạp ckpt cảm xúc:", EMO_CKPT, "| keys:", list(ckpt.keys()))
|
| 382 |
+
|
| 383 |
+
def find_hf_backbone(module):
|
| 384 |
+
cands = []
|
| 385 |
+
for nm, m in module.named_modules():
|
| 386 |
+
enc = getattr(m, "encoder", None)
|
| 387 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 388 |
+
and getattr(enc, "layers", None) is not None:
|
| 389 |
+
cands.append((nm, m))
|
| 390 |
+
if not cands:
|
| 391 |
+
return None, None
|
| 392 |
+
cands.sort(key=lambda x: sum(p.numel() for p in x[1].parameters()), reverse=True)
|
| 393 |
+
return cands[0]
|
| 394 |
+
|
| 395 |
+
wavlm = None
|
| 396 |
+
try:
|
| 397 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 398 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 399 |
+
_name, wavlm = find_hf_backbone(_wrapper)
|
| 400 |
+
if wavlm is not None:
|
| 401 |
+
print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{_name}'")
|
| 402 |
+
except Exception as e:
|
| 403 |
+
print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
|
| 404 |
+
if wavlm is None:
|
| 405 |
+
from transformers import WavLMModel
|
| 406 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 407 |
+
print("ℹ️ Fallback: microsoft/wavlm-large.")
|
| 408 |
+
|
| 409 |
+
wavlm = wavlm.to(device).eval()
|
| 410 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 411 |
+
miss, unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
|
| 412 |
+
print(f"🔁 load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0).")
|
| 413 |
+
|
| 414 |
+
def masked_mean(hidden, attn_mask):
|
| 415 |
+
if attn_mask is None:
|
| 416 |
+
return hidden.mean(dim=1)
|
| 417 |
+
try:
|
| 418 |
+
fm = wavlm._get_feature_vector_attention_mask(hidden.shape[1], attn_mask)
|
| 419 |
+
except Exception:
|
| 420 |
+
return hidden.mean(dim=1)
|
| 421 |
+
fm = fm.unsqueeze(-1).to(hidden.dtype)
|
| 422 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 423 |
+
|
| 424 |
+
@torch.no_grad()
|
| 425 |
+
def wavlm_embed(input_values, attn_mask):
|
| 426 |
+
out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
|
| 427 |
+
return masked_mean(out, attn_mask)
|
| 428 |
+
|
| 429 |
+
# ── audeering FROZEN (đặc trưng phụ) — như exp08 ──
|
| 430 |
+
AUD_DIM = 0
|
| 431 |
+
aud_backbone = aud_head = aud_proc = None
|
| 432 |
+
if USE_AUDEERING:
|
| 433 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 434 |
+
from huggingface_hub import hf_hub_download
|
| 435 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 436 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 437 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 438 |
+
aud_backbone = Wav2Vec2Model(aud_cfg)
|
| 439 |
+
try:
|
| 440 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 441 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 442 |
+
except Exception:
|
| 443 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 444 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 445 |
+
aud_backbone.load_state_dict(bb_sd, strict=False)
|
| 446 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 447 |
+
_out = _sd["classifier.out_proj.weight"].shape[0]
|
| 448 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
|
| 449 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 450 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 451 |
+
aud_backbone = aud_backbone.to(device).eval()
|
| 452 |
+
aud_head = aud_head.to(device).eval()
|
| 453 |
+
AUD_DIM = _hid + 3
|
| 454 |
+
print(f"✅ audeering frozen ({AUD_DIM}-D)")
|
| 455 |
+
|
| 456 |
+
def load_wav_emo(sid):
|
| 457 |
+
p = os.path.join(WAV_DIR, sid + ".wav")
|
| 458 |
+
if not os.path.exists(p):
|
| 459 |
+
return None
|
| 460 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 461 |
+
return wave[: EMO_MAX_SEC * SR].astype(np.float32)
|
| 462 |
+
|
| 463 |
+
@torch.no_grad()
|
| 464 |
+
def extract_audeering(stems, tag):
|
| 465 |
+
if not USE_AUDEERING:
|
| 466 |
+
return {}
|
| 467 |
+
cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
|
| 468 |
+
store = {}
|
| 469 |
+
if os.path.exists(cache_path):
|
| 470 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 471 |
+
store = {k: z[k] for k in z.files}
|
| 472 |
+
print(f"[aud/{tag}] nạp cache: {len(store)}")
|
| 473 |
+
todo = [s for s in stems if s not in store]
|
| 474 |
+
for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
|
| 475 |
+
wave = load_wav_emo(s)
|
| 476 |
+
if wave is None:
|
| 477 |
+
continue
|
| 478 |
+
x = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 479 |
+
x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
|
| 480 |
+
h = aud_backbone(x)[0].mean(dim=1)
|
| 481 |
+
out = aud_head(h)[0].cpu().numpy()
|
| 482 |
+
vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
|
| 483 |
+
store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
|
| 484 |
+
if (i + 1) % 500 == 0:
|
| 485 |
+
np.savez(cache_path, **store)
|
| 486 |
+
if todo:
|
| 487 |
+
np.savez(cache_path, **store)
|
| 488 |
+
return store
|
| 489 |
+
|
| 490 |
+
# ── EmoHeads (khớp exp08) + nạp trọng số head + thống kê chuẩn hóa từ ckpt ──
|
| 491 |
+
N_EMO = len(EMOTIONS5)
|
| 492 |
+
TRUNK_IN = WAVLM_DIM + (AUD_DIM if USE_AUDEERING else 0)
|
| 493 |
+
|
| 494 |
+
class EmoHeads(nn.Module):
|
| 495 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 496 |
+
super().__init__()
|
| 497 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 498 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 499 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 500 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 501 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 502 |
+
def forward(self, feat, tgt):
|
| 503 |
+
h = self.trunk(feat)
|
| 504 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 505 |
+
|
| 506 |
+
heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device).eval()
|
| 507 |
+
hmiss, hunexp = heads.load_state_dict(ckpt["heads"], strict=False)
|
| 508 |
+
print(f"🔁 load heads từ ckpt: thiếu {len(hmiss)} / dư {len(hunexp)} key (kỳ vọng 0).")
|
| 509 |
+
|
| 510 |
+
emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
|
| 511 |
+
vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
|
| 512 |
+
print(f"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
|
| 513 |
+
|
| 514 |
+
# Target cảm xúc (cho EMOS head) từ metadata
|
| 515 |
+
def load_target_emotions():
|
| 516 |
+
tgt = {}
|
| 517 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 518 |
+
for ln in f:
|
| 519 |
+
parts = ln.strip().split("|")
|
| 520 |
+
if len(parts) >= 2:
|
| 521 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 522 |
+
return tgt
|
| 523 |
+
|
| 524 |
+
target_map = load_target_emotions()
|
| 525 |
+
|
| 526 |
+
def onehot_target(tgt):
|
| 527 |
+
v = np.zeros(N_EMO, dtype=np.float32)
|
| 528 |
+
if tgt in EMOTIONS5:
|
| 529 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 530 |
+
return v
|
| 531 |
+
|
| 532 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 533 |
+
aud_dev = extract_audeering(dev_stems, "dev")
|
| 534 |
+
|
| 535 |
+
@torch.no_grad()
|
| 536 |
+
def predict_emotion(sid):
|
| 537 |
+
wave = load_wav_emo(sid)
|
| 538 |
+
if wave is None or (USE_AUDEERING and sid not in aud_dev):
|
| 539 |
+
return None
|
| 540 |
+
iv = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 541 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
|
| 542 |
+
tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
|
| 543 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 544 |
+
fw = wavlm_embed(iv, am)
|
| 545 |
+
feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw
|
| 546 |
+
emos_p, cat_l, vad_p = heads(feat, tgt)
|
| 547 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 548 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 549 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 550 |
+
return emos, cat5, vad3
|
| 551 |
+
|
| 552 |
+
# %% [markdown]
|
| 553 |
+
# ## 6. PHẦN C — Ghép QMOS (fine-tune) + 5 cột cảm xúc (exp08) → answer.txt 6 cột
|
| 554 |
+
|
| 555 |
+
# %%
|
| 556 |
+
def fmt_cat(p5):
|
| 557 |
+
return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 558 |
+
|
| 559 |
+
def build_answer(out_path):
|
| 560 |
+
n_real = n_def = 0
|
| 561 |
+
with open(out_path, "w") as f:
|
| 562 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 563 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 564 |
+
sid = stem(name)
|
| 565 |
+
pr = predict_emotion(sid)
|
| 566 |
+
if pr is None:
|
| 567 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
|
| 568 |
+
else:
|
| 569 |
+
emos, cat5, vad3 = pr; n_real += 1
|
| 570 |
+
qmos = qmos_pred.get(name, qmos_pred.get(sid, 3.0))
|
| 571 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 572 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
|
| 573 |
+
|
| 574 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 575 |
+
build_answer(answer_path)
|
| 576 |
+
|
| 577 |
+
# %% [markdown]
|
| 578 |
+
# ## 7. Validate + zip
|
| 579 |
+
|
| 580 |
+
# %%
|
| 581 |
+
def validate(path):
|
| 582 |
+
import csv
|
| 583 |
+
with open(path) as f:
|
| 584 |
+
rows = list(csv.reader(f))
|
| 585 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0] and "EMOS" in rows[0], "Header sai"
|
| 586 |
+
for i, r in enumerate(rows[1:], 2):
|
| 587 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 588 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 589 |
+
|
| 590 |
+
validate(answer_path)
|
| 591 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp13_ft-qmos.zip answer.txt "
|
| 592 |
+
f"&& unzip -l submission_track2_exp13_ft-qmos.zip")
|
| 593 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp13_ft-qmos.zip"))
|
| 594 |
+
|
| 595 |
+
# %% [markdown]
|
| 596 |
+
# ## Ghi chú
|
| 597 |
+
# - **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để chạy trơn (không OOM, 1 epoch xong); rồi đặt `None`.
|
| 598 |
+
# - **OOM trên T4?** giảm `QMOS_MAX_SEC` (12→10→8); giữ `FREEZE_FEAT_EXT=True`; `BATCH=1` đã là min.
|
| 599 |
+
# ⚠️ **Bật `RANK_LAMBDA>0` tốn VRAM hơn** vì phải GIỮ đồ thị cả cửa sổ ACCUM (=16) để so thứ hạng →
|
| 600 |
+
# nếu OOM khi bật ranking: giảm `ACCUM` (vd 8, cũng là kích thước nhóm ranking) hoặc `QMOS_MAX_SEC`.
|
| 601 |
+
# - **Đọc mục A:** so `val SRCC fine-tune` với `zero-shot`. Chỉ nộp QMOS fine-tune nếu **vượt zero-shot**
|
| 602 |
+
# (lý tưởng vượt cả exp07 0.548); nếu không → giữ QMOS exp07 (Add Input answer.txt exp07, đổi cột QMOS).
|
| 603 |
+
# - Nếu chưa vượt: tăng `EPOCHS`, bật `RANK_LAMBDA=0.3` (tối ưu thẳng thứ hạng), hoặc `FREEZE_FEAT_EXT=False`
|
| 604 |
+
# (mở băng feature-extractor — mạnh hơn nhưng dễ overfit + nặng VRAM).
|
| 605 |
+
# - **Lưu checkpoint:** `ft_qmos_utmos.pt` lưu mỗi best → **Save Version NGAY** sau khi chạy (bài học exp08).
|
| 606 |
+
# - **License QMOS:** UTMOS/SpeechMOS (kiểm tra license tarepan/SpeechMOS) — khai báo `docs/12_system_description.md`.
|
| 607 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp13).
|
track2/exp14_mamba_head.ipynb
ADDED
|
@@ -0,0 +1,952 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "63b4bfa4",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp14 (MAMBA temporal head, CỘNG vào FUSION 6 cột) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Ý tưởng (theo gợi ý mentor \"thử Mamba\"):** exp04/exp07 đều **mean-pool** đặc trưng SSL →\n",
|
| 11 |
+
"mỗi wav thành 1 vector → mất hết **động lực theo thời gian** (lên/xuống giọng, ngắt quãng, rung).\n",
|
| 12 |
+
"**Mamba** là State Space Model (SSM) xử lý **chuỗi** với độ phức tạp tuyến tính → cho nó **dãy frame**\n",
|
| 13 |
+
"(chưa pool) để học temporal dynamics, rồi mới pool. Tham khảo: MambaRate (AudioMOS 2025), arXiv:2507.12090.\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"## exp14 = exp07 + 1 nhánh Mamba (CỘNG thêm, không thay thế)\n",
|
| 16 |
+
"```\n",
|
| 17 |
+
" ┌─ đặc trưng POOLED [e2v_emb|e2v_p5|sailer_emb|sailer_p9|sailer_vad3] (y hệt exp07 → DÙNG LẠI cache)\n",
|
| 18 |
+
" mỗi wav ──┤\n",
|
| 19 |
+
" └─ WavLM frame-level (chuỗi T×1024) ─► Mamba (2 lớp, 2 chiều) ─► attn-pool ─► z_seq (Z chiều)\n",
|
| 20 |
+
" │\n",
|
| 21 |
+
" concat ──► TRUNK chung ──► 6 head: QMOS · EMOS · CAT · VAL · ARO · DOM\n",
|
| 22 |
+
"```\n",
|
| 23 |
+
"- **Cờ `USE_MAMBA`:** `False` → chạy ra **đúng exp07** (kiểm chứng tái lập ~0.548/0.795). `True` → bật nhánh Mamba.\n",
|
| 24 |
+
" Đây CHÍNH là **ablation \"có/không Mamba\"** cho paper.\n",
|
| 25 |
+
"- WavLM **đóng băng** (chỉ trích đặc trưng) → Mamba head nhỏ → train nhanh, vừa T4.\n",
|
| 26 |
+
"\n",
|
| 27 |
+
"## 2 gotcha Kaggle đã xử trong file\n",
|
| 28 |
+
"1. `mamba-ssm` hay lỗi build CUDA → **nhúng sẵn Mamba thuần PyTorch** (không cần pip); tự dùng `mamba-ssm` nếu import được.\n",
|
| 29 |
+
"2. Cache frame-level RẤT nặng → **cap `MAX_FRAMES`** + lưu **fp16**. Ước lượng: MAX_FRAMES=256, 1024 chiều, fp16\n",
|
| 30 |
+
" ≈ 0.5 MB/wav → train ~12k ≈ 6 GB, dev ~2.7k ≈ 1.4 GB (vừa /kaggle/working). **Save Version** để giữ cache.\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"**Cách chạy:** GPU T4 + Internet On → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.\n",
|
| 33 |
+
"Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để soi nhanh; OK rồi đặt `None`."
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "markdown",
|
| 38 |
+
"id": "3fe243f8",
|
| 39 |
+
"metadata": {},
|
| 40 |
+
"source": [
|
| 41 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"id": "bd2e582a",
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": [
|
| 51 |
+
"import os\n",
|
| 52 |
+
"\n",
|
| 53 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
|
| 54 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 55 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
|
| 56 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
|
| 57 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 58 |
+
"\n",
|
| 59 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 60 |
+
"CACHE_DIR = \"/kaggle/working/fusion_cache\" # DÙNG CHUNG với exp04/exp07 (e2v_*, sailer_*, utmos_*)\n",
|
| 61 |
+
"SEQ_DIR = \"/kaggle/working/wavlm_seq_cache\" # MỚI: cache frame-level WavLM (fp16)\n",
|
| 62 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 63 |
+
"os.makedirs(SEQ_DIR, exist_ok=True)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"# ── Bật/tắt nhánh Mamba (ablation chính) ─────────────────────────────────────\n",
|
| 66 |
+
"USE_MAMBA = True # False → ra ĐÚNG exp07 (sanity check). True → bật nhánh Mamba.\n",
|
| 67 |
+
"\n",
|
| 68 |
+
"# ── Siêu tham số nhánh Mamba ─────────────────────────────────────────────────\n",
|
| 69 |
+
"WAVLM_NAME = \"microsoft/wavlm-large\" # backbone frame-level (đóng băng). Trả chuỗi (T, 1024).\n",
|
| 70 |
+
"MAX_FRAMES = 256 # cap độ dài chuỗi (256 frame ≈ 5.1s @ 50Hz). Giảm nếu hết đĩa.\n",
|
| 71 |
+
"MAMBA_DMODEL = 256 # chiều ẩn của khối Mamba (proj 1024→256 trước khi vào Mamba)\n",
|
| 72 |
+
"MAMBA_LAYERS = 2 # số khối Mamba xếp chồng\n",
|
| 73 |
+
"MAMBA_DSTATE = 16 # chiều state SSM\n",
|
| 74 |
+
"BIDIRECTIONAL = True # chạy Mamba cả 2 chiều (xuôi + ngược) rồi cộng\n",
|
| 75 |
+
"Z_DIM = 128 # chiều vector z_seq sau attentive-pool, đem concat vào fusion\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"# ── Siêu tham số fusion (giống exp07) ────────────────────────────────────────\n",
|
| 78 |
+
"DEVICE = \"cuda\"\n",
|
| 79 |
+
"TRUNK_HIDDEN = 512\n",
|
| 80 |
+
"HEAD_HIDDEN = 128\n",
|
| 81 |
+
"DROPOUT = 0.3\n",
|
| 82 |
+
"LR = 1e-3\n",
|
| 83 |
+
"EPOCHS = 80\n",
|
| 84 |
+
"BATCH = 32 # nhỏ hơn exp07 (64) vì có nhánh Mamba tốn RAM hơn\n",
|
| 85 |
+
"VAL_FRAC = 0.10\n",
|
| 86 |
+
"PATIENCE = 15\n",
|
| 87 |
+
"SEED = 42\n",
|
| 88 |
+
"\n",
|
| 89 |
+
"USE_UNCERTAINTY = True\n",
|
| 90 |
+
"LOSS_W = {\"qmos\": 1.0, \"emos\": 1.0, \"cat\": 1.0, \"val\": 1.0, \"aro\": 1.0, \"dom\": 1.0}\n",
|
| 91 |
+
"USE_E2V = True\n",
|
| 92 |
+
"USE_SAILER = True\n",
|
| 93 |
+
"USE_CLASSPROB = True\n",
|
| 94 |
+
"USE_UTMOS_FEAT = True\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"LIMIT_TRAIN = None\n",
|
| 97 |
+
"LIMIT_DEV = None\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"# Mốc exp07 để so (đây là hệ thống đang tốt nhất)\n",
|
| 100 |
+
"EXP07 = {\"qmos\": 0.548, \"emos\": 0.795, \"cat_err\": 0.153, \"val\": 0.581, \"aro\": 0.752, \"dom\": 0.705}\n",
|
| 101 |
+
"\n",
|
| 102 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 103 |
+
"SAILER9 = [\"Anger\", \"Contempt\", \"Disgust\", \"Fear\", \"Happiness\", \"Neutral\", \"Sadness\", \"Surprise\", \"Other\"]\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"_EMO_ALIAS = {\n",
|
| 106 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 107 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 108 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 109 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 110 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 111 |
+
"}\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"def norm_emotion(label):\n",
|
| 114 |
+
" key = str(label).strip().lower()\n",
|
| 115 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 116 |
+
"\n",
|
| 117 |
+
"def stem(p):\n",
|
| 118 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"assert USE_E2V or USE_SAILER, \"Phải bật ít nhất 1 backbone pooled.\"\n",
|
| 121 |
+
"print(\"USE_MAMBA =\", USE_MAMBA, \"| nếu False → ra đúng exp07\")\n",
|
| 122 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 123 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 124 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 125 |
+
]
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"cell_type": "markdown",
|
| 129 |
+
"id": "5ad58750",
|
| 130 |
+
"metadata": {},
|
| 131 |
+
"source": [
|
| 132 |
+
"## 1. Cài đặt + tải code SAILER\n",
|
| 133 |
+
"Chỉ cài gói còn thiếu (Kaggle có sẵn torch/transformers). KHÔNG đụng numpy (tránh lệch ABI torch — bài học exp12)."
|
| 134 |
+
]
|
| 135 |
+
},
|
| 136 |
+
{
|
| 137 |
+
"cell_type": "code",
|
| 138 |
+
"execution_count": null,
|
| 139 |
+
"id": "3260eb06",
|
| 140 |
+
"metadata": {},
|
| 141 |
+
"outputs": [],
|
| 142 |
+
"source": [
|
| 143 |
+
"import sys, subprocess\n",
|
| 144 |
+
"\n",
|
| 145 |
+
"def pip_install(*pkgs):\n",
|
| 146 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 147 |
+
"\n",
|
| 148 |
+
"pip_install(\"speechmos\", \"funasr\", \"librosa\", \"soundfile\", \"pandas\", \"scipy\", \"scikit-learn\", \"tqdm\")\n",
|
| 149 |
+
"\n",
|
| 150 |
+
"if USE_SAILER:\n",
|
| 151 |
+
" pip_install(\"loralib\", \"speechbrain\")\n",
|
| 152 |
+
" REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 153 |
+
" if not os.path.exists(REPO_DIR):\n",
|
| 154 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 155 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 156 |
+
" if REPO_DIR not in sys.path:\n",
|
| 157 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 158 |
+
]
|
| 159 |
+
},
|
| 160 |
+
{
|
| 161 |
+
"cell_type": "markdown",
|
| 162 |
+
"id": "f92c0e17",
|
| 163 |
+
"metadata": {},
|
| 164 |
+
"source": [
|
| 165 |
+
"## 2. Đọc & gộp nhãn theo wavID (giống exp07)"
|
| 166 |
+
]
|
| 167 |
+
},
|
| 168 |
+
{
|
| 169 |
+
"cell_type": "code",
|
| 170 |
+
"execution_count": null,
|
| 171 |
+
"id": "bab3f8d5",
|
| 172 |
+
"metadata": {},
|
| 173 |
+
"outputs": [],
|
| 174 |
+
"source": [
|
| 175 |
+
"import numpy as np\n",
|
| 176 |
+
"import pandas as pd\n",
|
| 177 |
+
"\n",
|
| 178 |
+
"def load_target_emotions():\n",
|
| 179 |
+
" tgt = {}\n",
|
| 180 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 181 |
+
" for ln in f:\n",
|
| 182 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 183 |
+
" if len(parts) < 2:\n",
|
| 184 |
+
" continue\n",
|
| 185 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 186 |
+
" return tgt\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"def _col(cols_map, *names, default_idx=None, df=None):\n",
|
| 189 |
+
" for n in names:\n",
|
| 190 |
+
" if n in cols_map:\n",
|
| 191 |
+
" return cols_map[n]\n",
|
| 192 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"def parse_emocat_votes(cell):\n",
|
| 195 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 196 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 197 |
+
" e = norm_emotion(tok)\n",
|
| 198 |
+
" if e in EMOTIONS5:\n",
|
| 199 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 200 |
+
" return v\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"def load_train_labels():\n",
|
| 203 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 204 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 205 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", default_idx=1, df=df)\n",
|
| 206 |
+
" qmos_col = _col(cols, \"qmos\", \"mos\")\n",
|
| 207 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 208 |
+
" val_col = _col(cols, \"val\", \"valence\")\n",
|
| 209 |
+
" aro_col = _col(cols, \"aro\", \"arousal\")\n",
|
| 210 |
+
" dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 211 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 212 |
+
" assert qmos_col and emos_col, f\"Thiếu cột qMOS/eMOS (cột: {list(df.columns)})\"\n",
|
| 213 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 214 |
+
" rows = []\n",
|
| 215 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 216 |
+
" rec = {\"wavID\": sid, \"qmos\": float(g[qmos_col].mean()), \"emos\": float(g[emos_col].mean())}\n",
|
| 217 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 218 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 219 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 220 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 221 |
+
" if cat_col:\n",
|
| 222 |
+
" for cell in g[cat_col]:\n",
|
| 223 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 224 |
+
" s = votes.sum()\n",
|
| 225 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)\n",
|
| 226 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 227 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 228 |
+
" rows.append(rec)\n",
|
| 229 |
+
" return pd.DataFrame(rows)\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"target_map = load_target_emotions()\n",
|
| 232 |
+
"train_df = load_train_labels()\n",
|
| 233 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 234 |
+
"print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"cell_type": "markdown",
|
| 239 |
+
"id": "a5cd1ff1",
|
| 240 |
+
"metadata": {},
|
| 241 |
+
"source": [
|
| 242 |
+
"## 3. Đặc trưng POOLED (e2v + sailer + UTMOS) — TÁI DÙNG cache exp04/exp07\n",
|
| 243 |
+
"(Y hệt exp07; nếu đã chạy exp07 thì cache `fusion_cache/` còn nguyên → không tính lại.)"
|
| 244 |
+
]
|
| 245 |
+
},
|
| 246 |
+
{
|
| 247 |
+
"cell_type": "code",
|
| 248 |
+
"execution_count": null,
|
| 249 |
+
"id": "8c31b6a4",
|
| 250 |
+
"metadata": {
|
| 251 |
+
"lines_to_next_cell": 1
|
| 252 |
+
},
|
| 253 |
+
"outputs": [],
|
| 254 |
+
"source": [
|
| 255 |
+
"import torch\n",
|
| 256 |
+
"import torch.nn.functional as F\n",
|
| 257 |
+
"\n",
|
| 258 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 259 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU\")\n",
|
| 260 |
+
"\n",
|
| 261 |
+
"def extract_e2v(stems, tag):\n",
|
| 262 |
+
" from tqdm.auto import tqdm\n",
|
| 263 |
+
" cache_path = os.path.join(CACHE_DIR, f\"e2v_{tag}.npz\")\n",
|
| 264 |
+
" store = {}\n",
|
| 265 |
+
" if os.path.exists(cache_path):\n",
|
| 266 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 267 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 268 |
+
" print(f\"[e2v/{tag}] nạp cache: {len(store)}\")\n",
|
| 269 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 270 |
+
" if todo:\n",
|
| 271 |
+
" from funasr import AutoModel\n",
|
| 272 |
+
" m = AutoModel(model=\"iic/emotion2vec_plus_large\", hub=\"hf\", device=device)\n",
|
| 273 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"e2v {tag}\")):\n",
|
| 274 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 275 |
+
" if not os.path.exists(wav):\n",
|
| 276 |
+
" continue\n",
|
| 277 |
+
" r = m.generate(wav, granularity=\"utterance\", extract_embedding=True)[0]\n",
|
| 278 |
+
" emb = np.asarray(r[\"feats\"], dtype=np.float32).reshape(-1)\n",
|
| 279 |
+
" probs = {e: 0.0 for e in EMOTIONS5}\n",
|
| 280 |
+
" for lab, sc in zip(r[\"labels\"], r[\"scores\"]):\n",
|
| 281 |
+
" name = lab.split(\"/\")[-1]\n",
|
| 282 |
+
" if name in probs:\n",
|
| 283 |
+
" probs[name] = float(sc)\n",
|
| 284 |
+
" tot = sum(probs.values())\n",
|
| 285 |
+
" p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)\n",
|
| 286 |
+
" store[s] = np.concatenate([emb, p5]).astype(np.float32)\n",
|
| 287 |
+
" if (i + 1) % 500 == 0:\n",
|
| 288 |
+
" np.savez(cache_path, **store)\n",
|
| 289 |
+
" np.savez(cache_path, **store)\n",
|
| 290 |
+
" del m\n",
|
| 291 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 292 |
+
" return {s: (v[:-5], v[-5:]) for s, v in store.items()}\n",
|
| 293 |
+
"\n",
|
| 294 |
+
"def _pool_feat(features):\n",
|
| 295 |
+
" f = features.detach().cpu().numpy()\n",
|
| 296 |
+
" if f.ndim <= 1:\n",
|
| 297 |
+
" return f.reshape(-1).astype(np.float32)\n",
|
| 298 |
+
" return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)\n",
|
| 299 |
+
"\n",
|
| 300 |
+
"def extract_sailer(stems, tag):\n",
|
| 301 |
+
" import librosa\n",
|
| 302 |
+
" from tqdm.auto import tqdm\n",
|
| 303 |
+
" cache_path = os.path.join(CACHE_DIR, f\"sailer_{tag}.npz\")\n",
|
| 304 |
+
" store = {}\n",
|
| 305 |
+
" if os.path.exists(cache_path):\n",
|
| 306 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 307 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 308 |
+
" print(f\"[sailer/{tag}] nạp cache: {len(store)}\")\n",
|
| 309 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 310 |
+
" if todo:\n",
|
| 311 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper\n",
|
| 312 |
+
" sailer = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\").to(device).eval()\n",
|
| 313 |
+
" with torch.no_grad():\n",
|
| 314 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"sailer {tag}\")):\n",
|
| 315 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 316 |
+
" if not os.path.exists(wav):\n",
|
| 317 |
+
" continue\n",
|
| 318 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 319 |
+
" wave = wave[: 15 * 16000]\n",
|
| 320 |
+
" data = torch.from_numpy(wave).float().unsqueeze(0).to(device)\n",
|
| 321 |
+
" logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)\n",
|
| 322 |
+
" emb = _pool_feat(feat)\n",
|
| 323 |
+
" p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)\n",
|
| 324 |
+
" vad3 = np.array([1 + 4 * float(valence.item()),\n",
|
| 325 |
+
" 1 + 4 * float(arousal.item()),\n",
|
| 326 |
+
" 1 + 4 * float(dominance.item())], dtype=np.float32)\n",
|
| 327 |
+
" store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)\n",
|
| 328 |
+
" if (i + 1) % 500 == 0:\n",
|
| 329 |
+
" np.savez(cache_path, **store)\n",
|
| 330 |
+
" np.savez(cache_path, **store)\n",
|
| 331 |
+
" del sailer\n",
|
| 332 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 333 |
+
" return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"def extract_utmos(names, tag):\n",
|
| 336 |
+
" import librosa\n",
|
| 337 |
+
" from tqdm.auto import tqdm\n",
|
| 338 |
+
" cache_path = os.path.join(CACHE_DIR, f\"utmos_{tag}.npz\")\n",
|
| 339 |
+
" store = {}\n",
|
| 340 |
+
" if os.path.exists(cache_path):\n",
|
| 341 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 342 |
+
" store = {k: float(z[k]) for k in z.files}\n",
|
| 343 |
+
" print(f\"[utmos/{tag}] nạp cache: {len(store)}\")\n",
|
| 344 |
+
" todo = [n for n in names if stem(n) not in store]\n",
|
| 345 |
+
" if todo:\n",
|
| 346 |
+
" predictor = torch.hub.load(\"tarepan/SpeechMOS:v1.2.0\", \"utmos22_strong\",\n",
|
| 347 |
+
" trust_repo=True).to(device).eval()\n",
|
| 348 |
+
" with torch.no_grad():\n",
|
| 349 |
+
" for i, n in enumerate(tqdm(todo, desc=f\"utmos {tag}\")):\n",
|
| 350 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else n + \".wav\")\n",
|
| 351 |
+
" if not os.path.exists(wav):\n",
|
| 352 |
+
" continue\n",
|
| 353 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 354 |
+
" store[stem(n)] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),\n",
|
| 355 |
+
" sr=16000).mean().item())\n",
|
| 356 |
+
" if (i + 1) % 500 == 0:\n",
|
| 357 |
+
" np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 358 |
+
" np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})\n",
|
| 359 |
+
" del predictor\n",
|
| 360 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 361 |
+
" return store"
|
| 362 |
+
]
|
| 363 |
+
},
|
| 364 |
+
{
|
| 365 |
+
"cell_type": "markdown",
|
| 366 |
+
"id": "a6a9dfc9",
|
| 367 |
+
"metadata": {},
|
| 368 |
+
"source": [
|
| 369 |
+
"## 3b. Đặc trưng FRAME-LEVEL WavLM (chuỗi T×1024) cho nhánh Mamba — cache fp16\n",
|
| 370 |
+
"Mỗi wav lưu 1 file `.npy` riêng trong `SEQ_DIR` (mảng fp16 [T, 1024], T ≤ MAX_FRAMES).\n",
|
| 371 |
+
"WavLM **đóng băng** (eval, no_grad) → layerdrop tự tắt ở eval, không đụng gotcha checkpoint."
|
| 372 |
+
]
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"cell_type": "code",
|
| 376 |
+
"execution_count": null,
|
| 377 |
+
"id": "60c9e86e",
|
| 378 |
+
"metadata": {
|
| 379 |
+
"lines_to_next_cell": 1
|
| 380 |
+
},
|
| 381 |
+
"outputs": [],
|
| 382 |
+
"source": [
|
| 383 |
+
"_wavlm = None\n",
|
| 384 |
+
"def _get_wavlm():\n",
|
| 385 |
+
" \"\"\"Lazy-load microsoft/wavlm-large (đóng băng). Trả model + feature_extractor.\"\"\"\n",
|
| 386 |
+
" global _wavlm\n",
|
| 387 |
+
" if _wavlm is None:\n",
|
| 388 |
+
" from transformers import WavLMModel, AutoFeatureExtractor\n",
|
| 389 |
+
" fe = AutoFeatureExtractor.from_pretrained(WAVLM_NAME)\n",
|
| 390 |
+
" mdl = WavLMModel.from_pretrained(WAVLM_NAME).to(device).eval()\n",
|
| 391 |
+
" for p in mdl.parameters():\n",
|
| 392 |
+
" p.requires_grad = False\n",
|
| 393 |
+
" _wavlm = (mdl, fe)\n",
|
| 394 |
+
" return _wavlm\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"def seq_path(sid):\n",
|
| 397 |
+
" return os.path.join(SEQ_DIR, sid + \".npy\")\n",
|
| 398 |
+
"\n",
|
| 399 |
+
"def extract_wavlm_seq(stems, tag):\n",
|
| 400 |
+
" \"\"\"Trích frame-level WavLM cho từng wav, cache fp16 ra .npy. Trả set stem đã có.\"\"\"\n",
|
| 401 |
+
" if not USE_MAMBA:\n",
|
| 402 |
+
" return set()\n",
|
| 403 |
+
" import librosa\n",
|
| 404 |
+
" from tqdm.auto import tqdm\n",
|
| 405 |
+
" todo = [s for s in stems if not os.path.exists(seq_path(s))]\n",
|
| 406 |
+
" if todo:\n",
|
| 407 |
+
" mdl, fe = _get_wavlm()\n",
|
| 408 |
+
" with torch.no_grad():\n",
|
| 409 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"wavlm-seq {tag}\")):\n",
|
| 410 |
+
" wav = os.path.join(WAV_DIR, s + \".wav\")\n",
|
| 411 |
+
" if not os.path.exists(wav):\n",
|
| 412 |
+
" continue\n",
|
| 413 |
+
" wave, _ = librosa.load(wav, sr=16000, mono=True)\n",
|
| 414 |
+
" wave = wave[: 15 * 16000]\n",
|
| 415 |
+
" inp = fe(wave, sampling_rate=16000, return_tensors=\"pt\").input_values.to(device)\n",
|
| 416 |
+
" hs = mdl(inp).last_hidden_state[0] # (T, 1024)\n",
|
| 417 |
+
" if hs.shape[0] > MAX_FRAMES: # cap độ dài (đều theo thời gian)\n",
|
| 418 |
+
" idx = torch.linspace(0, hs.shape[0] - 1, MAX_FRAMES).long()\n",
|
| 419 |
+
" hs = hs[idx]\n",
|
| 420 |
+
" np.save(seq_path(s), hs.cpu().numpy().astype(np.float16))\n",
|
| 421 |
+
" torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 422 |
+
" return {s for s in stems if os.path.exists(seq_path(s))}\n",
|
| 423 |
+
"\n",
|
| 424 |
+
"def load_seq(sid):\n",
|
| 425 |
+
" \"\"\"Đọc chuỗi fp16 → tensor float32 (T, 1024). Thiếu file → None.\"\"\"\n",
|
| 426 |
+
" p = seq_path(sid)\n",
|
| 427 |
+
" if not os.path.exists(p):\n",
|
| 428 |
+
" return None\n",
|
| 429 |
+
" return torch.from_numpy(np.load(p).astype(np.float32))\n",
|
| 430 |
+
"\n",
|
| 431 |
+
"def collate_seqs(sids):\n",
|
| 432 |
+
" \"\"\"Gộp list chuỗi độ dài khác nhau → (B, Lmax, 1024) + mask (B, Lmax) bool (True=thật).\"\"\"\n",
|
| 433 |
+
" seqs = [load_seq(s) for s in sids]\n",
|
| 434 |
+
" lens = [t.shape[0] for t in seqs]\n",
|
| 435 |
+
" Lmax = max(lens)\n",
|
| 436 |
+
" B = len(seqs)\n",
|
| 437 |
+
" x = torch.zeros(B, Lmax, seqs[0].shape[1], dtype=torch.float32)\n",
|
| 438 |
+
" mask = torch.zeros(B, Lmax, dtype=torch.bool)\n",
|
| 439 |
+
" for i, t in enumerate(seqs):\n",
|
| 440 |
+
" x[i, : t.shape[0]] = t\n",
|
| 441 |
+
" mask[i, : t.shape[0]] = True\n",
|
| 442 |
+
" return x, mask"
|
| 443 |
+
]
|
| 444 |
+
},
|
| 445 |
+
{
|
| 446 |
+
"cell_type": "markdown",
|
| 447 |
+
"id": "328a5f30",
|
| 448 |
+
"metadata": {},
|
| 449 |
+
"source": [
|
| 450 |
+
"## 4. Dựng feature pooled + nhãn cho train (lọc các wav đủ mọi nguồn)"
|
| 451 |
+
]
|
| 452 |
+
},
|
| 453 |
+
{
|
| 454 |
+
"cell_type": "code",
|
| 455 |
+
"execution_count": null,
|
| 456 |
+
"id": "4449a153",
|
| 457 |
+
"metadata": {},
|
| 458 |
+
"outputs": [],
|
| 459 |
+
"source": [
|
| 460 |
+
"train_stems = list(train_df[\"wavID\"])\n",
|
| 461 |
+
"if LIMIT_TRAIN:\n",
|
| 462 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 463 |
+
"\n",
|
| 464 |
+
"e2v_tr = extract_e2v(train_stems, \"train\") if USE_E2V else {}\n",
|
| 465 |
+
"sailer_tr = extract_sailer(train_stems, \"train\") if USE_SAILER else {}\n",
|
| 466 |
+
"utmos_tr = extract_utmos(train_stems, \"train\") if USE_UTMOS_FEAT else {}\n",
|
| 467 |
+
"seq_tr = extract_wavlm_seq(train_stems, \"train\")\n",
|
| 468 |
+
"\n",
|
| 469 |
+
"def audio_feature(sid, e2v_map, sailer_map):\n",
|
| 470 |
+
" parts = []\n",
|
| 471 |
+
" if USE_E2V:\n",
|
| 472 |
+
" pk = e2v_map.get(sid)\n",
|
| 473 |
+
" if pk is None:\n",
|
| 474 |
+
" return None\n",
|
| 475 |
+
" emb, p5 = pk\n",
|
| 476 |
+
" parts.append(emb)\n",
|
| 477 |
+
" if USE_CLASSPROB:\n",
|
| 478 |
+
" parts.append(p5)\n",
|
| 479 |
+
" if USE_SAILER:\n",
|
| 480 |
+
" pk = sailer_map.get(sid)\n",
|
| 481 |
+
" if pk is None:\n",
|
| 482 |
+
" return None\n",
|
| 483 |
+
" emb, p9, vad3 = pk\n",
|
| 484 |
+
" parts.append(emb)\n",
|
| 485 |
+
" if USE_CLASSPROB:\n",
|
| 486 |
+
" parts.append(p9); parts.append(vad3)\n",
|
| 487 |
+
" return np.concatenate(parts).astype(np.float32)\n",
|
| 488 |
+
"\n",
|
| 489 |
+
"def onehot_target(tgt):\n",
|
| 490 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 491 |
+
" if tgt in EMOTIONS5:\n",
|
| 492 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 493 |
+
" return v\n",
|
| 494 |
+
"\n",
|
| 495 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 496 |
+
"keep_sids, X, T, U = [], [], [], []\n",
|
| 497 |
+
"y_qmos, y_emos, y_vad, y_cat = [], [], [], []\n",
|
| 498 |
+
"for s in train_stems:\n",
|
| 499 |
+
" f = audio_feature(s, e2v_tr, sailer_tr)\n",
|
| 500 |
+
" tgt = target_map.get(s)\n",
|
| 501 |
+
" if f is None or tgt is None or s not in lab.index:\n",
|
| 502 |
+
" continue\n",
|
| 503 |
+
" if USE_UTMOS_FEAT and s not in utmos_tr:\n",
|
| 504 |
+
" continue\n",
|
| 505 |
+
" if USE_MAMBA and s not in seq_tr: # cần có chuỗi WavLM nếu bật Mamba\n",
|
| 506 |
+
" continue\n",
|
| 507 |
+
" keep_sids.append(s)\n",
|
| 508 |
+
" X.append(f)\n",
|
| 509 |
+
" T.append(onehot_target(tgt))\n",
|
| 510 |
+
" U.append(utmos_tr.get(s, 3.0) if USE_UTMOS_FEAT else 0.0)\n",
|
| 511 |
+
" y_qmos.append(lab.loc[s, \"qmos\"]); y_emos.append(lab.loc[s, \"emos\"])\n",
|
| 512 |
+
" y_vad.append([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]])\n",
|
| 513 |
+
" y_cat.append([lab.loc[s, f\"cat{i}\"] for i in range(len(EMOTIONS5))])\n",
|
| 514 |
+
"\n",
|
| 515 |
+
"X = np.stack(X).astype(np.float32)\n",
|
| 516 |
+
"T = np.stack(T).astype(np.float32)\n",
|
| 517 |
+
"U = np.array(U, dtype=np.float32).reshape(-1, 1)\n",
|
| 518 |
+
"y_qmos = np.array(y_qmos, dtype=np.float32); y_emos = np.array(y_emos, dtype=np.float32)\n",
|
| 519 |
+
"y_vad = np.array(y_vad, dtype=np.float32); y_cat = np.array(y_cat, dtype=np.float32)\n",
|
| 520 |
+
"FEAT_DIM = X.shape[1]\n",
|
| 521 |
+
"print(f\"Train giữ lại: {len(keep_sids)} wav | X={X.shape} | Mamba={'ON' if USE_MAMBA else 'OFF'}\")\n",
|
| 522 |
+
"\n",
|
| 523 |
+
"# Chuẩn hóa feature pooled + UTMOS + nhãn liên tục (z-score)\n",
|
| 524 |
+
"feat_mean = X.mean(0, keepdims=True); feat_std = X.std(0, keepdims=True) + 1e-6\n",
|
| 525 |
+
"Xn = (X - feat_mean) / feat_std\n",
|
| 526 |
+
"u_mu, u_sd = float(U.mean()), float(U.std() + 1e-6); Un = (U - u_mu) / u_sd\n",
|
| 527 |
+
"qmos_mu, qmos_sd = float(y_qmos.mean()), float(y_qmos.std() + 1e-6); y_qmos_z = (y_qmos - qmos_mu) / qmos_sd\n",
|
| 528 |
+
"emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6); y_emos_z = (y_emos - emos_mu) / emos_sd\n",
|
| 529 |
+
"if HAS_VAD:\n",
|
| 530 |
+
" vad_mu = np.nanmean(y_vad, axis=0); vad_sd = np.nanstd(y_vad, axis=0) + 1e-6\n",
|
| 531 |
+
" y_vad_z = (y_vad - vad_mu) / vad_sd\n",
|
| 532 |
+
"else:\n",
|
| 533 |
+
" vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32); y_vad_z = np.zeros_like(y_vad)"
|
| 534 |
+
]
|
| 535 |
+
},
|
| 536 |
+
{
|
| 537 |
+
"cell_type": "markdown",
|
| 538 |
+
"id": "5f0a94ff",
|
| 539 |
+
"metadata": {},
|
| 540 |
+
"source": [
|
| 541 |
+
"## 5a. Khối MAMBA (thuần PyTorch, không cần `mamba-ssm`)\n",
|
| 542 |
+
"Tự dùng `mamba-ssm` nếu import được (nhanh hơn); nếu không → bản thuần PyTorch (selective scan vòng lặp thời gian).\n",
|
| 543 |
+
"Bản này theo \"mamba-minimal\" (johnma2006) — đúng công thức, chỉ chậm hơn kernel CUDA, nhưng head nhỏ nên OK trên T4."
|
| 544 |
+
]
|
| 545 |
+
},
|
| 546 |
+
{
|
| 547 |
+
"cell_type": "code",
|
| 548 |
+
"execution_count": null,
|
| 549 |
+
"id": "535fcd63",
|
| 550 |
+
"metadata": {
|
| 551 |
+
"lines_to_next_cell": 1
|
| 552 |
+
},
|
| 553 |
+
"outputs": [],
|
| 554 |
+
"source": [
|
| 555 |
+
"import math\n",
|
| 556 |
+
"import torch.nn as nn\n",
|
| 557 |
+
"\n",
|
| 558 |
+
"try:\n",
|
| 559 |
+
" from mamba_ssm import Mamba as _OfficialMamba # nếu cài được thì dùng (tùy chọn)\n",
|
| 560 |
+
" _HAS_MAMBA_SSM = True\n",
|
| 561 |
+
" print(\"✅ Dùng mamba-ssm (CUDA kernel)\")\n",
|
| 562 |
+
"except Exception:\n",
|
| 563 |
+
" _HAS_MAMBA_SSM = False\n",
|
| 564 |
+
" print(\"ℹ️ Không có mamba-ssm → dùng Mamba thuần PyTorch (nhúng sẵn)\")\n",
|
| 565 |
+
"\n",
|
| 566 |
+
"class MambaBlockTorch(nn.Module):\n",
|
| 567 |
+
" \"\"\"Một khối Mamba (selective SSM) thuần PyTorch. d_model = chiều ẩn.\"\"\"\n",
|
| 568 |
+
" def __init__(self, d_model, d_state=16, d_conv=4, expand=2):\n",
|
| 569 |
+
" super().__init__()\n",
|
| 570 |
+
" self.d_inner = expand * d_model\n",
|
| 571 |
+
" self.dt_rank = math.ceil(d_model / 16)\n",
|
| 572 |
+
" self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)\n",
|
| 573 |
+
" self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,\n",
|
| 574 |
+
" groups=self.d_inner, padding=d_conv - 1, bias=True)\n",
|
| 575 |
+
" self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)\n",
|
| 576 |
+
" self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)\n",
|
| 577 |
+
" A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)\n",
|
| 578 |
+
" self.A_log = nn.Parameter(torch.log(A)) # (d_inner, d_state)\n",
|
| 579 |
+
" self.D = nn.Parameter(torch.ones(self.d_inner))\n",
|
| 580 |
+
" self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)\n",
|
| 581 |
+
" self.d_state = d_state\n",
|
| 582 |
+
"\n",
|
| 583 |
+
" def forward(self, x): # x: (B, L, d_model)\n",
|
| 584 |
+
" B, L, _ = x.shape\n",
|
| 585 |
+
" xz = self.in_proj(x) # (B, L, 2*d_inner)\n",
|
| 586 |
+
" xin, z = xz.chunk(2, dim=-1)\n",
|
| 587 |
+
" xin = xin.transpose(1, 2) # (B, d_inner, L)\n",
|
| 588 |
+
" xin = self.conv1d(xin)[..., :L].transpose(1, 2) # (B, L, d_inner) causal conv\n",
|
| 589 |
+
" xin = F.silu(xin)\n",
|
| 590 |
+
" y = self._ssm(xin) # (B, L, d_inner)\n",
|
| 591 |
+
" y = y * F.silu(z)\n",
|
| 592 |
+
" return self.out_proj(y)\n",
|
| 593 |
+
"\n",
|
| 594 |
+
" def _ssm(self, x): # x: (B, L, d_inner)\n",
|
| 595 |
+
" A = -torch.exp(self.A_log) # (d_inner, d_state)\n",
|
| 596 |
+
" x_dbl = self.x_proj(x) # (B, L, dt_rank + 2*d_state)\n",
|
| 597 |
+
" delta, Bm, Cm = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)\n",
|
| 598 |
+
" delta = F.softplus(self.dt_proj(delta)) # (B, L, d_inner)\n",
|
| 599 |
+
" dA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)\n",
|
| 600 |
+
" dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1) # (B, L, d_inner, d_state)\n",
|
| 601 |
+
" h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)\n",
|
| 602 |
+
" ys = []\n",
|
| 603 |
+
" for t in range(x.shape[1]): # selective scan theo thời gian\n",
|
| 604 |
+
" h = dA[:, t] * h + dB_x[:, t]\n",
|
| 605 |
+
" ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1)) # (B, d_inner)\n",
|
| 606 |
+
" y = torch.stack(ys, dim=1) # (B, L, d_inner)\n",
|
| 607 |
+
" return y + x * self.D\n",
|
| 608 |
+
"\n",
|
| 609 |
+
"class MambaLayer(nn.Module):\n",
|
| 610 |
+
" \"\"\"Pre-norm residual quanh 1 khối Mamba (chọn official nếu có).\"\"\"\n",
|
| 611 |
+
" def __init__(self, d_model, d_state):\n",
|
| 612 |
+
" super().__init__()\n",
|
| 613 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 614 |
+
" if _HAS_MAMBA_SSM:\n",
|
| 615 |
+
" self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2)\n",
|
| 616 |
+
" else:\n",
|
| 617 |
+
" self.mix = MambaBlockTorch(d_model, d_state=d_state)\n",
|
| 618 |
+
"\n",
|
| 619 |
+
" def forward(self, x):\n",
|
| 620 |
+
" return x + self.mix(self.norm(x))\n",
|
| 621 |
+
"\n",
|
| 622 |
+
"class MambaEncoder(nn.Module):\n",
|
| 623 |
+
" \"\"\"1024 → d_model → [Mamba ×L] (2 chiều nếu BIDIRECTIONAL) → attentive-pool → Z_DIM.\"\"\"\n",
|
| 624 |
+
" def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):\n",
|
| 625 |
+
" super().__init__()\n",
|
| 626 |
+
" self.bidir = bidir\n",
|
| 627 |
+
" self.proj = nn.Linear(d_in, d_model)\n",
|
| 628 |
+
" self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
|
| 629 |
+
" if bidir:\n",
|
| 630 |
+
" self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
|
| 631 |
+
" self.attn = nn.Linear(d_model, 1) # attentive pooling\n",
|
| 632 |
+
" self.out = nn.Linear(d_model, z_dim)\n",
|
| 633 |
+
"\n",
|
| 634 |
+
" def _run(self, layers, h):\n",
|
| 635 |
+
" for L in layers:\n",
|
| 636 |
+
" h = L(h)\n",
|
| 637 |
+
" return h\n",
|
| 638 |
+
"\n",
|
| 639 |
+
" def forward(self, x, mask): # x: (B, L, 1024), mask: (B, L) bool\n",
|
| 640 |
+
" h = self.proj(x)\n",
|
| 641 |
+
" out = self._run(self.fwd, h)\n",
|
| 642 |
+
" if self.bidir:\n",
|
| 643 |
+
" rev = torch.flip(h, dims=[1])\n",
|
| 644 |
+
" out = out + torch.flip(self._run(self.bwd, rev), dims=[1])\n",
|
| 645 |
+
" a = self.attn(out).squeeze(-1) # (B, L)\n",
|
| 646 |
+
" a = a.masked_fill(~mask, float(\"-inf\"))\n",
|
| 647 |
+
" w = torch.softmax(a, dim=1).unsqueeze(-1) # (B, L, 1)\n",
|
| 648 |
+
" pooled = (out * w).sum(1) # (B, d_model)\n",
|
| 649 |
+
" return self.out(pooled) # (B, z_dim)"
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"cell_type": "markdown",
|
| 654 |
+
"id": "a1a3026b",
|
| 655 |
+
"metadata": {},
|
| 656 |
+
"source": [
|
| 657 |
+
"## 5b. Model fusion 6 head + nhánh Mamba + train loop"
|
| 658 |
+
]
|
| 659 |
+
},
|
| 660 |
+
{
|
| 661 |
+
"cell_type": "code",
|
| 662 |
+
"execution_count": null,
|
| 663 |
+
"id": "e5f743ef",
|
| 664 |
+
"metadata": {
|
| 665 |
+
"lines_to_next_cell": 1
|
| 666 |
+
},
|
| 667 |
+
"outputs": [],
|
| 668 |
+
"source": [
|
| 669 |
+
"from scipy.stats import spearmanr\n",
|
| 670 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 673 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 674 |
+
"idx_all = np.arange(X.shape[0])\n",
|
| 675 |
+
"tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)\n",
|
| 676 |
+
"\n",
|
| 677 |
+
"def to_t(a):\n",
|
| 678 |
+
" return torch.tensor(a, dtype=torch.float32, device=device)\n",
|
| 679 |
+
"\n",
|
| 680 |
+
"Xn_t, T_t, Un_t = to_t(Xn), to_t(T), to_t(Un)\n",
|
| 681 |
+
"qmos_t = to_t(y_qmos_z).unsqueeze(1); emos_t = to_t(y_emos_z).unsqueeze(1)\n",
|
| 682 |
+
"vad_t = to_t(y_vad_z); cat_t = to_t(y_cat)\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"class FusionMamba6(nn.Module):\n",
|
| 685 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo, use_utmos, use_mamba):\n",
|
| 686 |
+
" super().__init__()\n",
|
| 687 |
+
" self.use_utmos = use_utmos\n",
|
| 688 |
+
" self.use_mamba = use_mamba\n",
|
| 689 |
+
" z_extra = Z_DIM if use_mamba else 0\n",
|
| 690 |
+
" if use_mamba:\n",
|
| 691 |
+
" self.enc = MambaEncoder(1024, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL)\n",
|
| 692 |
+
" self.trunk = nn.Sequential(\n",
|
| 693 |
+
" nn.Linear(d_in + z_extra, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 694 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 695 |
+
" self.qmos = nn.Sequential(\n",
|
| 696 |
+
" nn.Linear(trunk_h + (1 if use_utmos else 0), head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 697 |
+
" self.emos = nn.Sequential(\n",
|
| 698 |
+
" nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 699 |
+
" self.cat = nn.Sequential(\n",
|
| 700 |
+
" nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 701 |
+
" self.vad = nn.Sequential(\n",
|
| 702 |
+
" nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 703 |
+
"\n",
|
| 704 |
+
" def forward(self, x, tgt, utmos, seq=None, mask=None):\n",
|
| 705 |
+
" if self.use_mamba:\n",
|
| 706 |
+
" z = self.enc(seq, mask)\n",
|
| 707 |
+
" x = torch.cat([x, z], dim=1)\n",
|
| 708 |
+
" h = self.trunk(x)\n",
|
| 709 |
+
" qmos_in = torch.cat([h, utmos], dim=1) if self.use_utmos else h\n",
|
| 710 |
+
" return self.qmos(qmos_in), self.emos(torch.cat([h, tgt], dim=1)), self.cat(h), self.vad(h)\n",
|
| 711 |
+
"\n",
|
| 712 |
+
"model = FusionMamba6(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO, USE_UTMOS_FEAT, USE_MAMBA).to(device)\n",
|
| 713 |
+
"n_par = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
| 714 |
+
"print(f\"Tham số train được: {n_par/1e6:.2f} M\")\n",
|
| 715 |
+
"\n",
|
| 716 |
+
"TASKS = [\"qmos\", \"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 717 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 718 |
+
"params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 719 |
+
"opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)\n",
|
| 720 |
+
"mse = nn.MSELoss(reduction=\"none\")\n",
|
| 721 |
+
"\n",
|
| 722 |
+
"def soft_ce(logits, target_dist):\n",
|
| 723 |
+
" return -(target_dist * F.log_softmax(logits, dim=1)).sum(dim=1)\n",
|
| 724 |
+
"\n",
|
| 725 |
+
"def task_losses(qmos_p, emos_p, cat_logits, vad_p, b):\n",
|
| 726 |
+
" L = {\"qmos\": mse(qmos_p, qmos_t[b]).mean(),\n",
|
| 727 |
+
" \"emos\": mse(emos_p, emos_t[b]).mean(),\n",
|
| 728 |
+
" \"cat\": soft_ce(cat_logits, cat_t[b]).mean()}\n",
|
| 729 |
+
" if HAS_VAD:\n",
|
| 730 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()\n",
|
| 731 |
+
" L[\"aro\"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()\n",
|
| 732 |
+
" L[\"dom\"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()\n",
|
| 733 |
+
" else:\n",
|
| 734 |
+
" z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 735 |
+
" return L\n",
|
| 736 |
+
"\n",
|
| 737 |
+
"def combine(L):\n",
|
| 738 |
+
" if USE_UNCERTAINTY:\n",
|
| 739 |
+
" return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
|
| 740 |
+
" return sum(LOSS_W[t] * L[t] for t in TASKS)\n",
|
| 741 |
+
"\n",
|
| 742 |
+
"# batch theo INDEX (vì nhánh Mamba cần đọc chuỗi theo sid → collate động)\n",
|
| 743 |
+
"sids_arr = np.array(keep_sids)\n",
|
| 744 |
+
"\n",
|
| 745 |
+
"def forward_batch(bidx):\n",
|
| 746 |
+
" \"\"\"bidx: numpy index. Trả output model cho batch (tự collate chuỗi nếu bật Mamba).\"\"\"\n",
|
| 747 |
+
" bt = torch.tensor(bidx, device=device)\n",
|
| 748 |
+
" if USE_MAMBA:\n",
|
| 749 |
+
" seq, mask = collate_seqs(list(sids_arr[bidx]))\n",
|
| 750 |
+
" seq, mask = seq.to(device), mask.to(device)\n",
|
| 751 |
+
" return model(Xn_t[bt], T_t[bt], Un_t[bt], seq, mask)\n",
|
| 752 |
+
" return model(Xn_t[bt], T_t[bt], Un_t[bt])\n",
|
| 753 |
+
"\n",
|
| 754 |
+
"@torch.no_grad()\n",
|
| 755 |
+
"def eval_val():\n",
|
| 756 |
+
" model.eval()\n",
|
| 757 |
+
" qp, ep, vp = [], [], []\n",
|
| 758 |
+
" for i in range(0, len(va_idx), BATCH):\n",
|
| 759 |
+
" b = va_idx[i:i + BATCH]\n",
|
| 760 |
+
" q, e, _cl, v = forward_batch(b)\n",
|
| 761 |
+
" qp.append(q.cpu().numpy().ravel()); ep.append(e.cpu().numpy().ravel()); vp.append(v.cpu().numpy())\n",
|
| 762 |
+
" qp = np.concatenate(qp); ep = np.concatenate(ep); vp = np.concatenate(vp)\n",
|
| 763 |
+
" out = {\"qmos\": spearmanr(qp, y_qmos[va_idx]).correlation,\n",
|
| 764 |
+
" \"emos\": spearmanr(ep, y_emos[va_idx]).correlation}\n",
|
| 765 |
+
" if USE_UTMOS_FEAT:\n",
|
| 766 |
+
" out[\"qmos_utmos\"] = spearmanr(U[va_idx, 0], y_qmos[va_idx]).correlation\n",
|
| 767 |
+
" if HAS_VAD:\n",
|
| 768 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 769 |
+
" out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation\n",
|
| 770 |
+
" return out\n",
|
| 771 |
+
"\n",
|
| 772 |
+
"def val_score(m):\n",
|
| 773 |
+
" keys = [\"qmos\", \"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 774 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 775 |
+
"\n",
|
| 776 |
+
"best_score, best_state, bad = -1e9, None, 0\n",
|
| 777 |
+
"for ep_i in range(1, EPOCHS + 1):\n",
|
| 778 |
+
" model.train()\n",
|
| 779 |
+
" perm = np.random.permutation(tr_idx)\n",
|
| 780 |
+
" run = 0.0\n",
|
| 781 |
+
" for i in range(0, len(perm), BATCH):\n",
|
| 782 |
+
" b = perm[i:i + BATCH]\n",
|
| 783 |
+
" opt.zero_grad()\n",
|
| 784 |
+
" q, e, cl, v = forward_batch(b)\n",
|
| 785 |
+
" loss = combine(task_losses(q, e, cl, v, torch.tensor(b, device=device)))\n",
|
| 786 |
+
" loss.backward(); opt.step()\n",
|
| 787 |
+
" run += loss.item() * len(b)\n",
|
| 788 |
+
" m = eval_val(); sc = val_score(m)\n",
|
| 789 |
+
" if sc > best_score:\n",
|
| 790 |
+
" best_score = sc; bad = 0\n",
|
| 791 |
+
" best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}\n",
|
| 792 |
+
" else:\n",
|
| 793 |
+
" bad += 1\n",
|
| 794 |
+
" if ep_i % 2 == 0 or ep_i == 1:\n",
|
| 795 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"qmos\", \"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
|
| 796 |
+
" print(f\"epoch {ep_i:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}\")\n",
|
| 797 |
+
" if bad >= PATIENCE:\n",
|
| 798 |
+
" print(f\"Early stop ở epoch {ep_i}.\"); break\n",
|
| 799 |
+
"\n",
|
| 800 |
+
"model.load_state_dict(best_state)\n",
|
| 801 |
+
"final = eval_val()\n",
|
| 802 |
+
"print(f\"\\n✅ VAL (nội bộ) — exp14 (Mamba={'ON' if USE_MAMBA else 'OFF'}):\")\n",
|
| 803 |
+
"print(f\" QMOS={final['qmos']:.4f} (exp07 {EXP07['qmos']}) | EMOS={final['emos']:.4f} (exp07 {EXP07['emos']})\")\n",
|
| 804 |
+
"if HAS_VAD:\n",
|
| 805 |
+
" print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}\"\n",
|
| 806 |
+
" f\" (exp07 {EXP07['val']}/{EXP07['aro']}/{EXP07['dom']})\")\n",
|
| 807 |
+
"print(\" → So sánh USE_MAMBA True vs False = ablation Mamba cho paper.\")\n",
|
| 808 |
+
"\n",
|
| 809 |
+
"torch.save({\"state\": best_state, \"feat_mean\": feat_mean, \"feat_std\": feat_std,\n",
|
| 810 |
+
" \"u_mu\": u_mu, \"u_sd\": u_sd, \"qmos_mu\": qmos_mu, \"qmos_sd\": qmos_sd,\n",
|
| 811 |
+
" \"emos_mu\": emos_mu, \"emos_sd\": emos_sd, \"vad_mu\": vad_mu, \"vad_sd\": vad_sd,\n",
|
| 812 |
+
" \"FEAT_DIM\": FEAT_DIM, \"USE_MAMBA\": USE_MAMBA, \"val_score\": best_score},\n",
|
| 813 |
+
" os.path.join(OUT_DIR, \"fusion_mamba_mtl.pt\"))\n",
|
| 814 |
+
"print(\"Đã lưu\", os.path.join(OUT_DIR, \"fusion_mamba_mtl.pt\"))"
|
| 815 |
+
]
|
| 816 |
+
},
|
| 817 |
+
{
|
| 818 |
+
"cell_type": "markdown",
|
| 819 |
+
"id": "ea38383a",
|
| 820 |
+
"metadata": {},
|
| 821 |
+
"source": [
|
| 822 |
+
"## 6. Dự đoán DEV → `answer.txt` đủ 6 cột"
|
| 823 |
+
]
|
| 824 |
+
},
|
| 825 |
+
{
|
| 826 |
+
"cell_type": "code",
|
| 827 |
+
"execution_count": null,
|
| 828 |
+
"id": "6e774431",
|
| 829 |
+
"metadata": {
|
| 830 |
+
"lines_to_next_cell": 1
|
| 831 |
+
},
|
| 832 |
+
"outputs": [],
|
| 833 |
+
"source": [
|
| 834 |
+
"def list_dev():\n",
|
| 835 |
+
" with open(DEV_SCP) as f:\n",
|
| 836 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 837 |
+
"\n",
|
| 838 |
+
"dev_names = list_dev()\n",
|
| 839 |
+
"if LIMIT_DEV:\n",
|
| 840 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 841 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 842 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 843 |
+
"\n",
|
| 844 |
+
"e2v_dev = extract_e2v(dev_stems, \"dev\") if USE_E2V else {}\n",
|
| 845 |
+
"sailer_dev = extract_sailer(dev_stems, \"dev\") if USE_SAILER else {}\n",
|
| 846 |
+
"utmos_dev = extract_utmos(dev_names, \"dev\") if USE_UTMOS_FEAT else {}\n",
|
| 847 |
+
"seq_dev = extract_wavlm_seq(dev_stems, \"dev\")\n",
|
| 848 |
+
"\n",
|
| 849 |
+
"@torch.no_grad()\n",
|
| 850 |
+
"def predict_all(sid):\n",
|
| 851 |
+
" f = audio_feature(sid, e2v_dev, sailer_dev)\n",
|
| 852 |
+
" if f is None:\n",
|
| 853 |
+
" return None\n",
|
| 854 |
+
" if USE_MAMBA and not os.path.exists(seq_path(sid)):\n",
|
| 855 |
+
" return None\n",
|
| 856 |
+
" fn = (f[None, :] - feat_mean) / feat_std\n",
|
| 857 |
+
" tgt = onehot_target(target_map.get(sid))[None, :]\n",
|
| 858 |
+
" u = np.array([[utmos_dev.get(sid, 3.0)]], dtype=np.float32); un = (u - u_mu) / u_sd\n",
|
| 859 |
+
" model.eval()\n",
|
| 860 |
+
" if USE_MAMBA:\n",
|
| 861 |
+
" seq, mask = collate_seqs([sid]); seq, mask = seq.to(device), mask.to(device)\n",
|
| 862 |
+
" q, e, cl, v = model(to_t(fn), to_t(tgt), to_t(un), seq, mask)\n",
|
| 863 |
+
" else:\n",
|
| 864 |
+
" q, e, cl, v = model(to_t(fn), to_t(tgt), to_t(un))\n",
|
| 865 |
+
" qmos = float(q.item()) * qmos_sd + qmos_mu\n",
|
| 866 |
+
" emos = float(e.item()) * emos_sd + emos_mu\n",
|
| 867 |
+
" cat5 = F.softmax(cl, dim=1)[0].cpu().numpy()\n",
|
| 868 |
+
" vad3 = v[0].cpu().numpy() * vad_sd + vad_mu\n",
|
| 869 |
+
" return qmos, emos, cat5, vad3\n",
|
| 870 |
+
"\n",
|
| 871 |
+
"def fmt_cat(probs5):\n",
|
| 872 |
+
" return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 873 |
+
"\n",
|
| 874 |
+
"def build_answer(out_path):\n",
|
| 875 |
+
" from tqdm.auto import tqdm\n",
|
| 876 |
+
" n_real = n_default = 0\n",
|
| 877 |
+
" with open(out_path, \"w\") as f:\n",
|
| 878 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 879 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 880 |
+
" sid = stem(name)\n",
|
| 881 |
+
" pred = predict_all(sid)\n",
|
| 882 |
+
" if pred is None:\n",
|
| 883 |
+
" qmos = utmos_dev.get(sid, 3.0)\n",
|
| 884 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])\n",
|
| 885 |
+
" n_default += 1\n",
|
| 886 |
+
" else:\n",
|
| 887 |
+
" qmos, emos, cat5, vad3 = pred; n_real += 1\n",
|
| 888 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
|
| 889 |
+
" f\"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 890 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | head thật {n_real}, mặc định {n_default}\")\n",
|
| 891 |
+
"\n",
|
| 892 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 893 |
+
"build_answer(answer_path)"
|
| 894 |
+
]
|
| 895 |
+
},
|
| 896 |
+
{
|
| 897 |
+
"cell_type": "markdown",
|
| 898 |
+
"id": "bcab20d3",
|
| 899 |
+
"metadata": {},
|
| 900 |
+
"source": [
|
| 901 |
+
"## 7. Validate + đóng zip"
|
| 902 |
+
]
|
| 903 |
+
},
|
| 904 |
+
{
|
| 905 |
+
"cell_type": "code",
|
| 906 |
+
"execution_count": null,
|
| 907 |
+
"id": "e9b2e0ab",
|
| 908 |
+
"metadata": {},
|
| 909 |
+
"outputs": [],
|
| 910 |
+
"source": [
|
| 911 |
+
"def validate(path):\n",
|
| 912 |
+
" import csv\n",
|
| 913 |
+
" with open(path) as f:\n",
|
| 914 |
+
" rows = list(csv.reader(f))\n",
|
| 915 |
+
" header = rows[0]\n",
|
| 916 |
+
" assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
|
| 917 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 918 |
+
" assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
|
| 919 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
|
| 920 |
+
"\n",
|
| 921 |
+
"validate(answer_path)\n",
|
| 922 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp14_mamba.zip answer.txt \"\n",
|
| 923 |
+
" f\"&& unzip -l submission_track2_exp14_mamba.zip\")\n",
|
| 924 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp14_mamba.zip\"))"
|
| 925 |
+
]
|
| 926 |
+
},
|
| 927 |
+
{
|
| 928 |
+
"cell_type": "markdown",
|
| 929 |
+
"id": "7604df81",
|
| 930 |
+
"metadata": {},
|
| 931 |
+
"source": [
|
| 932 |
+
"## Ghi chú\n",
|
| 933 |
+
"- **Ablation chính cho paper:** chạy 2 lần — `USE_MAMBA=False` (= exp07, mốc) và `USE_MAMBA=True`.\n",
|
| 934 |
+
" So QMOS/EMOS/VAD nội bộ → trả lời \"bộ mã hóa thời gian Mamba có hơn mean-pooling không?\".\n",
|
| 935 |
+
"- **Nếu hết đĩa khi cache chuỗi:** giảm `MAX_FRAMES` (256→160) hoặc xóa `wavlm_seq_cache/` sau khi chạy xong.\n",
|
| 936 |
+
"- **Nếu Mamba chậm:** thử `pip install mamba-ssm causal-conv1d` (file tự dùng nếu import được); hoặc giảm\n",
|
| 937 |
+
" `MAMBA_LAYERS`/`MAX_FRAMES`. Bản thuần PyTorch dùng vòng lặp thời gian nên chậm hơn kernel CUDA.\n",
|
| 938 |
+
"- **Save Version** để giữ cache `fusion_cache/` + `wavlm_seq_cache/` cho lần sau.\n",
|
| 939 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp14)."
|
| 940 |
+
]
|
| 941 |
+
}
|
| 942 |
+
],
|
| 943 |
+
"metadata": {
|
| 944 |
+
"jupytext": {
|
| 945 |
+
"cell_metadata_filter": "-all",
|
| 946 |
+
"main_language": "python",
|
| 947 |
+
"notebook_metadata_filter": "-all"
|
| 948 |
+
}
|
| 949 |
+
},
|
| 950 |
+
"nbformat": 4,
|
| 951 |
+
"nbformat_minor": 5
|
| 952 |
+
}
|
track2/exp14_mamba_head_pipeline.py
ADDED
|
@@ -0,0 +1,798 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp14 (MAMBA temporal head, CỘNG vào FUSION 6 cột) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Ý tưởng (theo gợi ý mentor "thử Mamba"):** exp04/exp07 đều **mean-pool** đặc trưng SSL →
|
| 5 |
+
# mỗi wav thành 1 vector → mất hết **động lực theo thời gian** (lên/xuống giọng, ngắt quãng, rung).
|
| 6 |
+
# **Mamba** là State Space Model (SSM) xử lý **chuỗi** với độ phức tạp tuyến tính → cho nó **dãy frame**
|
| 7 |
+
# (chưa pool) để học temporal dynamics, rồi mới pool. Tham khảo: MambaRate (AudioMOS 2025), arXiv:2507.12090.
|
| 8 |
+
#
|
| 9 |
+
# ## exp14 = exp07 + 1 nhánh Mamba (CỘNG thêm, không thay thế)
|
| 10 |
+
# ```
|
| 11 |
+
# ┌─ đặc trưng POOLED [e2v_emb|e2v_p5|sailer_emb|sailer_p9|sailer_vad3] (y hệt exp07 → DÙNG LẠI cache)
|
| 12 |
+
# mỗi wav ──┤
|
| 13 |
+
# └─ WavLM frame-level (chuỗi T×1024) ─► Mamba (2 lớp, 2 chiều) ─► attn-pool ─► z_seq (Z chiều)
|
| 14 |
+
# │
|
| 15 |
+
# concat ──► TRUNK chung ──► 6 head: QMOS · EMOS · CAT · VAL · ARO · DOM
|
| 16 |
+
# ```
|
| 17 |
+
# - **Cờ `USE_MAMBA`:** `False` → chạy ra **đúng exp07** (kiểm chứng tái lập ~0.548/0.795). `True` → bật nhánh Mamba.
|
| 18 |
+
# Đây CHÍNH là **ablation "có/không Mamba"** cho paper.
|
| 19 |
+
# - WavLM **đóng băng** (chỉ trích đặc trưng) → Mamba head nhỏ → train nhanh, vừa T4.
|
| 20 |
+
#
|
| 21 |
+
# ## 2 gotcha Kaggle đã xử trong file
|
| 22 |
+
# 1. `mamba-ssm` hay lỗi build CUDA → **nhúng sẵn Mamba thuần PyTorch** (không cần pip); tự dùng `mamba-ssm` nếu import được.
|
| 23 |
+
# 2. Cache frame-level RẤT nặng → **cap `MAX_FRAMES`** + lưu **fp16**. Ước lượng: MAX_FRAMES=256, 1024 chiều, fp16
|
| 24 |
+
# ≈ 0.5 MB/wav → train ~12k ≈ 6 GB, dev ~2.7k ≈ 1.4 GB (vừa /kaggle/working). **Save Version** để giữ cache.
|
| 25 |
+
#
|
| 26 |
+
# **Cách chạy:** GPU T4 + Internet On → Add Input dataset Track 2 → sửa `DATA_ROOT` → Run All.
|
| 27 |
+
# Lần đầu đặt `LIMIT_TRAIN=300`, `LIMIT_DEV=20` để soi nhanh; OK rồi đặt `None`.
|
| 28 |
+
|
| 29 |
+
# %% [markdown]
|
| 30 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 31 |
+
|
| 32 |
+
# %%
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
|
| 36 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 37 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
|
| 38 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
|
| 39 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 40 |
+
|
| 41 |
+
OUT_DIR = "/kaggle/working"
|
| 42 |
+
CACHE_DIR = "/kaggle/working/fusion_cache" # DÙNG CHUNG với exp04/exp07 (e2v_*, sailer_*, utmos_*)
|
| 43 |
+
SEQ_DIR = "/kaggle/working/wavlm_seq_cache" # MỚI: cache frame-level WavLM (fp16)
|
| 44 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 45 |
+
os.makedirs(SEQ_DIR, exist_ok=True)
|
| 46 |
+
|
| 47 |
+
# ── Bật/tắt nhánh Mamba (ablation chính) ─────────────────────────────────────
|
| 48 |
+
USE_MAMBA = True # False → ra ĐÚNG exp07 (sanity check). True → bật nhánh Mamba.
|
| 49 |
+
|
| 50 |
+
# ── Siêu tham số nhánh Mamba ─────────────────────────────────────────────────
|
| 51 |
+
WAVLM_NAME = "microsoft/wavlm-large" # backbone frame-level (đóng băng). Trả chuỗi (T, 1024).
|
| 52 |
+
MAX_FRAMES = 256 # cap độ dài chuỗi (256 frame ≈ 5.1s @ 50Hz). Giảm nếu hết đĩa.
|
| 53 |
+
MAMBA_DMODEL = 256 # chiều ẩn của khối Mamba (proj 1024→256 trước khi vào Mamba)
|
| 54 |
+
MAMBA_LAYERS = 2 # số khối Mamba xếp chồng
|
| 55 |
+
MAMBA_DSTATE = 16 # chiều state SSM
|
| 56 |
+
BIDIRECTIONAL = True # chạy Mamba cả 2 chiều (xuôi + ngược) rồi cộng
|
| 57 |
+
Z_DIM = 128 # chiều vector z_seq sau attentive-pool, đem concat vào fusion
|
| 58 |
+
|
| 59 |
+
# ── Siêu tham số fusion (giống exp07) ────────────────────────────────────────
|
| 60 |
+
DEVICE = "cuda"
|
| 61 |
+
TRUNK_HIDDEN = 512
|
| 62 |
+
HEAD_HIDDEN = 128
|
| 63 |
+
DROPOUT = 0.3
|
| 64 |
+
LR = 1e-3
|
| 65 |
+
EPOCHS = 80
|
| 66 |
+
BATCH = 32 # nhỏ hơn exp07 (64) vì có nhánh Mamba tốn RAM hơn
|
| 67 |
+
VAL_FRAC = 0.10
|
| 68 |
+
PATIENCE = 15
|
| 69 |
+
SEED = 42
|
| 70 |
+
|
| 71 |
+
USE_UNCERTAINTY = True
|
| 72 |
+
LOSS_W = {"qmos": 1.0, "emos": 1.0, "cat": 1.0, "val": 1.0, "aro": 1.0, "dom": 1.0}
|
| 73 |
+
USE_E2V = True
|
| 74 |
+
USE_SAILER = True
|
| 75 |
+
USE_CLASSPROB = True
|
| 76 |
+
USE_UTMOS_FEAT = True
|
| 77 |
+
|
| 78 |
+
LIMIT_TRAIN = None
|
| 79 |
+
LIMIT_DEV = None
|
| 80 |
+
|
| 81 |
+
# Mốc exp07 để so (đây là hệ thống đang tốt nhất)
|
| 82 |
+
EXP07 = {"qmos": 0.548, "emos": 0.795, "cat_err": 0.153, "val": 0.581, "aro": 0.752, "dom": 0.705}
|
| 83 |
+
|
| 84 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 85 |
+
SAILER9 = ["Anger", "Contempt", "Disgust", "Fear", "Happiness", "Neutral", "Sadness", "Surprise", "Other"]
|
| 86 |
+
|
| 87 |
+
_EMO_ALIAS = {
|
| 88 |
+
"angry": "angry", "anger": "angry",
|
| 89 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 90 |
+
"neutral": "neutral", "calm": "neutral",
|
| 91 |
+
"sad": "sad", "sadness": "sad",
|
| 92 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
def norm_emotion(label):
|
| 96 |
+
key = str(label).strip().lower()
|
| 97 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 98 |
+
|
| 99 |
+
def stem(p):
|
| 100 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 101 |
+
|
| 102 |
+
assert USE_E2V or USE_SAILER, "Phải bật ít nhất 1 backbone pooled."
|
| 103 |
+
print("USE_MAMBA =", USE_MAMBA, "| nếu False → ra đúng exp07")
|
| 104 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 105 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 106 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 107 |
+
|
| 108 |
+
# %% [markdown]
|
| 109 |
+
# ## 1. Cài đặt + tải code SAILER
|
| 110 |
+
# Chỉ cài gói còn thiếu (Kaggle có sẵn torch/transformers). KHÔNG đụng numpy (tránh lệch ABI torch — bài học exp12).
|
| 111 |
+
|
| 112 |
+
# %%
|
| 113 |
+
import sys, subprocess
|
| 114 |
+
|
| 115 |
+
def pip_install(*pkgs):
|
| 116 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 117 |
+
|
| 118 |
+
pip_install("speechmos", "funasr", "librosa", "soundfile", "pandas", "scipy", "scikit-learn", "tqdm")
|
| 119 |
+
|
| 120 |
+
if USE_SAILER:
|
| 121 |
+
pip_install("loralib", "speechbrain")
|
| 122 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 123 |
+
if not os.path.exists(REPO_DIR):
|
| 124 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 125 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 126 |
+
if REPO_DIR not in sys.path:
|
| 127 |
+
sys.path.insert(0, REPO_DIR)
|
| 128 |
+
|
| 129 |
+
# %% [markdown]
|
| 130 |
+
# ## 2. Đọc & gộp nhãn theo wavID (giống exp07)
|
| 131 |
+
|
| 132 |
+
# %%
|
| 133 |
+
import numpy as np
|
| 134 |
+
import pandas as pd
|
| 135 |
+
|
| 136 |
+
def load_target_emotions():
|
| 137 |
+
tgt = {}
|
| 138 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 139 |
+
for ln in f:
|
| 140 |
+
parts = ln.strip().split("|")
|
| 141 |
+
if len(parts) < 2:
|
| 142 |
+
continue
|
| 143 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 144 |
+
return tgt
|
| 145 |
+
|
| 146 |
+
def _col(cols_map, *names, default_idx=None, df=None):
|
| 147 |
+
for n in names:
|
| 148 |
+
if n in cols_map:
|
| 149 |
+
return cols_map[n]
|
| 150 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 151 |
+
|
| 152 |
+
def parse_emocat_votes(cell):
|
| 153 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 154 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 155 |
+
e = norm_emotion(tok)
|
| 156 |
+
if e in EMOTIONS5:
|
| 157 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 158 |
+
return v
|
| 159 |
+
|
| 160 |
+
def load_train_labels():
|
| 161 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 162 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 163 |
+
wav_col = _col(cols, "wavid", "wav", default_idx=1, df=df)
|
| 164 |
+
qmos_col = _col(cols, "qmos", "mos")
|
| 165 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 166 |
+
val_col = _col(cols, "val", "valence")
|
| 167 |
+
aro_col = _col(cols, "aro", "arousal")
|
| 168 |
+
dom_col = _col(cols, "dom", "dominance")
|
| 169 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 170 |
+
assert qmos_col and emos_col, f"Thiếu cột qMOS/eMOS (cột: {list(df.columns)})"
|
| 171 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 172 |
+
rows = []
|
| 173 |
+
for sid, g in df.groupby("_stem"):
|
| 174 |
+
rec = {"wavID": sid, "qmos": float(g[qmos_col].mean()), "emos": float(g[emos_col].mean())}
|
| 175 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 176 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 177 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 178 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 179 |
+
if cat_col:
|
| 180 |
+
for cell in g[cat_col]:
|
| 181 |
+
votes += parse_emocat_votes(cell)
|
| 182 |
+
s = votes.sum()
|
| 183 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 1.0 / len(EMOTIONS5), dtype=np.float32)
|
| 184 |
+
for i in range(len(EMOTIONS5)):
|
| 185 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 186 |
+
rows.append(rec)
|
| 187 |
+
return pd.DataFrame(rows)
|
| 188 |
+
|
| 189 |
+
target_map = load_target_emotions()
|
| 190 |
+
train_df = load_train_labels()
|
| 191 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 192 |
+
print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
|
| 193 |
+
|
| 194 |
+
# %% [markdown]
|
| 195 |
+
# ## 3. Đặc trưng POOLED (e2v + sailer + UTMOS) — TÁI DÙNG cache exp04/exp07
|
| 196 |
+
# (Y hệt exp07; nếu đã chạy exp07 thì cache `fusion_cache/` còn nguyên → không tính lại.)
|
| 197 |
+
|
| 198 |
+
# %%
|
| 199 |
+
import torch
|
| 200 |
+
import torch.nn.functional as F
|
| 201 |
+
|
| 202 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 203 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU")
|
| 204 |
+
|
| 205 |
+
def extract_e2v(stems, tag):
|
| 206 |
+
from tqdm.auto import tqdm
|
| 207 |
+
cache_path = os.path.join(CACHE_DIR, f"e2v_{tag}.npz")
|
| 208 |
+
store = {}
|
| 209 |
+
if os.path.exists(cache_path):
|
| 210 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 211 |
+
store = {k: z[k] for k in z.files}
|
| 212 |
+
print(f"[e2v/{tag}] nạp cache: {len(store)}")
|
| 213 |
+
todo = [s for s in stems if s not in store]
|
| 214 |
+
if todo:
|
| 215 |
+
from funasr import AutoModel
|
| 216 |
+
m = AutoModel(model="iic/emotion2vec_plus_large", hub="hf", device=device)
|
| 217 |
+
for i, s in enumerate(tqdm(todo, desc=f"e2v {tag}")):
|
| 218 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 219 |
+
if not os.path.exists(wav):
|
| 220 |
+
continue
|
| 221 |
+
r = m.generate(wav, granularity="utterance", extract_embedding=True)[0]
|
| 222 |
+
emb = np.asarray(r["feats"], dtype=np.float32).reshape(-1)
|
| 223 |
+
probs = {e: 0.0 for e in EMOTIONS5}
|
| 224 |
+
for lab, sc in zip(r["labels"], r["scores"]):
|
| 225 |
+
name = lab.split("/")[-1]
|
| 226 |
+
if name in probs:
|
| 227 |
+
probs[name] = float(sc)
|
| 228 |
+
tot = sum(probs.values())
|
| 229 |
+
p5 = np.array([probs[e] / tot if tot > 0 else 0.2 for e in EMOTIONS5], dtype=np.float32)
|
| 230 |
+
store[s] = np.concatenate([emb, p5]).astype(np.float32)
|
| 231 |
+
if (i + 1) % 500 == 0:
|
| 232 |
+
np.savez(cache_path, **store)
|
| 233 |
+
np.savez(cache_path, **store)
|
| 234 |
+
del m
|
| 235 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 236 |
+
return {s: (v[:-5], v[-5:]) for s, v in store.items()}
|
| 237 |
+
|
| 238 |
+
def _pool_feat(features):
|
| 239 |
+
f = features.detach().cpu().numpy()
|
| 240 |
+
if f.ndim <= 1:
|
| 241 |
+
return f.reshape(-1).astype(np.float32)
|
| 242 |
+
return f.mean(axis=tuple(range(f.ndim - 1))).reshape(-1).astype(np.float32)
|
| 243 |
+
|
| 244 |
+
def extract_sailer(stems, tag):
|
| 245 |
+
import librosa
|
| 246 |
+
from tqdm.auto import tqdm
|
| 247 |
+
cache_path = os.path.join(CACHE_DIR, f"sailer_{tag}.npz")
|
| 248 |
+
store = {}
|
| 249 |
+
if os.path.exists(cache_path):
|
| 250 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 251 |
+
store = {k: z[k] for k in z.files}
|
| 252 |
+
print(f"[sailer/{tag}] nạp cache: {len(store)}")
|
| 253 |
+
todo = [s for s in stems if s not in store]
|
| 254 |
+
if todo:
|
| 255 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper
|
| 256 |
+
sailer = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion").to(device).eval()
|
| 257 |
+
with torch.no_grad():
|
| 258 |
+
for i, s in enumerate(tqdm(todo, desc=f"sailer {tag}")):
|
| 259 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 260 |
+
if not os.path.exists(wav):
|
| 261 |
+
continue
|
| 262 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 263 |
+
wave = wave[: 15 * 16000]
|
| 264 |
+
data = torch.from_numpy(wave).float().unsqueeze(0).to(device)
|
| 265 |
+
logits, feat, _det, arousal, valence, dominance = sailer(data, return_feature=True)
|
| 266 |
+
emb = _pool_feat(feat)
|
| 267 |
+
p9 = F.softmax(logits, dim=1)[0].detach().cpu().numpy().astype(np.float32)
|
| 268 |
+
vad3 = np.array([1 + 4 * float(valence.item()),
|
| 269 |
+
1 + 4 * float(arousal.item()),
|
| 270 |
+
1 + 4 * float(dominance.item())], dtype=np.float32)
|
| 271 |
+
store[s] = np.concatenate([emb, p9, vad3]).astype(np.float32)
|
| 272 |
+
if (i + 1) % 500 == 0:
|
| 273 |
+
np.savez(cache_path, **store)
|
| 274 |
+
np.savez(cache_path, **store)
|
| 275 |
+
del sailer
|
| 276 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 277 |
+
return {s: (v[:-12], v[-12:-3], v[-3:]) for s, v in store.items()}
|
| 278 |
+
|
| 279 |
+
def extract_utmos(names, tag):
|
| 280 |
+
import librosa
|
| 281 |
+
from tqdm.auto import tqdm
|
| 282 |
+
cache_path = os.path.join(CACHE_DIR, f"utmos_{tag}.npz")
|
| 283 |
+
store = {}
|
| 284 |
+
if os.path.exists(cache_path):
|
| 285 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 286 |
+
store = {k: float(z[k]) for k in z.files}
|
| 287 |
+
print(f"[utmos/{tag}] nạp cache: {len(store)}")
|
| 288 |
+
todo = [n for n in names if stem(n) not in store]
|
| 289 |
+
if todo:
|
| 290 |
+
predictor = torch.hub.load("tarepan/SpeechMOS:v1.2.0", "utmos22_strong",
|
| 291 |
+
trust_repo=True).to(device).eval()
|
| 292 |
+
with torch.no_grad():
|
| 293 |
+
for i, n in enumerate(tqdm(todo, desc=f"utmos {tag}")):
|
| 294 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else n + ".wav")
|
| 295 |
+
if not os.path.exists(wav):
|
| 296 |
+
continue
|
| 297 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 298 |
+
store[stem(n)] = float(predictor(torch.from_numpy(wave).unsqueeze(0).to(device),
|
| 299 |
+
sr=16000).mean().item())
|
| 300 |
+
if (i + 1) % 500 == 0:
|
| 301 |
+
np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
|
| 302 |
+
np.savez(cache_path, **{k: np.float32(v) for k, v in store.items()})
|
| 303 |
+
del predictor
|
| 304 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 305 |
+
return store
|
| 306 |
+
|
| 307 |
+
# %% [markdown]
|
| 308 |
+
# ## 3b. Đặc trưng FRAME-LEVEL WavLM (chuỗi T×1024) cho nhánh Mamba — cache fp16
|
| 309 |
+
# Mỗi wav lưu 1 file `.npy` riêng trong `SEQ_DIR` (mảng fp16 [T, 1024], T ≤ MAX_FRAMES).
|
| 310 |
+
# WavLM **đóng băng** (eval, no_grad) → layerdrop tự tắt ở eval, không đụng gotcha checkpoint.
|
| 311 |
+
|
| 312 |
+
# %%
|
| 313 |
+
_wavlm = None
|
| 314 |
+
def _get_wavlm():
|
| 315 |
+
"""Lazy-load microsoft/wavlm-large (đóng băng). Trả model + feature_extractor."""
|
| 316 |
+
global _wavlm
|
| 317 |
+
if _wavlm is None:
|
| 318 |
+
from transformers import WavLMModel, AutoFeatureExtractor
|
| 319 |
+
fe = AutoFeatureExtractor.from_pretrained(WAVLM_NAME)
|
| 320 |
+
mdl = WavLMModel.from_pretrained(WAVLM_NAME).to(device).eval()
|
| 321 |
+
for p in mdl.parameters():
|
| 322 |
+
p.requires_grad = False
|
| 323 |
+
_wavlm = (mdl, fe)
|
| 324 |
+
return _wavlm
|
| 325 |
+
|
| 326 |
+
def seq_path(sid):
|
| 327 |
+
return os.path.join(SEQ_DIR, sid + ".npy")
|
| 328 |
+
|
| 329 |
+
def extract_wavlm_seq(stems, tag):
|
| 330 |
+
"""Trích frame-level WavLM cho từng wav, cache fp16 ra .npy. Trả set stem đã có."""
|
| 331 |
+
if not USE_MAMBA:
|
| 332 |
+
return set()
|
| 333 |
+
import librosa
|
| 334 |
+
from tqdm.auto import tqdm
|
| 335 |
+
todo = [s for s in stems if not os.path.exists(seq_path(s))]
|
| 336 |
+
if todo:
|
| 337 |
+
mdl, fe = _get_wavlm()
|
| 338 |
+
with torch.no_grad():
|
| 339 |
+
for i, s in enumerate(tqdm(todo, desc=f"wavlm-seq {tag}")):
|
| 340 |
+
wav = os.path.join(WAV_DIR, s + ".wav")
|
| 341 |
+
if not os.path.exists(wav):
|
| 342 |
+
continue
|
| 343 |
+
wave, _ = librosa.load(wav, sr=16000, mono=True)
|
| 344 |
+
wave = wave[: 15 * 16000]
|
| 345 |
+
inp = fe(wave, sampling_rate=16000, return_tensors="pt").input_values.to(device)
|
| 346 |
+
hs = mdl(inp).last_hidden_state[0] # (T, 1024)
|
| 347 |
+
if hs.shape[0] > MAX_FRAMES: # cap độ dài (đều theo thời gian)
|
| 348 |
+
idx = torch.linspace(0, hs.shape[0] - 1, MAX_FRAMES).long()
|
| 349 |
+
hs = hs[idx]
|
| 350 |
+
np.save(seq_path(s), hs.cpu().numpy().astype(np.float16))
|
| 351 |
+
torch.cuda.empty_cache() if device == "cuda" else None
|
| 352 |
+
return {s for s in stems if os.path.exists(seq_path(s))}
|
| 353 |
+
|
| 354 |
+
def load_seq(sid):
|
| 355 |
+
"""Đọc chuỗi fp16 → tensor float32 (T, 1024). Thiếu file → None."""
|
| 356 |
+
p = seq_path(sid)
|
| 357 |
+
if not os.path.exists(p):
|
| 358 |
+
return None
|
| 359 |
+
return torch.from_numpy(np.load(p).astype(np.float32))
|
| 360 |
+
|
| 361 |
+
def collate_seqs(sids):
|
| 362 |
+
"""Gộp list chuỗi độ dài khác nhau → (B, Lmax, 1024) + mask (B, Lmax) bool (True=thật)."""
|
| 363 |
+
seqs = [load_seq(s) for s in sids]
|
| 364 |
+
lens = [t.shape[0] for t in seqs]
|
| 365 |
+
Lmax = max(lens)
|
| 366 |
+
B = len(seqs)
|
| 367 |
+
x = torch.zeros(B, Lmax, seqs[0].shape[1], dtype=torch.float32)
|
| 368 |
+
mask = torch.zeros(B, Lmax, dtype=torch.bool)
|
| 369 |
+
for i, t in enumerate(seqs):
|
| 370 |
+
x[i, : t.shape[0]] = t
|
| 371 |
+
mask[i, : t.shape[0]] = True
|
| 372 |
+
return x, mask
|
| 373 |
+
|
| 374 |
+
# %% [markdown]
|
| 375 |
+
# ## 4. Dựng feature pooled + nhãn cho train (lọc các wav đủ mọi nguồn)
|
| 376 |
+
|
| 377 |
+
# %%
|
| 378 |
+
train_stems = list(train_df["wavID"])
|
| 379 |
+
if LIMIT_TRAIN:
|
| 380 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 381 |
+
|
| 382 |
+
e2v_tr = extract_e2v(train_stems, "train") if USE_E2V else {}
|
| 383 |
+
sailer_tr = extract_sailer(train_stems, "train") if USE_SAILER else {}
|
| 384 |
+
utmos_tr = extract_utmos(train_stems, "train") if USE_UTMOS_FEAT else {}
|
| 385 |
+
seq_tr = extract_wavlm_seq(train_stems, "train")
|
| 386 |
+
|
| 387 |
+
def audio_feature(sid, e2v_map, sailer_map):
|
| 388 |
+
parts = []
|
| 389 |
+
if USE_E2V:
|
| 390 |
+
pk = e2v_map.get(sid)
|
| 391 |
+
if pk is None:
|
| 392 |
+
return None
|
| 393 |
+
emb, p5 = pk
|
| 394 |
+
parts.append(emb)
|
| 395 |
+
if USE_CLASSPROB:
|
| 396 |
+
parts.append(p5)
|
| 397 |
+
if USE_SAILER:
|
| 398 |
+
pk = sailer_map.get(sid)
|
| 399 |
+
if pk is None:
|
| 400 |
+
return None
|
| 401 |
+
emb, p9, vad3 = pk
|
| 402 |
+
parts.append(emb)
|
| 403 |
+
if USE_CLASSPROB:
|
| 404 |
+
parts.append(p9); parts.append(vad3)
|
| 405 |
+
return np.concatenate(parts).astype(np.float32)
|
| 406 |
+
|
| 407 |
+
def onehot_target(tgt):
|
| 408 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 409 |
+
if tgt in EMOTIONS5:
|
| 410 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 411 |
+
return v
|
| 412 |
+
|
| 413 |
+
lab = train_df.set_index("wavID")
|
| 414 |
+
keep_sids, X, T, U = [], [], [], []
|
| 415 |
+
y_qmos, y_emos, y_vad, y_cat = [], [], [], []
|
| 416 |
+
for s in train_stems:
|
| 417 |
+
f = audio_feature(s, e2v_tr, sailer_tr)
|
| 418 |
+
tgt = target_map.get(s)
|
| 419 |
+
if f is None or tgt is None or s not in lab.index:
|
| 420 |
+
continue
|
| 421 |
+
if USE_UTMOS_FEAT and s not in utmos_tr:
|
| 422 |
+
continue
|
| 423 |
+
if USE_MAMBA and s not in seq_tr: # cần có chuỗi WavLM nếu bật Mamba
|
| 424 |
+
continue
|
| 425 |
+
keep_sids.append(s)
|
| 426 |
+
X.append(f)
|
| 427 |
+
T.append(onehot_target(tgt))
|
| 428 |
+
U.append(utmos_tr.get(s, 3.0) if USE_UTMOS_FEAT else 0.0)
|
| 429 |
+
y_qmos.append(lab.loc[s, "qmos"]); y_emos.append(lab.loc[s, "emos"])
|
| 430 |
+
y_vad.append([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]])
|
| 431 |
+
y_cat.append([lab.loc[s, f"cat{i}"] for i in range(len(EMOTIONS5))])
|
| 432 |
+
|
| 433 |
+
X = np.stack(X).astype(np.float32)
|
| 434 |
+
T = np.stack(T).astype(np.float32)
|
| 435 |
+
U = np.array(U, dtype=np.float32).reshape(-1, 1)
|
| 436 |
+
y_qmos = np.array(y_qmos, dtype=np.float32); y_emos = np.array(y_emos, dtype=np.float32)
|
| 437 |
+
y_vad = np.array(y_vad, dtype=np.float32); y_cat = np.array(y_cat, dtype=np.float32)
|
| 438 |
+
FEAT_DIM = X.shape[1]
|
| 439 |
+
print(f"Train giữ lại: {len(keep_sids)} wav | X={X.shape} | Mamba={'ON' if USE_MAMBA else 'OFF'}")
|
| 440 |
+
|
| 441 |
+
# Chuẩn hóa feature pooled + UTMOS + nhãn liên tục (z-score)
|
| 442 |
+
feat_mean = X.mean(0, keepdims=True); feat_std = X.std(0, keepdims=True) + 1e-6
|
| 443 |
+
Xn = (X - feat_mean) / feat_std
|
| 444 |
+
u_mu, u_sd = float(U.mean()), float(U.std() + 1e-6); Un = (U - u_mu) / u_sd
|
| 445 |
+
qmos_mu, qmos_sd = float(y_qmos.mean()), float(y_qmos.std() + 1e-6); y_qmos_z = (y_qmos - qmos_mu) / qmos_sd
|
| 446 |
+
emos_mu, emos_sd = float(y_emos.mean()), float(y_emos.std() + 1e-6); y_emos_z = (y_emos - emos_mu) / emos_sd
|
| 447 |
+
if HAS_VAD:
|
| 448 |
+
vad_mu = np.nanmean(y_vad, axis=0); vad_sd = np.nanstd(y_vad, axis=0) + 1e-6
|
| 449 |
+
y_vad_z = (y_vad - vad_mu) / vad_sd
|
| 450 |
+
else:
|
| 451 |
+
vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32); y_vad_z = np.zeros_like(y_vad)
|
| 452 |
+
|
| 453 |
+
# %% [markdown]
|
| 454 |
+
# ## 5a. Khối MAMBA (thuần PyTorch, không cần `mamba-ssm`)
|
| 455 |
+
# Tự dùng `mamba-ssm` nếu import được (nhanh hơn); nếu không → bản thuần PyTorch (selective scan vòng lặp thời gian).
|
| 456 |
+
# Bản này theo "mamba-minimal" (johnma2006) — đúng công thức, chỉ chậm hơn kernel CUDA, nhưng head nhỏ nên OK trên T4.
|
| 457 |
+
|
| 458 |
+
# %%
|
| 459 |
+
import math
|
| 460 |
+
import torch.nn as nn
|
| 461 |
+
|
| 462 |
+
try:
|
| 463 |
+
from mamba_ssm import Mamba as _OfficialMamba # nếu cài được thì dùng (tùy chọn)
|
| 464 |
+
_HAS_MAMBA_SSM = True
|
| 465 |
+
print("✅ Dùng mamba-ssm (CUDA kernel)")
|
| 466 |
+
except Exception:
|
| 467 |
+
_HAS_MAMBA_SSM = False
|
| 468 |
+
print("ℹ️ Không có mamba-ssm → dùng Mamba thuần PyTorch (nhúng sẵn)")
|
| 469 |
+
|
| 470 |
+
class MambaBlockTorch(nn.Module):
|
| 471 |
+
"""Một khối Mamba (selective SSM) thuần PyTorch. d_model = chiều ẩn."""
|
| 472 |
+
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
|
| 473 |
+
super().__init__()
|
| 474 |
+
self.d_inner = expand * d_model
|
| 475 |
+
self.dt_rank = math.ceil(d_model / 16)
|
| 476 |
+
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
| 477 |
+
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,
|
| 478 |
+
groups=self.d_inner, padding=d_conv - 1, bias=True)
|
| 479 |
+
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
|
| 480 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
| 481 |
+
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
|
| 482 |
+
self.A_log = nn.Parameter(torch.log(A)) # (d_inner, d_state)
|
| 483 |
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
| 484 |
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
| 485 |
+
self.d_state = d_state
|
| 486 |
+
|
| 487 |
+
def forward(self, x): # x: (B, L, d_model)
|
| 488 |
+
B, L, _ = x.shape
|
| 489 |
+
xz = self.in_proj(x) # (B, L, 2*d_inner)
|
| 490 |
+
xin, z = xz.chunk(2, dim=-1)
|
| 491 |
+
xin = xin.transpose(1, 2) # (B, d_inner, L)
|
| 492 |
+
xin = self.conv1d(xin)[..., :L].transpose(1, 2) # (B, L, d_inner) causal conv
|
| 493 |
+
xin = F.silu(xin)
|
| 494 |
+
y = self._ssm(xin) # (B, L, d_inner)
|
| 495 |
+
y = y * F.silu(z)
|
| 496 |
+
return self.out_proj(y)
|
| 497 |
+
|
| 498 |
+
def _ssm(self, x): # x: (B, L, d_inner)
|
| 499 |
+
A = -torch.exp(self.A_log) # (d_inner, d_state)
|
| 500 |
+
x_dbl = self.x_proj(x) # (B, L, dt_rank + 2*d_state)
|
| 501 |
+
delta, Bm, Cm = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 502 |
+
delta = F.softplus(self.dt_proj(delta)) # (B, L, d_inner)
|
| 503 |
+
dA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, d_inner, d_state)
|
| 504 |
+
dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1) # (B, L, d_inner, d_state)
|
| 505 |
+
h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
|
| 506 |
+
ys = []
|
| 507 |
+
for t in range(x.shape[1]): # selective scan theo thời gian
|
| 508 |
+
h = dA[:, t] * h + dB_x[:, t]
|
| 509 |
+
ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1)) # (B, d_inner)
|
| 510 |
+
y = torch.stack(ys, dim=1) # (B, L, d_inner)
|
| 511 |
+
return y + x * self.D
|
| 512 |
+
|
| 513 |
+
class MambaLayer(nn.Module):
|
| 514 |
+
"""Pre-norm residual quanh 1 khối Mamba (chọn official nếu có)."""
|
| 515 |
+
def __init__(self, d_model, d_state):
|
| 516 |
+
super().__init__()
|
| 517 |
+
self.norm = nn.LayerNorm(d_model)
|
| 518 |
+
if _HAS_MAMBA_SSM:
|
| 519 |
+
self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2)
|
| 520 |
+
else:
|
| 521 |
+
self.mix = MambaBlockTorch(d_model, d_state=d_state)
|
| 522 |
+
|
| 523 |
+
def forward(self, x):
|
| 524 |
+
return x + self.mix(self.norm(x))
|
| 525 |
+
|
| 526 |
+
class MambaEncoder(nn.Module):
|
| 527 |
+
"""1024 → d_model → [Mamba ×L] (2 chiều nếu BIDIRECTIONAL) → attentive-pool → Z_DIM."""
|
| 528 |
+
def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):
|
| 529 |
+
super().__init__()
|
| 530 |
+
self.bidir = bidir
|
| 531 |
+
self.proj = nn.Linear(d_in, d_model)
|
| 532 |
+
self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
|
| 533 |
+
if bidir:
|
| 534 |
+
self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
|
| 535 |
+
self.attn = nn.Linear(d_model, 1) # attentive pooling
|
| 536 |
+
self.out = nn.Linear(d_model, z_dim)
|
| 537 |
+
|
| 538 |
+
def _run(self, layers, h):
|
| 539 |
+
for L in layers:
|
| 540 |
+
h = L(h)
|
| 541 |
+
return h
|
| 542 |
+
|
| 543 |
+
def forward(self, x, mask): # x: (B, L, 1024), mask: (B, L) bool
|
| 544 |
+
h = self.proj(x)
|
| 545 |
+
out = self._run(self.fwd, h)
|
| 546 |
+
if self.bidir:
|
| 547 |
+
rev = torch.flip(h, dims=[1])
|
| 548 |
+
out = out + torch.flip(self._run(self.bwd, rev), dims=[1])
|
| 549 |
+
a = self.attn(out).squeeze(-1) # (B, L)
|
| 550 |
+
a = a.masked_fill(~mask, float("-inf"))
|
| 551 |
+
w = torch.softmax(a, dim=1).unsqueeze(-1) # (B, L, 1)
|
| 552 |
+
pooled = (out * w).sum(1) # (B, d_model)
|
| 553 |
+
return self.out(pooled) # (B, z_dim)
|
| 554 |
+
|
| 555 |
+
# %% [markdown]
|
| 556 |
+
# ## 5b. Model fusion 6 head + nhánh Mamba + train loop
|
| 557 |
+
|
| 558 |
+
# %%
|
| 559 |
+
from scipy.stats import spearmanr
|
| 560 |
+
from sklearn.model_selection import train_test_split
|
| 561 |
+
|
| 562 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 563 |
+
N_EMO = len(EMOTIONS5)
|
| 564 |
+
idx_all = np.arange(X.shape[0])
|
| 565 |
+
tr_idx, va_idx = train_test_split(idx_all, test_size=VAL_FRAC, random_state=SEED)
|
| 566 |
+
|
| 567 |
+
def to_t(a):
|
| 568 |
+
return torch.tensor(a, dtype=torch.float32, device=device)
|
| 569 |
+
|
| 570 |
+
Xn_t, T_t, Un_t = to_t(Xn), to_t(T), to_t(Un)
|
| 571 |
+
qmos_t = to_t(y_qmos_z).unsqueeze(1); emos_t = to_t(y_emos_z).unsqueeze(1)
|
| 572 |
+
vad_t = to_t(y_vad_z); cat_t = to_t(y_cat)
|
| 573 |
+
|
| 574 |
+
class FusionMamba6(nn.Module):
|
| 575 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo, use_utmos, use_mamba):
|
| 576 |
+
super().__init__()
|
| 577 |
+
self.use_utmos = use_utmos
|
| 578 |
+
self.use_mamba = use_mamba
|
| 579 |
+
z_extra = Z_DIM if use_mamba else 0
|
| 580 |
+
if use_mamba:
|
| 581 |
+
self.enc = MambaEncoder(1024, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL)
|
| 582 |
+
self.trunk = nn.Sequential(
|
| 583 |
+
nn.Linear(d_in + z_extra, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 584 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 585 |
+
self.qmos = nn.Sequential(
|
| 586 |
+
nn.Linear(trunk_h + (1 if use_utmos else 0), head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 587 |
+
self.emos = nn.Sequential(
|
| 588 |
+
nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 589 |
+
self.cat = nn.Sequential(
|
| 590 |
+
nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 591 |
+
self.vad = nn.Sequential(
|
| 592 |
+
nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 593 |
+
|
| 594 |
+
def forward(self, x, tgt, utmos, seq=None, mask=None):
|
| 595 |
+
if self.use_mamba:
|
| 596 |
+
z = self.enc(seq, mask)
|
| 597 |
+
x = torch.cat([x, z], dim=1)
|
| 598 |
+
h = self.trunk(x)
|
| 599 |
+
qmos_in = torch.cat([h, utmos], dim=1) if self.use_utmos else h
|
| 600 |
+
return self.qmos(qmos_in), self.emos(torch.cat([h, tgt], dim=1)), self.cat(h), self.vad(h)
|
| 601 |
+
|
| 602 |
+
model = FusionMamba6(FEAT_DIM, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO, USE_UTMOS_FEAT, USE_MAMBA).to(device)
|
| 603 |
+
n_par = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 604 |
+
print(f"Tham số train được: {n_par/1e6:.2f} M")
|
| 605 |
+
|
| 606 |
+
TASKS = ["qmos", "emos", "cat", "val", "aro", "dom"]
|
| 607 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 608 |
+
params = list(model.parameters()) + ([log_var] if USE_UNCERTAINTY else [])
|
| 609 |
+
opt = torch.optim.Adam(params, lr=LR, weight_decay=1e-5)
|
| 610 |
+
mse = nn.MSELoss(reduction="none")
|
| 611 |
+
|
| 612 |
+
def soft_ce(logits, target_dist):
|
| 613 |
+
return -(target_dist * F.log_softmax(logits, dim=1)).sum(dim=1)
|
| 614 |
+
|
| 615 |
+
def task_losses(qmos_p, emos_p, cat_logits, vad_p, b):
|
| 616 |
+
L = {"qmos": mse(qmos_p, qmos_t[b]).mean(),
|
| 617 |
+
"emos": mse(emos_p, emos_t[b]).mean(),
|
| 618 |
+
"cat": soft_ce(cat_logits, cat_t[b]).mean()}
|
| 619 |
+
if HAS_VAD:
|
| 620 |
+
L["val"] = mse(vad_p[:, 0:1], vad_t[b, 0:1]).mean()
|
| 621 |
+
L["aro"] = mse(vad_p[:, 1:2], vad_t[b, 1:2]).mean()
|
| 622 |
+
L["dom"] = mse(vad_p[:, 2:3], vad_t[b, 2:3]).mean()
|
| 623 |
+
else:
|
| 624 |
+
z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
|
| 625 |
+
return L
|
| 626 |
+
|
| 627 |
+
def combine(L):
|
| 628 |
+
if USE_UNCERTAINTY:
|
| 629 |
+
return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
|
| 630 |
+
return sum(LOSS_W[t] * L[t] for t in TASKS)
|
| 631 |
+
|
| 632 |
+
# batch theo INDEX (vì nhánh Mamba cần đọc chuỗi theo sid → collate động)
|
| 633 |
+
sids_arr = np.array(keep_sids)
|
| 634 |
+
|
| 635 |
+
def forward_batch(bidx):
|
| 636 |
+
"""bidx: numpy index. Trả output model cho batch (tự collate chuỗi nếu bật Mamba)."""
|
| 637 |
+
bt = torch.tensor(bidx, device=device)
|
| 638 |
+
if USE_MAMBA:
|
| 639 |
+
seq, mask = collate_seqs(list(sids_arr[bidx]))
|
| 640 |
+
seq, mask = seq.to(device), mask.to(device)
|
| 641 |
+
return model(Xn_t[bt], T_t[bt], Un_t[bt], seq, mask)
|
| 642 |
+
return model(Xn_t[bt], T_t[bt], Un_t[bt])
|
| 643 |
+
|
| 644 |
+
@torch.no_grad()
|
| 645 |
+
def eval_val():
|
| 646 |
+
model.eval()
|
| 647 |
+
qp, ep, vp = [], [], []
|
| 648 |
+
for i in range(0, len(va_idx), BATCH):
|
| 649 |
+
b = va_idx[i:i + BATCH]
|
| 650 |
+
q, e, _cl, v = forward_batch(b)
|
| 651 |
+
qp.append(q.cpu().numpy().ravel()); ep.append(e.cpu().numpy().ravel()); vp.append(v.cpu().numpy())
|
| 652 |
+
qp = np.concatenate(qp); ep = np.concatenate(ep); vp = np.concatenate(vp)
|
| 653 |
+
out = {"qmos": spearmanr(qp, y_qmos[va_idx]).correlation,
|
| 654 |
+
"emos": spearmanr(ep, y_emos[va_idx]).correlation}
|
| 655 |
+
if USE_UTMOS_FEAT:
|
| 656 |
+
out["qmos_utmos"] = spearmanr(U[va_idx, 0], y_qmos[va_idx]).correlation
|
| 657 |
+
if HAS_VAD:
|
| 658 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 659 |
+
out[t] = spearmanr(vp[:, j], y_vad[va_idx, j]).correlation
|
| 660 |
+
return out
|
| 661 |
+
|
| 662 |
+
def val_score(m):
|
| 663 |
+
keys = ["qmos", "emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 664 |
+
return float(np.mean([m[k] for k in keys]))
|
| 665 |
+
|
| 666 |
+
best_score, best_state, bad = -1e9, None, 0
|
| 667 |
+
for ep_i in range(1, EPOCHS + 1):
|
| 668 |
+
model.train()
|
| 669 |
+
perm = np.random.permutation(tr_idx)
|
| 670 |
+
run = 0.0
|
| 671 |
+
for i in range(0, len(perm), BATCH):
|
| 672 |
+
b = perm[i:i + BATCH]
|
| 673 |
+
opt.zero_grad()
|
| 674 |
+
q, e, cl, v = forward_batch(b)
|
| 675 |
+
loss = combine(task_losses(q, e, cl, v, torch.tensor(b, device=device)))
|
| 676 |
+
loss.backward(); opt.step()
|
| 677 |
+
run += loss.item() * len(b)
|
| 678 |
+
m = eval_val(); sc = val_score(m)
|
| 679 |
+
if sc > best_score:
|
| 680 |
+
best_score = sc; bad = 0
|
| 681 |
+
best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
| 682 |
+
else:
|
| 683 |
+
bad += 1
|
| 684 |
+
if ep_i % 2 == 0 or ep_i == 1:
|
| 685 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in ["qmos", "emos", "val", "aro", "dom"] if k in m)
|
| 686 |
+
print(f"epoch {ep_i:3d} | loss {run/len(perm):.4f} | {msg} | best {best_score:.4f}")
|
| 687 |
+
if bad >= PATIENCE:
|
| 688 |
+
print(f"Early stop ở epoch {ep_i}."); break
|
| 689 |
+
|
| 690 |
+
model.load_state_dict(best_state)
|
| 691 |
+
final = eval_val()
|
| 692 |
+
print(f"\n✅ VAL (nội bộ) — exp14 (Mamba={'ON' if USE_MAMBA else 'OFF'}):")
|
| 693 |
+
print(f" QMOS={final['qmos']:.4f} (exp07 {EXP07['qmos']}) | EMOS={final['emos']:.4f} (exp07 {EXP07['emos']})")
|
| 694 |
+
if HAS_VAD:
|
| 695 |
+
print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f}"
|
| 696 |
+
f" (exp07 {EXP07['val']}/{EXP07['aro']}/{EXP07['dom']})")
|
| 697 |
+
print(" → So sánh USE_MAMBA True vs False = ablation Mamba cho paper.")
|
| 698 |
+
|
| 699 |
+
torch.save({"state": best_state, "feat_mean": feat_mean, "feat_std": feat_std,
|
| 700 |
+
"u_mu": u_mu, "u_sd": u_sd, "qmos_mu": qmos_mu, "qmos_sd": qmos_sd,
|
| 701 |
+
"emos_mu": emos_mu, "emos_sd": emos_sd, "vad_mu": vad_mu, "vad_sd": vad_sd,
|
| 702 |
+
"FEAT_DIM": FEAT_DIM, "USE_MAMBA": USE_MAMBA, "val_score": best_score},
|
| 703 |
+
os.path.join(OUT_DIR, "fusion_mamba_mtl.pt"))
|
| 704 |
+
print("Đã lưu", os.path.join(OUT_DIR, "fusion_mamba_mtl.pt"))
|
| 705 |
+
|
| 706 |
+
# %% [markdown]
|
| 707 |
+
# ## 6. Dự đoán DEV → `answer.txt` đủ 6 cột
|
| 708 |
+
|
| 709 |
+
# %%
|
| 710 |
+
def list_dev():
|
| 711 |
+
with open(DEV_SCP) as f:
|
| 712 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 713 |
+
|
| 714 |
+
dev_names = list_dev()
|
| 715 |
+
if LIMIT_DEV:
|
| 716 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 717 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 718 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 719 |
+
|
| 720 |
+
e2v_dev = extract_e2v(dev_stems, "dev") if USE_E2V else {}
|
| 721 |
+
sailer_dev = extract_sailer(dev_stems, "dev") if USE_SAILER else {}
|
| 722 |
+
utmos_dev = extract_utmos(dev_names, "dev") if USE_UTMOS_FEAT else {}
|
| 723 |
+
seq_dev = extract_wavlm_seq(dev_stems, "dev")
|
| 724 |
+
|
| 725 |
+
@torch.no_grad()
|
| 726 |
+
def predict_all(sid):
|
| 727 |
+
f = audio_feature(sid, e2v_dev, sailer_dev)
|
| 728 |
+
if f is None:
|
| 729 |
+
return None
|
| 730 |
+
if USE_MAMBA and not os.path.exists(seq_path(sid)):
|
| 731 |
+
return None
|
| 732 |
+
fn = (f[None, :] - feat_mean) / feat_std
|
| 733 |
+
tgt = onehot_target(target_map.get(sid))[None, :]
|
| 734 |
+
u = np.array([[utmos_dev.get(sid, 3.0)]], dtype=np.float32); un = (u - u_mu) / u_sd
|
| 735 |
+
model.eval()
|
| 736 |
+
if USE_MAMBA:
|
| 737 |
+
seq, mask = collate_seqs([sid]); seq, mask = seq.to(device), mask.to(device)
|
| 738 |
+
q, e, cl, v = model(to_t(fn), to_t(tgt), to_t(un), seq, mask)
|
| 739 |
+
else:
|
| 740 |
+
q, e, cl, v = model(to_t(fn), to_t(tgt), to_t(un))
|
| 741 |
+
qmos = float(q.item()) * qmos_sd + qmos_mu
|
| 742 |
+
emos = float(e.item()) * emos_sd + emos_mu
|
| 743 |
+
cat5 = F.softmax(cl, dim=1)[0].cpu().numpy()
|
| 744 |
+
vad3 = v[0].cpu().numpy() * vad_sd + vad_mu
|
| 745 |
+
return qmos, emos, cat5, vad3
|
| 746 |
+
|
| 747 |
+
def fmt_cat(probs5):
|
| 748 |
+
return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 749 |
+
|
| 750 |
+
def build_answer(out_path):
|
| 751 |
+
from tqdm.auto import tqdm
|
| 752 |
+
n_real = n_default = 0
|
| 753 |
+
with open(out_path, "w") as f:
|
| 754 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 755 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 756 |
+
sid = stem(name)
|
| 757 |
+
pred = predict_all(sid)
|
| 758 |
+
if pred is None:
|
| 759 |
+
qmos = utmos_dev.get(sid, 3.0)
|
| 760 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0])
|
| 761 |
+
n_default += 1
|
| 762 |
+
else:
|
| 763 |
+
qmos, emos, cat5, vad3 = pred; n_real += 1
|
| 764 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
|
| 765 |
+
f"{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 766 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | head thật {n_real}, mặc định {n_default}")
|
| 767 |
+
|
| 768 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 769 |
+
build_answer(answer_path)
|
| 770 |
+
|
| 771 |
+
# %% [markdown]
|
| 772 |
+
# ## 7. Validate + đóng zip
|
| 773 |
+
|
| 774 |
+
# %%
|
| 775 |
+
def validate(path):
|
| 776 |
+
import csv
|
| 777 |
+
with open(path) as f:
|
| 778 |
+
rows = list(csv.reader(f))
|
| 779 |
+
header = rows[0]
|
| 780 |
+
assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
|
| 781 |
+
for i, r in enumerate(rows[1:], 2):
|
| 782 |
+
assert len(r) == len(header), f"Dòng {i} sai số cột"
|
| 783 |
+
print(f"OK: {len(rows)-1} dòng, header = {header}")
|
| 784 |
+
|
| 785 |
+
validate(answer_path)
|
| 786 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp14_mamba.zip answer.txt "
|
| 787 |
+
f"&& unzip -l submission_track2_exp14_mamba.zip")
|
| 788 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp14_mamba.zip"))
|
| 789 |
+
|
| 790 |
+
# %% [markdown]
|
| 791 |
+
# ## Ghi chú
|
| 792 |
+
# - **Ablation chính cho paper:** chạy 2 lần — `USE_MAMBA=False` (= exp07, mốc) và `USE_MAMBA=True`.
|
| 793 |
+
# So QMOS/EMOS/VAD nội bộ → trả lời "bộ mã hóa thời gian Mamba có hơn mean-pooling không?".
|
| 794 |
+
# - **Nếu hết đĩa khi cache chuỗi:** giảm `MAX_FRAMES` (256→160) hoặc xóa `wavlm_seq_cache/` sau khi chạy xong.
|
| 795 |
+
# - **Nếu Mamba chậm:** thử `pip install mamba-ssm causal-conv1d` (file tự dùng nếu import được); hoặc giảm
|
| 796 |
+
# `MAMBA_LAYERS`/`MAX_FRAMES`. Bản thuần PyTorch dùng vòng lặp thời gian nên chậm hơn kernel CUDA.
|
| 797 |
+
# - **Save Version** để giữ cache `fusion_cache/` + `wavlm_seq_cache/` cho lần sau.
|
| 798 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp14).
|
track2/exp15_predict.ipynb
ADDED
|
@@ -0,0 +1,698 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "9f2c52f2",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp15 PREDICT-ONLY (nạp checkpoint → chấm DEV, KHÔNG train) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Mục đích:** bạn ĐÃ có checkpoint exp15 (`ft_mamba_emotion_full*.pt`, lưu cả backbone WavLM + Mamba enc + heads).\n",
|
| 11 |
+
"File này **chỉ inference**: dựng lại đúng kiến trúc → nạp trọng số + thống kê chuẩn hóa TỪ ckpt →\n",
|
| 12 |
+
"dự đoán 5 cột cảm xúc trên tập DEV → ghép QMOS (exp07/UTMOSv2) → `answer.txt` → zip nộp.\n",
|
| 13 |
+
"**KHÔNG** train, **KHÔNG** cần train.csv (chỉ cần wav DEV + metadata.csv để lấy cảm xúc target cho EMOS).\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"## Vì sao nhanh\n",
|
| 16 |
+
"- Không có vòng train → chỉ 1 lượt forward qua DEV (~2730 mẫu). Việc lâu nhất là trích audeering DEV\n",
|
| 17 |
+
" (~vài phút; có cache thì gần như tức thì).\n",
|
| 18 |
+
"\n",
|
| 19 |
+
"## Chuẩn bị input trên Kaggle (Add Input)\n",
|
| 20 |
+
"1. Dataset Track 2 (wav + `metadata.csv` + `sets/dev.scp`).\n",
|
| 21 |
+
"2. **Checkpoint** exp15: dataset chứa `ft_mamba_emotion_full*.pt` (vd `cache_exp8`). Auto-dò; hoặc trỏ `CKPT_PATH`.\n",
|
| 22 |
+
"3. (tùy chọn) cache audeering `aud_dev.npz` để khỏi trích lại.\n",
|
| 23 |
+
"4. (tùy chọn) `answer.txt` exp07 để mượn cột QMOS 0.548.\n",
|
| 24 |
+
"\n",
|
| 25 |
+
"**Cách chạy:** GPU **T4** + Internet **On** → Add Input → Run All."
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"cell_type": "markdown",
|
| 30 |
+
"id": "adbc7c65",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"source": [
|
| 33 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 34 |
+
]
|
| 35 |
+
},
|
| 36 |
+
{
|
| 37 |
+
"cell_type": "code",
|
| 38 |
+
"execution_count": null,
|
| 39 |
+
"id": "7eb066d5",
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"outputs": [],
|
| 42 |
+
"source": [
|
| 43 |
+
"import os, glob\n",
|
| 44 |
+
"\n",
|
| 45 |
+
"# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets + wav/ + metadata.csv) ──\n",
|
| 46 |
+
"def find_data_root(search_root=\"/kaggle/input\"):\n",
|
| 47 |
+
" cands = []\n",
|
| 48 |
+
" for dev_scp in glob.glob(os.path.join(search_root, \"**\", \"sets\", \"dev.scp\"), recursive=True):\n",
|
| 49 |
+
" root = os.path.dirname(os.path.dirname(dev_scp))\n",
|
| 50 |
+
" score = os.path.isdir(os.path.join(root, \"wav\")) + os.path.exists(os.path.join(root, \"metadata.csv\"))\n",
|
| 51 |
+
" cands.append((score, root))\n",
|
| 52 |
+
" cands.sort(reverse=True)\n",
|
| 53 |
+
" return cands\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"_cands = find_data_root(\"/kaggle/input\")\n",
|
| 56 |
+
"if _cands:\n",
|
| 57 |
+
" print(\"🔎 Ứng viên DATA_ROOT:\")\n",
|
| 58 |
+
" for sc, r in _cands:\n",
|
| 59 |
+
" print(f\" [{sc}/2] {r}\")\n",
|
| 60 |
+
" DATA_ROOT = _cands[0][1]\n",
|
| 61 |
+
" print(f\"👉 Tự chọn DATA_ROOT = {DATA_ROOT}\")\n",
|
| 62 |
+
"else:\n",
|
| 63 |
+
" DATA_ROOT = \"/kaggle/input/datasets/minhtoan2\" # dự phòng — sửa tay\n",
|
| 64 |
+
" print(f\"❌ Không thấy sets/dev.scp → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)\")\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 67 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) — lấy cảm xúc target cho EMOS\n",
|
| 68 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 71 |
+
"CACHE_DIR = \"/kaggle/working/ft_cache\"\n",
|
| 72 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 73 |
+
"\n",
|
| 74 |
+
"# ── CHECKPOINT exp15 (đủ backbone + Mamba + heads) ───────────────────────────\n",
|
| 75 |
+
"CKPT_PATH = \"\" # << \"\" = auto-dò ft_mamba_emotion_full*.pt; hoặc \"/kaggle/input/<slug>/ft_mamba_emotion_full (2).pt\"\n",
|
| 76 |
+
"\n",
|
| 77 |
+
"def find_ckpt(explicit):\n",
|
| 78 |
+
" \"\"\"Tìm checkpoint exp15. Khớp cả tên bị thêm hậu tố trùng, vd 'ft_mamba_emotion_full (2).pt'.\"\"\"\n",
|
| 79 |
+
" if explicit and os.path.exists(explicit):\n",
|
| 80 |
+
" return explicit\n",
|
| 81 |
+
" for base in [\"/kaggle/input\", \"/kaggle/working\"]:\n",
|
| 82 |
+
" hits = sorted(glob.glob(os.path.join(base, \"**\", \"ft_mamba_emotion_full*.pt\"), recursive=True))\n",
|
| 83 |
+
" if hits:\n",
|
| 84 |
+
" return hits[0]\n",
|
| 85 |
+
" return \"\"\n",
|
| 86 |
+
"\n",
|
| 87 |
+
"CKPT_PATH = find_ckpt(CKPT_PATH)\n",
|
| 88 |
+
"assert CKPT_PATH, \"❌ Không thấy checkpoint ft_mamba_emotion_full*.pt. Đã Add Input dataset chứa ckpt chưa?\"\n",
|
| 89 |
+
"print(\"✅ Dùng checkpoint:\", CKPT_PATH)\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"# (Tùy chọn) tái dùng cache audeering DEV — quét đệ quy (file có thể nằm trong archive/)\n",
|
| 92 |
+
"CACHE_INPUT = \"/kaggle/input/cache-exp8\" # << SỬA slug (hoặc \"\")\n",
|
| 93 |
+
"if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
|
| 94 |
+
" import shutil\n",
|
| 95 |
+
" _n = 0\n",
|
| 96 |
+
" for _fp in glob.glob(os.path.join(CACHE_INPUT, \"**\", \"aud_*.npz\"), recursive=True):\n",
|
| 97 |
+
" shutil.copy(_fp, os.path.join(CACHE_DIR, os.path.basename(_fp))); _n += 1\n",
|
| 98 |
+
" print(f\"📦 Copy {_n} file aud_*.npz từ {CACHE_INPUT}\")\n",
|
| 99 |
+
"\n",
|
| 100 |
+
"# Mượn cột QMOS exp07 (0.548). Trỏ answer.txt exp07 nếu có; không thì UTMOSv2.\n",
|
| 101 |
+
"EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn)\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"# ── Siêu tham số PHẢI KHỚP lúc train exp15 (ckpt không lưu các số này của Mamba) ──\n",
|
| 104 |
+
"MAMBA_DMODEL = 256\n",
|
| 105 |
+
"MAMBA_LAYERS = 2\n",
|
| 106 |
+
"MAMBA_DSTATE = 16\n",
|
| 107 |
+
"BIDIRECTIONAL = True\n",
|
| 108 |
+
"TRUNK_HIDDEN = 512\n",
|
| 109 |
+
"HEAD_HIDDEN = 128\n",
|
| 110 |
+
"DROPOUT = 0.3 # không ảnh hưởng eval (model.eval() tắt dropout) — chỉ để dựng đúng shape\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"DEVICE = \"cuda\"\n",
|
| 113 |
+
"SR = 16000\n",
|
| 114 |
+
"MAX_SECONDS = 6 # khớp lúc train (exp15 = 6)\n",
|
| 115 |
+
"USE_AMP = True\n",
|
| 116 |
+
"LIMIT_DEV = None # << để None chấm ĐỦ 2730; đặt 20 để smoke-test nhanh\n",
|
| 117 |
+
"\n",
|
| 118 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 119 |
+
"_EMO_ALIAS = {\n",
|
| 120 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 121 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 122 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 123 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 124 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 125 |
+
"}\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"def norm_emotion(label):\n",
|
| 128 |
+
" key = str(label).strip().lower()\n",
|
| 129 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 130 |
+
"\n",
|
| 131 |
+
"def stem(p):\n",
|
| 132 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 135 |
+
"for p in [WAV_DIR, METADATA_CSV, DEV_SCP, CKPT_PATH]:\n",
|
| 136 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "markdown",
|
| 141 |
+
"id": "febe8bdc",
|
| 142 |
+
"metadata": {},
|
| 143 |
+
"source": [
|
| 144 |
+
"## 1. Cài đặt + tải code SAILER (để dựng đúng kiến trúc WavLM rồi nạp ckpt đè lên)"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "code",
|
| 149 |
+
"execution_count": null,
|
| 150 |
+
"id": "7732e245",
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"outputs": [],
|
| 153 |
+
"source": [
|
| 154 |
+
"import sys, subprocess\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"def pip_install(*pkgs):\n",
|
| 157 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\",\n",
|
| 160 |
+
" \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
|
| 161 |
+
"\n",
|
| 162 |
+
"# Mamba kernel CUDA (tùy chọn — không có thì dùng Mamba thuần PyTorch, inference vẫn ổn vì chỉ 1 lượt forward)\n",
|
| 163 |
+
"INSTALL_MAMBA_SSM = True\n",
|
| 164 |
+
"if INSTALL_MAMBA_SSM:\n",
|
| 165 |
+
" try:\n",
|
| 166 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"ninja\"], check=True)\n",
|
| 167 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--no-build-isolation\", \"causal-conv1d>=1.2.0\"], check=True)\n",
|
| 168 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--no-build-isolation\", \"mamba-ssm\"], check=True)\n",
|
| 169 |
+
" print(\"✅ Cài mamba-ssm xong (dùng kernel CUDA nếu import được).\")\n",
|
| 170 |
+
" except Exception as e:\n",
|
| 171 |
+
" print(\"⚠️ Cài mamba-ssm thất bại:\", repr(e), \"→ Mamba thuần PyTorch (inference vẫn chạy).\")\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 174 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 175 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 176 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 177 |
+
"if REPO_DIR not in sys.path:\n",
|
| 178 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 179 |
+
]
|
| 180 |
+
},
|
| 181 |
+
{
|
| 182 |
+
"cell_type": "markdown",
|
| 183 |
+
"id": "fba12581",
|
| 184 |
+
"metadata": {},
|
| 185 |
+
"source": [
|
| 186 |
+
"## 2. Nạp checkpoint → dựng WavLM → load trọng số backbone đã fine-tune"
|
| 187 |
+
]
|
| 188 |
+
},
|
| 189 |
+
{
|
| 190 |
+
"cell_type": "code",
|
| 191 |
+
"execution_count": null,
|
| 192 |
+
"id": "61199736",
|
| 193 |
+
"metadata": {
|
| 194 |
+
"lines_to_next_cell": 1
|
| 195 |
+
},
|
| 196 |
+
"outputs": [],
|
| 197 |
+
"source": [
|
| 198 |
+
"import torch\n",
|
| 199 |
+
"import torch.nn as nn\n",
|
| 200 |
+
"import torch.nn.functional as F\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 203 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (chậm)\")\n",
|
| 204 |
+
"\n",
|
| 205 |
+
"ckpt = torch.load(CKPT_PATH, map_location=\"cpu\", weights_only=False) # ckpt có numpy → cần False\n",
|
| 206 |
+
"assert \"wavlm\" in ckpt, \"❌ Checkpoint KHÔNG có 'wavlm' (backbone) → không inference được. Cần ft_mamba_emotion_full*.pt đủ.\"\n",
|
| 207 |
+
"print(\"✅ Nạp ckpt | keys:\", list(ckpt.keys()))\n",
|
| 208 |
+
"\n",
|
| 209 |
+
"# Lấy cấu hình KIẾN TRÚC từ ckpt (để dựng đúng shape head)\n",
|
| 210 |
+
"USE_MAMBA = bool(ckpt.get(\"USE_MAMBA\", True))\n",
|
| 211 |
+
"Z_DIM = int(ckpt.get(\"Z_DIM\", 256))\n",
|
| 212 |
+
"AUD_DIM = int(ckpt.get(\"AUD_DIM\", 0))\n",
|
| 213 |
+
"USE_AUDEERING = AUD_DIM > 0\n",
|
| 214 |
+
"UNFREEZE_TOP_LAYERS = int(ckpt.get(\"UNFREEZE_TOP_LAYERS\", 6))\n",
|
| 215 |
+
"print(f\"Từ ckpt: USE_MAMBA={USE_MAMBA} · Z_DIM={Z_DIM} · AUD_DIM={AUD_DIM} (audeering={'ON' if USE_AUDEERING else 'OFF'})\")\n",
|
| 216 |
+
"\n",
|
| 217 |
+
"def find_hf_backbone(module):\n",
|
| 218 |
+
" cands = []\n",
|
| 219 |
+
" for name, m in module.named_modules():\n",
|
| 220 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 221 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 222 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 223 |
+
" cands.append((name, m))\n",
|
| 224 |
+
" if not cands:\n",
|
| 225 |
+
" return None, None\n",
|
| 226 |
+
" cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
|
| 227 |
+
" return cands[0]\n",
|
| 228 |
+
"\n",
|
| 229 |
+
"wavlm = None\n",
|
| 230 |
+
"try:\n",
|
| 231 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 232 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 233 |
+
" name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 234 |
+
" if wavlm is not None:\n",
|
| 235 |
+
" print(f\"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'\")\n",
|
| 236 |
+
"except Exception as e:\n",
|
| 237 |
+
" print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
|
| 238 |
+
"\n",
|
| 239 |
+
"if wavlm is None:\n",
|
| 240 |
+
" from transformers import WavLMModel\n",
|
| 241 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 242 |
+
" print(\"ℹ️ Fallback: microsoft/wavlm-large.\")\n",
|
| 243 |
+
"\n",
|
| 244 |
+
"wavlm = wavlm.to(device)\n",
|
| 245 |
+
"WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 246 |
+
"wavlm.config.layerdrop = 0.0\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"miss, unexp = wavlm.load_state_dict(ckpt[\"wavlm\"], strict=False)\n",
|
| 249 |
+
"print(f\"🔁 load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0)\")\n",
|
| 250 |
+
"if len(miss) > 20 or len(unexp) > 20:\n",
|
| 251 |
+
" print(\" ⚠️ Lệch key nhiều → kiểm tra backbone có khớp ckpt không.\")\n",
|
| 252 |
+
"wavlm.eval()\n",
|
| 253 |
+
"\n",
|
| 254 |
+
"def frame_mask(T, attn_mask):\n",
|
| 255 |
+
" if attn_mask is None:\n",
|
| 256 |
+
" return torch.ones((1, T), dtype=torch.bool, device=device)\n",
|
| 257 |
+
" try:\n",
|
| 258 |
+
" return wavlm._get_feature_vector_attention_mask(T, attn_mask).bool()\n",
|
| 259 |
+
" except Exception:\n",
|
| 260 |
+
" return torch.ones((attn_mask.shape[0], T), dtype=torch.bool, device=attn_mask.device)\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"def masked_mean(hidden, attn_mask):\n",
|
| 263 |
+
" if attn_mask is None:\n",
|
| 264 |
+
" return hidden.mean(dim=1)\n",
|
| 265 |
+
" fm = frame_mask(hidden.shape[1], attn_mask).unsqueeze(-1).to(hidden.dtype)\n",
|
| 266 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)"
|
| 267 |
+
]
|
| 268 |
+
},
|
| 269 |
+
{
|
| 270 |
+
"cell_type": "markdown",
|
| 271 |
+
"id": "421e0b6a",
|
| 272 |
+
"metadata": {},
|
| 273 |
+
"source": [
|
| 274 |
+
"## 3. audeering MSP-dim (FROZEN) — chỉ dựng nếu ckpt có dùng (AUD_DIM>0)"
|
| 275 |
+
]
|
| 276 |
+
},
|
| 277 |
+
{
|
| 278 |
+
"cell_type": "code",
|
| 279 |
+
"execution_count": null,
|
| 280 |
+
"id": "d37d3d53",
|
| 281 |
+
"metadata": {
|
| 282 |
+
"lines_to_next_cell": 1
|
| 283 |
+
},
|
| 284 |
+
"outputs": [],
|
| 285 |
+
"source": [
|
| 286 |
+
"import numpy as np\n",
|
| 287 |
+
"import librosa\n",
|
| 288 |
+
"from tqdm.auto import tqdm\n",
|
| 289 |
+
"\n",
|
| 290 |
+
"aud_backbone = aud_head = aud_proc = None\n",
|
| 291 |
+
"if USE_AUDEERING:\n",
|
| 292 |
+
" from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 293 |
+
" from huggingface_hub import hf_hub_download\n",
|
| 294 |
+
" AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 295 |
+
" aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 296 |
+
" aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 297 |
+
" aud_backbone = Wav2Vec2Model(aud_cfg)\n",
|
| 298 |
+
" try:\n",
|
| 299 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 300 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 301 |
+
" except Exception:\n",
|
| 302 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 303 |
+
" bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 304 |
+
" aud_backbone.load_state_dict(bb_sd, strict=False)\n",
|
| 305 |
+
" _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 306 |
+
" aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd[\"classifier.out_proj.weight\"].shape[0]))\n",
|
| 307 |
+
" aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 308 |
+
" aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 309 |
+
" aud_backbone = aud_backbone.to(device).eval()\n",
|
| 310 |
+
" aud_head = aud_head.to(device).eval()\n",
|
| 311 |
+
" assert _hid + 3 == AUD_DIM, f\"⚠️ AUD_DIM dựng ({_hid+3}) ≠ ckpt ({AUD_DIM}) → audeering không khớp!\"\n",
|
| 312 |
+
" print(f\"✅ audeering frozen ({AUD_DIM}-D)\")\n",
|
| 313 |
+
"\n",
|
| 314 |
+
"def load_wav(name_or_stem):\n",
|
| 315 |
+
" p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(\n",
|
| 316 |
+
" WAV_DIR, name_or_stem if str(name_or_stem).endswith(\".wav\") else str(name_or_stem) + \".wav\")\n",
|
| 317 |
+
" if not os.path.exists(p):\n",
|
| 318 |
+
" return None\n",
|
| 319 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 320 |
+
" return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"@torch.no_grad()\n",
|
| 323 |
+
"def extract_audeering(stems, tag):\n",
|
| 324 |
+
" if not USE_AUDEERING:\n",
|
| 325 |
+
" return {}\n",
|
| 326 |
+
" cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
|
| 327 |
+
" store = {}\n",
|
| 328 |
+
" if os.path.exists(cache_path):\n",
|
| 329 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 330 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 331 |
+
" print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
|
| 332 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 333 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
|
| 334 |
+
" wave = load_wav(s)\n",
|
| 335 |
+
" if wave is None:\n",
|
| 336 |
+
" continue\n",
|
| 337 |
+
" x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 338 |
+
" x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
|
| 339 |
+
" h = aud_backbone(x)[0].mean(dim=1)\n",
|
| 340 |
+
" out = aud_head(h)[0].cpu().numpy()\n",
|
| 341 |
+
" vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
|
| 342 |
+
" store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
|
| 343 |
+
" if (i + 1) % 500 == 0:\n",
|
| 344 |
+
" np.savez(cache_path, **store)\n",
|
| 345 |
+
" if todo:\n",
|
| 346 |
+
" np.savez(cache_path, **store)\n",
|
| 347 |
+
" return store"
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"cell_type": "markdown",
|
| 352 |
+
"id": "0a04ef30",
|
| 353 |
+
"metadata": {},
|
| 354 |
+
"source": [
|
| 355 |
+
"## 4. Cảm xúc target theo wavID (cho one-hot điều kiện của head EMOS)"
|
| 356 |
+
]
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
"cell_type": "code",
|
| 360 |
+
"execution_count": null,
|
| 361 |
+
"id": "3c092318",
|
| 362 |
+
"metadata": {
|
| 363 |
+
"lines_to_next_cell": 1
|
| 364 |
+
},
|
| 365 |
+
"outputs": [],
|
| 366 |
+
"source": [
|
| 367 |
+
"def load_target_emotions():\n",
|
| 368 |
+
" tgt = {}\n",
|
| 369 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 370 |
+
" for ln in f:\n",
|
| 371 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 372 |
+
" if len(parts) >= 2:\n",
|
| 373 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 374 |
+
" return tgt\n",
|
| 375 |
+
"\n",
|
| 376 |
+
"target_map = load_target_emotions()\n",
|
| 377 |
+
"print(\"Target cảm xúc:\", len(target_map), \"wav\")\n",
|
| 378 |
+
"\n",
|
| 379 |
+
"def onehot_target(tgt):\n",
|
| 380 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 381 |
+
" if tgt in EMOTIONS5:\n",
|
| 382 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 383 |
+
" return v"
|
| 384 |
+
]
|
| 385 |
+
},
|
| 386 |
+
{
|
| 387 |
+
"cell_type": "markdown",
|
| 388 |
+
"id": "a0d7021a",
|
| 389 |
+
"metadata": {},
|
| 390 |
+
"source": [
|
| 391 |
+
"## 5. Khối Mamba (giống exp15) + MambaEncoder"
|
| 392 |
+
]
|
| 393 |
+
},
|
| 394 |
+
{
|
| 395 |
+
"cell_type": "code",
|
| 396 |
+
"execution_count": null,
|
| 397 |
+
"id": "d8c31f88",
|
| 398 |
+
"metadata": {
|
| 399 |
+
"lines_to_next_cell": 1
|
| 400 |
+
},
|
| 401 |
+
"outputs": [],
|
| 402 |
+
"source": [
|
| 403 |
+
"import math\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"try:\n",
|
| 406 |
+
" from mamba_ssm import Mamba as _OfficialMamba\n",
|
| 407 |
+
" _HAS_MAMBA_SSM = True\n",
|
| 408 |
+
" print(\"✅ Dùng mamba-ssm (CUDA kernel)\")\n",
|
| 409 |
+
"except Exception:\n",
|
| 410 |
+
" _HAS_MAMBA_SSM = False\n",
|
| 411 |
+
" print(\"ℹ️ Không có mamba-ssm → Mamba thuần PyTorch\")\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"class MambaBlockTorch(nn.Module):\n",
|
| 414 |
+
" def __init__(self, d_model, d_state=16, d_conv=4, expand=2):\n",
|
| 415 |
+
" super().__init__()\n",
|
| 416 |
+
" self.d_inner = expand * d_model\n",
|
| 417 |
+
" self.dt_rank = math.ceil(d_model / 16)\n",
|
| 418 |
+
" self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)\n",
|
| 419 |
+
" self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,\n",
|
| 420 |
+
" groups=self.d_inner, padding=d_conv - 1, bias=True)\n",
|
| 421 |
+
" self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)\n",
|
| 422 |
+
" self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)\n",
|
| 423 |
+
" A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)\n",
|
| 424 |
+
" self.A_log = nn.Parameter(torch.log(A))\n",
|
| 425 |
+
" self.D = nn.Parameter(torch.ones(self.d_inner))\n",
|
| 426 |
+
" self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)\n",
|
| 427 |
+
" self.d_state = d_state\n",
|
| 428 |
+
"\n",
|
| 429 |
+
" def forward(self, x):\n",
|
| 430 |
+
" B, L, _ = x.shape\n",
|
| 431 |
+
" xin, z = self.in_proj(x).chunk(2, dim=-1)\n",
|
| 432 |
+
" xin = xin.transpose(1, 2)\n",
|
| 433 |
+
" xin = self.conv1d(xin)[..., :L].transpose(1, 2)\n",
|
| 434 |
+
" xin = F.silu(xin)\n",
|
| 435 |
+
" y = self._ssm(xin) * F.silu(z)\n",
|
| 436 |
+
" return self.out_proj(y)\n",
|
| 437 |
+
"\n",
|
| 438 |
+
" def _ssm(self, x):\n",
|
| 439 |
+
" A = -torch.exp(self.A_log)\n",
|
| 440 |
+
" delta, Bm, Cm = torch.split(self.x_proj(x), [self.dt_rank, self.d_state, self.d_state], dim=-1)\n",
|
| 441 |
+
" delta = F.softplus(self.dt_proj(delta))\n",
|
| 442 |
+
" dA = torch.exp(delta.unsqueeze(-1) * A)\n",
|
| 443 |
+
" dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1)\n",
|
| 444 |
+
" h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)\n",
|
| 445 |
+
" ys = []\n",
|
| 446 |
+
" for t in range(x.shape[1]):\n",
|
| 447 |
+
" h = dA[:, t] * h + dB_x[:, t]\n",
|
| 448 |
+
" ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1))\n",
|
| 449 |
+
" return torch.stack(ys, dim=1) + x * self.D\n",
|
| 450 |
+
"\n",
|
| 451 |
+
"class MambaLayer(nn.Module):\n",
|
| 452 |
+
" def __init__(self, d_model, d_state):\n",
|
| 453 |
+
" super().__init__()\n",
|
| 454 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 455 |
+
" self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2) \\\n",
|
| 456 |
+
" if _HAS_MAMBA_SSM else MambaBlockTorch(d_model, d_state=d_state)\n",
|
| 457 |
+
" def forward(self, x):\n",
|
| 458 |
+
" return x + self.mix(self.norm(x))\n",
|
| 459 |
+
"\n",
|
| 460 |
+
"class MambaEncoder(nn.Module):\n",
|
| 461 |
+
" def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):\n",
|
| 462 |
+
" super().__init__()\n",
|
| 463 |
+
" self.bidir = bidir\n",
|
| 464 |
+
" self.proj = nn.Linear(d_in, d_model)\n",
|
| 465 |
+
" self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
|
| 466 |
+
" if bidir:\n",
|
| 467 |
+
" self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
|
| 468 |
+
" self.attn = nn.Linear(d_model, 1)\n",
|
| 469 |
+
" self.out = nn.Linear(d_model, z_dim)\n",
|
| 470 |
+
"\n",
|
| 471 |
+
" @staticmethod\n",
|
| 472 |
+
" def _run(layers, h):\n",
|
| 473 |
+
" for L in layers:\n",
|
| 474 |
+
" h = L(h)\n",
|
| 475 |
+
" return h\n",
|
| 476 |
+
"\n",
|
| 477 |
+
" def forward(self, x, mask):\n",
|
| 478 |
+
" with torch.cuda.amp.autocast(enabled=False):\n",
|
| 479 |
+
" x = x.float()\n",
|
| 480 |
+
" h = self.proj(x)\n",
|
| 481 |
+
" out = self._run(self.fwd, h)\n",
|
| 482 |
+
" if self.bidir:\n",
|
| 483 |
+
" out = out + torch.flip(self._run(self.bwd, torch.flip(h, dims=[1])), dims=[1])\n",
|
| 484 |
+
" a = self.attn(out).squeeze(-1).masked_fill(~mask, float(\"-inf\"))\n",
|
| 485 |
+
" w = torch.softmax(a, dim=1).unsqueeze(-1)\n",
|
| 486 |
+
" return self.out((out * w).sum(1))"
|
| 487 |
+
]
|
| 488 |
+
},
|
| 489 |
+
{
|
| 490 |
+
"cell_type": "markdown",
|
| 491 |
+
"id": "c8369a6b",
|
| 492 |
+
"metadata": {},
|
| 493 |
+
"source": [
|
| 494 |
+
"## 6. Dựng enc + heads → nạp trọng số từ ckpt + lấy chuẩn hóa từ ckpt"
|
| 495 |
+
]
|
| 496 |
+
},
|
| 497 |
+
{
|
| 498 |
+
"cell_type": "code",
|
| 499 |
+
"execution_count": null,
|
| 500 |
+
"id": "1c5e8556",
|
| 501 |
+
"metadata": {
|
| 502 |
+
"lines_to_next_cell": 1
|
| 503 |
+
},
|
| 504 |
+
"outputs": [],
|
| 505 |
+
"source": [
|
| 506 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 507 |
+
"WAVLM_BRANCH = Z_DIM if USE_MAMBA else WAVLM_DIM\n",
|
| 508 |
+
"TRUNK_IN = WAVLM_BRANCH + (AUD_DIM if USE_AUDEERING else 0)\n",
|
| 509 |
+
"\n",
|
| 510 |
+
"enc = MambaEncoder(WAVLM_DIM, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL).to(device) \\\n",
|
| 511 |
+
" if USE_MAMBA else None\n",
|
| 512 |
+
"\n",
|
| 513 |
+
"class EmoHeads(nn.Module):\n",
|
| 514 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 515 |
+
" super().__init__()\n",
|
| 516 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 517 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 518 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 519 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 520 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 521 |
+
" def forward(self, feat, tgt):\n",
|
| 522 |
+
" h = self.trunk(feat)\n",
|
| 523 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 524 |
+
"\n",
|
| 525 |
+
"heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
|
| 526 |
+
"hm, hu = heads.load_state_dict(ckpt[\"heads\"], strict=False)\n",
|
| 527 |
+
"print(f\"🔁 load heads từ ckpt: thiếu {len(hm)} / dư {len(hu)} key (kỳ vọng 0)\")\n",
|
| 528 |
+
"if USE_MAMBA:\n",
|
| 529 |
+
" assert ckpt.get(\"enc\") is not None, \"❌ ckpt USE_MAMBA=True nhưng KHÔNG có 'enc' → không inference đúng được.\"\n",
|
| 530 |
+
" em, eu = enc.load_state_dict(ckpt[\"enc\"], strict=False)\n",
|
| 531 |
+
" print(f\"🔁 load Mamba enc từ ckpt: thiếu {len(em)} / dư {len(eu)} key (kỳ vọng 0)\")\n",
|
| 532 |
+
"heads.eval()\n",
|
| 533 |
+
"if USE_MAMBA:\n",
|
| 534 |
+
" enc.eval()\n",
|
| 535 |
+
"\n",
|
| 536 |
+
"# Chuẩn hóa LẤY TỪ ckpt (head dự đoán ở thang z-score này → phải giải chuẩn đúng thang)\n",
|
| 537 |
+
"emos_mu = float(ckpt[\"emos_mu\"]); emos_sd = float(ckpt[\"emos_sd\"])\n",
|
| 538 |
+
"vad_mu = np.asarray(ckpt[\"vad_mu\"], dtype=np.float32); vad_sd = np.asarray(ckpt[\"vad_sd\"], dtype=np.float32)\n",
|
| 539 |
+
"print(f\"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
|
| 540 |
+
"\n",
|
| 541 |
+
"def wavlm_branch(input_values, attn_mask):\n",
|
| 542 |
+
" out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state\n",
|
| 543 |
+
" if USE_MAMBA:\n",
|
| 544 |
+
" return enc(out, frame_mask(out.shape[1], attn_mask))\n",
|
| 545 |
+
" return masked_mean(out, attn_mask)\n",
|
| 546 |
+
"\n",
|
| 547 |
+
"print(f\"Trunk input = {TRUNK_IN} (wavlm-branch {WAVLM_BRANCH} [{'Mamba' if USE_MAMBA else 'mean-pool'}] + aud {AUD_DIM if USE_AUDEERING else 0})\")"
|
| 548 |
+
]
|
| 549 |
+
},
|
| 550 |
+
{
|
| 551 |
+
"cell_type": "markdown",
|
| 552 |
+
"id": "fdcf05c2",
|
| 553 |
+
"metadata": {},
|
| 554 |
+
"source": [
|
| 555 |
+
"## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc; QMOS mượn exp07/UTMOSv2)"
|
| 556 |
+
]
|
| 557 |
+
},
|
| 558 |
+
{
|
| 559 |
+
"cell_type": "code",
|
| 560 |
+
"execution_count": null,
|
| 561 |
+
"id": "4d225f54",
|
| 562 |
+
"metadata": {
|
| 563 |
+
"lines_to_next_cell": 1
|
| 564 |
+
},
|
| 565 |
+
"outputs": [],
|
| 566 |
+
"source": [
|
| 567 |
+
"def list_dev():\n",
|
| 568 |
+
" with open(DEV_SCP) as f:\n",
|
| 569 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 570 |
+
"\n",
|
| 571 |
+
"dev_names = list_dev()\n",
|
| 572 |
+
"if LIMIT_DEV:\n",
|
| 573 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 574 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 575 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 576 |
+
"aud_dev = extract_audeering(dev_stems, \"dev\")\n",
|
| 577 |
+
"\n",
|
| 578 |
+
"def load_exp07_qmos():\n",
|
| 579 |
+
" if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
|
| 580 |
+
" import csv\n",
|
| 581 |
+
" d = {}\n",
|
| 582 |
+
" with open(EXP07_ANSWER) as f:\n",
|
| 583 |
+
" for row in csv.DictReader(f):\n",
|
| 584 |
+
" d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
|
| 585 |
+
" print(f\"✅ Mượn QMOS exp07 ({EXP07_ANSWER}): {len(d)//2} wav\")\n",
|
| 586 |
+
" return d\n",
|
| 587 |
+
" return None\n",
|
| 588 |
+
"\n",
|
| 589 |
+
"qmos_map = load_exp07_qmos()\n",
|
| 590 |
+
"if qmos_map is None:\n",
|
| 591 |
+
" print(\"ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).\")\n",
|
| 592 |
+
" pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
|
| 593 |
+
" import utmosv2\n",
|
| 594 |
+
" v2 = utmosv2.create_model(pretrained=True)\n",
|
| 595 |
+
" qmos_map = {}\n",
|
| 596 |
+
" for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
|
| 597 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
|
| 598 |
+
" if not os.path.exists(wav):\n",
|
| 599 |
+
" continue\n",
|
| 600 |
+
" out = v2.predict(input_path=wav)\n",
|
| 601 |
+
" qmos_map[n] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
|
| 602 |
+
" del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 603 |
+
"\n",
|
| 604 |
+
"@torch.no_grad()\n",
|
| 605 |
+
"def predict_emotion(sid):\n",
|
| 606 |
+
" wave = load_wav(sid)\n",
|
| 607 |
+
" if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
|
| 608 |
+
" return None\n",
|
| 609 |
+
" iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 610 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
|
| 611 |
+
" tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
|
| 612 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 613 |
+
" fw = wavlm_branch(iv, am)\n",
|
| 614 |
+
" feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw\n",
|
| 615 |
+
" emos_p, cat_l, vad_p = heads(feat, tgt)\n",
|
| 616 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 617 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 618 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 619 |
+
" return emos, cat5, vad3\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"def fmt_cat(p5):\n",
|
| 622 |
+
" return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 623 |
+
"\n",
|
| 624 |
+
"def build_answer(out_path):\n",
|
| 625 |
+
" n_real = n_def = 0\n",
|
| 626 |
+
" with open(out_path, \"w\") as f:\n",
|
| 627 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 628 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 629 |
+
" sid = stem(name)\n",
|
| 630 |
+
" pr = predict_emotion(sid)\n",
|
| 631 |
+
" if pr is None:\n",
|
| 632 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
|
| 633 |
+
" else:\n",
|
| 634 |
+
" emos, cat5, vad3 = pr; n_real += 1\n",
|
| 635 |
+
" qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
|
| 636 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 637 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
|
| 638 |
+
"\n",
|
| 639 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 640 |
+
"build_answer(answer_path)"
|
| 641 |
+
]
|
| 642 |
+
},
|
| 643 |
+
{
|
| 644 |
+
"cell_type": "markdown",
|
| 645 |
+
"id": "42503595",
|
| 646 |
+
"metadata": {},
|
| 647 |
+
"source": [
|
| 648 |
+
"## 8. Validate + đóng zip"
|
| 649 |
+
]
|
| 650 |
+
},
|
| 651 |
+
{
|
| 652 |
+
"cell_type": "code",
|
| 653 |
+
"execution_count": null,
|
| 654 |
+
"id": "42dec31f",
|
| 655 |
+
"metadata": {},
|
| 656 |
+
"outputs": [],
|
| 657 |
+
"source": [
|
| 658 |
+
"def validate(path):\n",
|
| 659 |
+
" import csv\n",
|
| 660 |
+
" with open(path) as f:\n",
|
| 661 |
+
" rows = list(csv.reader(f))\n",
|
| 662 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0] and \"EMOS\" in rows[0], \"Header sai\"\n",
|
| 663 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 664 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 665 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 666 |
+
"\n",
|
| 667 |
+
"validate(answer_path)\n",
|
| 668 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp15_predict.zip answer.txt \"\n",
|
| 669 |
+
" f\"&& unzip -l submission_track2_exp15_predict.zip\")\n",
|
| 670 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp15_predict.zip\"))"
|
| 671 |
+
]
|
| 672 |
+
},
|
| 673 |
+
{
|
| 674 |
+
"cell_type": "markdown",
|
| 675 |
+
"id": "fbef2a21",
|
| 676 |
+
"metadata": {},
|
| 677 |
+
"source": [
|
| 678 |
+
"## Ghi chú\n",
|
| 679 |
+
"- File này **chỉ inference** — không train, không cần train.csv. Dùng khi đã có `ft_mamba_emotion_full*.pt`.\n",
|
| 680 |
+
"- ⚠️ **Siêu tham số Mamba/heads (MAMBA_DMODEL/LAYERS/DSTATE, TRUNK_HIDDEN, HEAD_HIDDEN) PHẢI khớp lúc train**\n",
|
| 681 |
+
" (ckpt không lưu các số này) — nếu lúc train exp15 bạn đổi, hãy sửa cho khớp ở cell 0, nếu không load_state_dict\n",
|
| 682 |
+
" sẽ lệch key / sai shape.\n",
|
| 683 |
+
"- `USE_MAMBA`, `Z_DIM`, `AUD_DIM`, `UNFREEZE_TOP_LAYERS` thì **đọc tự động từ ckpt**.\n",
|
| 684 |
+
"- QMOS: tốt nhất Add Input `answer.txt` exp07 (0.548); không có thì tự chấm UTMOSv2.\n",
|
| 685 |
+
"- Smoke-test: đặt `LIMIT_DEV=20` chạy thử cho nhanh, OK rồi đặt lại `None` để chấm đủ 2730."
|
| 686 |
+
]
|
| 687 |
+
}
|
| 688 |
+
],
|
| 689 |
+
"metadata": {
|
| 690 |
+
"jupytext": {
|
| 691 |
+
"cell_metadata_filter": "-all",
|
| 692 |
+
"main_language": "python",
|
| 693 |
+
"notebook_metadata_filter": "-all"
|
| 694 |
+
}
|
| 695 |
+
},
|
| 696 |
+
"nbformat": 4,
|
| 697 |
+
"nbformat_minor": 5
|
| 698 |
+
}
|
track2/exp15_predict_pipeline.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp15 PREDICT-ONLY (nạp checkpoint → chấm DEV, KHÔNG train) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Mục đích:** bạn ĐÃ có checkpoint exp15 (`ft_mamba_emotion_full*.pt`, lưu cả backbone WavLM + Mamba enc + heads).
|
| 5 |
+
# File này **chỉ inference**: dựng lại đúng kiến trúc → nạp trọng số + thống kê chuẩn hóa TỪ ckpt →
|
| 6 |
+
# dự đoán 5 cột cảm xúc trên tập DEV → ghép QMOS (exp07/UTMOSv2) → `answer.txt` → zip nộp.
|
| 7 |
+
# **KHÔNG** train, **KHÔNG** cần train.csv (chỉ cần wav DEV + metadata.csv để lấy cảm xúc target cho EMOS).
|
| 8 |
+
#
|
| 9 |
+
# ## Vì sao nhanh
|
| 10 |
+
# - Không có vòng train → chỉ 1 lượt forward qua DEV (~2730 mẫu). Việc lâu nhất là trích audeering DEV
|
| 11 |
+
# (~vài phút; có cache thì gần như tức thì).
|
| 12 |
+
#
|
| 13 |
+
# ## Chuẩn bị input trên Kaggle (Add Input)
|
| 14 |
+
# 1. Dataset Track 2 (wav + `metadata.csv` + `sets/dev.scp`).
|
| 15 |
+
# 2. **Checkpoint** exp15: dataset chứa `ft_mamba_emotion_full*.pt` (vd `cache_exp8`). Auto-dò; hoặc trỏ `CKPT_PATH`.
|
| 16 |
+
# 3. (tùy chọn) cache audeering `aud_dev.npz` để khỏi trích lại.
|
| 17 |
+
# 4. (tùy chọn) `answer.txt` exp07 để mượn cột QMOS 0.548.
|
| 18 |
+
#
|
| 19 |
+
# **Cách chạy:** GPU **T4** + Internet **On** → Add Input → Run All.
|
| 20 |
+
|
| 21 |
+
# %% [markdown]
|
| 22 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 23 |
+
|
| 24 |
+
# %%
|
| 25 |
+
import os, glob
|
| 26 |
+
|
| 27 |
+
# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets + wav/ + metadata.csv) ──
|
| 28 |
+
def find_data_root(search_root="/kaggle/input"):
|
| 29 |
+
cands = []
|
| 30 |
+
for dev_scp in glob.glob(os.path.join(search_root, "**", "sets", "dev.scp"), recursive=True):
|
| 31 |
+
root = os.path.dirname(os.path.dirname(dev_scp))
|
| 32 |
+
score = os.path.isdir(os.path.join(root, "wav")) + os.path.exists(os.path.join(root, "metadata.csv"))
|
| 33 |
+
cands.append((score, root))
|
| 34 |
+
cands.sort(reverse=True)
|
| 35 |
+
return cands
|
| 36 |
+
|
| 37 |
+
_cands = find_data_root("/kaggle/input")
|
| 38 |
+
if _cands:
|
| 39 |
+
print("🔎 Ứng viên DATA_ROOT:")
|
| 40 |
+
for sc, r in _cands:
|
| 41 |
+
print(f" [{sc}/2] {r}")
|
| 42 |
+
DATA_ROOT = _cands[0][1]
|
| 43 |
+
print(f"👉 Tự chọn DATA_ROOT = {DATA_ROOT}")
|
| 44 |
+
else:
|
| 45 |
+
DATA_ROOT = "/kaggle/input/datasets/minhtoan2" # dự phòng — sửa tay
|
| 46 |
+
print(f"❌ Không thấy sets/dev.scp → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)")
|
| 47 |
+
|
| 48 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 49 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) — lấy cảm xúc target cho EMOS
|
| 50 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 51 |
+
|
| 52 |
+
OUT_DIR = "/kaggle/working"
|
| 53 |
+
CACHE_DIR = "/kaggle/working/ft_cache"
|
| 54 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 55 |
+
|
| 56 |
+
# ── CHECKPOINT exp15 (đủ backbone + Mamba + heads) ───────────────────────────
|
| 57 |
+
CKPT_PATH = "" # << "" = auto-dò ft_mamba_emotion_full*.pt; hoặc "/kaggle/input/<slug>/ft_mamba_emotion_full (2).pt"
|
| 58 |
+
|
| 59 |
+
def find_ckpt(explicit):
|
| 60 |
+
"""Tìm checkpoint exp15. Khớp cả tên bị thêm hậu tố trùng, vd 'ft_mamba_emotion_full (2).pt'."""
|
| 61 |
+
if explicit and os.path.exists(explicit):
|
| 62 |
+
return explicit
|
| 63 |
+
for base in ["/kaggle/input", "/kaggle/working"]:
|
| 64 |
+
hits = sorted(glob.glob(os.path.join(base, "**", "ft_mamba_emotion_full*.pt"), recursive=True))
|
| 65 |
+
if hits:
|
| 66 |
+
return hits[0]
|
| 67 |
+
return ""
|
| 68 |
+
|
| 69 |
+
CKPT_PATH = find_ckpt(CKPT_PATH)
|
| 70 |
+
assert CKPT_PATH, "❌ Không thấy checkpoint ft_mamba_emotion_full*.pt. Đã Add Input dataset chứa ckpt chưa?"
|
| 71 |
+
print("✅ Dùng checkpoint:", CKPT_PATH)
|
| 72 |
+
|
| 73 |
+
# (Tùy chọn) tái dùng cache audeering DEV — quét đệ quy (file có thể nằm trong archive/)
|
| 74 |
+
CACHE_INPUT = "/kaggle/input/cache-exp8" # << SỬA slug (hoặc "")
|
| 75 |
+
if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
|
| 76 |
+
import shutil
|
| 77 |
+
_n = 0
|
| 78 |
+
for _fp in glob.glob(os.path.join(CACHE_INPUT, "**", "aud_*.npz"), recursive=True):
|
| 79 |
+
shutil.copy(_fp, os.path.join(CACHE_DIR, os.path.basename(_fp))); _n += 1
|
| 80 |
+
print(f"📦 Copy {_n} file aud_*.npz từ {CACHE_INPUT}")
|
| 81 |
+
|
| 82 |
+
# Mượn cột QMOS exp07 (0.548). Trỏ answer.txt exp07 nếu có; không thì UTMOSv2.
|
| 83 |
+
EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn)
|
| 84 |
+
|
| 85 |
+
# ── Siêu tham số PHẢI KHỚP lúc train exp15 (ckpt không lưu các số này của Mamba) ──
|
| 86 |
+
MAMBA_DMODEL = 256
|
| 87 |
+
MAMBA_LAYERS = 2
|
| 88 |
+
MAMBA_DSTATE = 16
|
| 89 |
+
BIDIRECTIONAL = True
|
| 90 |
+
TRUNK_HIDDEN = 512
|
| 91 |
+
HEAD_HIDDEN = 128
|
| 92 |
+
DROPOUT = 0.3 # không ảnh hưởng eval (model.eval() tắt dropout) — chỉ để dựng đúng shape
|
| 93 |
+
|
| 94 |
+
DEVICE = "cuda"
|
| 95 |
+
SR = 16000
|
| 96 |
+
MAX_SECONDS = 6 # khớp lúc train (exp15 = 6)
|
| 97 |
+
USE_AMP = True
|
| 98 |
+
LIMIT_DEV = None # << để None chấm ĐỦ 2730; đặt 20 để smoke-test nhanh
|
| 99 |
+
|
| 100 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 101 |
+
_EMO_ALIAS = {
|
| 102 |
+
"angry": "angry", "anger": "angry",
|
| 103 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 104 |
+
"neutral": "neutral", "calm": "neutral",
|
| 105 |
+
"sad": "sad", "sadness": "sad",
|
| 106 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def norm_emotion(label):
|
| 110 |
+
key = str(label).strip().lower()
|
| 111 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 112 |
+
|
| 113 |
+
def stem(p):
|
| 114 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 115 |
+
|
| 116 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 117 |
+
for p in [WAV_DIR, METADATA_CSV, DEV_SCP, CKPT_PATH]:
|
| 118 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 119 |
+
|
| 120 |
+
# %% [markdown]
|
| 121 |
+
# ## 1. Cài đặt + tải code SAILER (để dựng đúng kiến trúc WavLM rồi nạp ckpt đè lên)
|
| 122 |
+
|
| 123 |
+
# %%
|
| 124 |
+
import sys, subprocess
|
| 125 |
+
|
| 126 |
+
def pip_install(*pkgs):
|
| 127 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 128 |
+
|
| 129 |
+
pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile",
|
| 130 |
+
"scipy", "scikit-learn", "pandas", "tqdm")
|
| 131 |
+
|
| 132 |
+
# Mamba kernel CUDA (tùy chọn — không có thì dùng Mamba thuần PyTorch, inference vẫn ổn vì chỉ 1 lượt forward)
|
| 133 |
+
INSTALL_MAMBA_SSM = True
|
| 134 |
+
if INSTALL_MAMBA_SSM:
|
| 135 |
+
try:
|
| 136 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "ninja"], check=True)
|
| 137 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "--no-build-isolation", "causal-conv1d>=1.2.0"], check=True)
|
| 138 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "--no-build-isolation", "mamba-ssm"], check=True)
|
| 139 |
+
print("✅ Cài mamba-ssm xong (dùng kernel CUDA nếu import được).")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print("⚠️ Cài mamba-ssm thất bại:", repr(e), "→ Mamba thuần PyTorch (inference vẫn chạy).")
|
| 142 |
+
|
| 143 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 144 |
+
if not os.path.exists(REPO_DIR):
|
| 145 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 146 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 147 |
+
if REPO_DIR not in sys.path:
|
| 148 |
+
sys.path.insert(0, REPO_DIR)
|
| 149 |
+
|
| 150 |
+
# %% [markdown]
|
| 151 |
+
# ## 2. Nạp checkpoint → dựng WavLM → load trọng số backbone đã fine-tune
|
| 152 |
+
|
| 153 |
+
# %%
|
| 154 |
+
import torch
|
| 155 |
+
import torch.nn as nn
|
| 156 |
+
import torch.nn.functional as F
|
| 157 |
+
|
| 158 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 159 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (chậm)")
|
| 160 |
+
|
| 161 |
+
ckpt = torch.load(CKPT_PATH, map_location="cpu", weights_only=False) # ckpt có numpy → cần False
|
| 162 |
+
assert "wavlm" in ckpt, "❌ Checkpoint KHÔNG có 'wavlm' (backbone) → không inference được. Cần ft_mamba_emotion_full*.pt đủ."
|
| 163 |
+
print("✅ Nạp ckpt | keys:", list(ckpt.keys()))
|
| 164 |
+
|
| 165 |
+
# Lấy cấu hình KIẾN TRÚC từ ckpt (để dựng đúng shape head)
|
| 166 |
+
USE_MAMBA = bool(ckpt.get("USE_MAMBA", True))
|
| 167 |
+
Z_DIM = int(ckpt.get("Z_DIM", 256))
|
| 168 |
+
AUD_DIM = int(ckpt.get("AUD_DIM", 0))
|
| 169 |
+
USE_AUDEERING = AUD_DIM > 0
|
| 170 |
+
UNFREEZE_TOP_LAYERS = int(ckpt.get("UNFREEZE_TOP_LAYERS", 6))
|
| 171 |
+
print(f"Từ ckpt: USE_MAMBA={USE_MAMBA} · Z_DIM={Z_DIM} · AUD_DIM={AUD_DIM} (audeering={'ON' if USE_AUDEERING else 'OFF'})")
|
| 172 |
+
|
| 173 |
+
def find_hf_backbone(module):
|
| 174 |
+
cands = []
|
| 175 |
+
for name, m in module.named_modules():
|
| 176 |
+
enc = getattr(m, "encoder", None)
|
| 177 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 178 |
+
and getattr(enc, "layers", None) is not None:
|
| 179 |
+
cands.append((name, m))
|
| 180 |
+
if not cands:
|
| 181 |
+
return None, None
|
| 182 |
+
cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
|
| 183 |
+
return cands[0]
|
| 184 |
+
|
| 185 |
+
wavlm = None
|
| 186 |
+
try:
|
| 187 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 188 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 189 |
+
name, wavlm = find_hf_backbone(_wrapper)
|
| 190 |
+
if wavlm is not None:
|
| 191 |
+
print(f"✅ Dựng backbone WavLM từ SAILER wrapper tại '.{name}'")
|
| 192 |
+
except Exception as e:
|
| 193 |
+
print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
|
| 194 |
+
|
| 195 |
+
if wavlm is None:
|
| 196 |
+
from transformers import WavLMModel
|
| 197 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 198 |
+
print("ℹ️ Fallback: microsoft/wavlm-large.")
|
| 199 |
+
|
| 200 |
+
wavlm = wavlm.to(device)
|
| 201 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 202 |
+
wavlm.config.layerdrop = 0.0
|
| 203 |
+
|
| 204 |
+
miss, unexp = wavlm.load_state_dict(ckpt["wavlm"], strict=False)
|
| 205 |
+
print(f"🔁 load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0)")
|
| 206 |
+
if len(miss) > 20 or len(unexp) > 20:
|
| 207 |
+
print(" ⚠️ Lệch key nhiều → kiểm tra backbone có khớp ckpt không.")
|
| 208 |
+
wavlm.eval()
|
| 209 |
+
|
| 210 |
+
def frame_mask(T, attn_mask):
|
| 211 |
+
if attn_mask is None:
|
| 212 |
+
return torch.ones((1, T), dtype=torch.bool, device=device)
|
| 213 |
+
try:
|
| 214 |
+
return wavlm._get_feature_vector_attention_mask(T, attn_mask).bool()
|
| 215 |
+
except Exception:
|
| 216 |
+
return torch.ones((attn_mask.shape[0], T), dtype=torch.bool, device=attn_mask.device)
|
| 217 |
+
|
| 218 |
+
def masked_mean(hidden, attn_mask):
|
| 219 |
+
if attn_mask is None:
|
| 220 |
+
return hidden.mean(dim=1)
|
| 221 |
+
fm = frame_mask(hidden.shape[1], attn_mask).unsqueeze(-1).to(hidden.dtype)
|
| 222 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 223 |
+
|
| 224 |
+
# %% [markdown]
|
| 225 |
+
# ## 3. audeering MSP-dim (FROZEN) — chỉ dựng nếu ckpt có dùng (AUD_DIM>0)
|
| 226 |
+
|
| 227 |
+
# %%
|
| 228 |
+
import numpy as np
|
| 229 |
+
import librosa
|
| 230 |
+
from tqdm.auto import tqdm
|
| 231 |
+
|
| 232 |
+
aud_backbone = aud_head = aud_proc = None
|
| 233 |
+
if USE_AUDEERING:
|
| 234 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 235 |
+
from huggingface_hub import hf_hub_download
|
| 236 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 237 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 238 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 239 |
+
aud_backbone = Wav2Vec2Model(aud_cfg)
|
| 240 |
+
try:
|
| 241 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 242 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 243 |
+
except Exception:
|
| 244 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 245 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 246 |
+
aud_backbone.load_state_dict(bb_sd, strict=False)
|
| 247 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 248 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _sd["classifier.out_proj.weight"].shape[0]))
|
| 249 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 250 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 251 |
+
aud_backbone = aud_backbone.to(device).eval()
|
| 252 |
+
aud_head = aud_head.to(device).eval()
|
| 253 |
+
assert _hid + 3 == AUD_DIM, f"⚠️ AUD_DIM dựng ({_hid+3}) ≠ ckpt ({AUD_DIM}) → audeering không khớp!"
|
| 254 |
+
print(f"✅ audeering frozen ({AUD_DIM}-D)")
|
| 255 |
+
|
| 256 |
+
def load_wav(name_or_stem):
|
| 257 |
+
p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(
|
| 258 |
+
WAV_DIR, name_or_stem if str(name_or_stem).endswith(".wav") else str(name_or_stem) + ".wav")
|
| 259 |
+
if not os.path.exists(p):
|
| 260 |
+
return None
|
| 261 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 262 |
+
return wave[: MAX_SECONDS * SR].astype(np.float32)
|
| 263 |
+
|
| 264 |
+
@torch.no_grad()
|
| 265 |
+
def extract_audeering(stems, tag):
|
| 266 |
+
if not USE_AUDEERING:
|
| 267 |
+
return {}
|
| 268 |
+
cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
|
| 269 |
+
store = {}
|
| 270 |
+
if os.path.exists(cache_path):
|
| 271 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 272 |
+
store = {k: z[k] for k in z.files}
|
| 273 |
+
print(f"[aud/{tag}] nạp cache: {len(store)}")
|
| 274 |
+
todo = [s for s in stems if s not in store]
|
| 275 |
+
for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
|
| 276 |
+
wave = load_wav(s)
|
| 277 |
+
if wave is None:
|
| 278 |
+
continue
|
| 279 |
+
x = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 280 |
+
x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
|
| 281 |
+
h = aud_backbone(x)[0].mean(dim=1)
|
| 282 |
+
out = aud_head(h)[0].cpu().numpy()
|
| 283 |
+
vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
|
| 284 |
+
store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
|
| 285 |
+
if (i + 1) % 500 == 0:
|
| 286 |
+
np.savez(cache_path, **store)
|
| 287 |
+
if todo:
|
| 288 |
+
np.savez(cache_path, **store)
|
| 289 |
+
return store
|
| 290 |
+
|
| 291 |
+
# %% [markdown]
|
| 292 |
+
# ## 4. Cảm xúc target theo wavID (cho one-hot điều kiện của head EMOS)
|
| 293 |
+
|
| 294 |
+
# %%
|
| 295 |
+
def load_target_emotions():
|
| 296 |
+
tgt = {}
|
| 297 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 298 |
+
for ln in f:
|
| 299 |
+
parts = ln.strip().split("|")
|
| 300 |
+
if len(parts) >= 2:
|
| 301 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 302 |
+
return tgt
|
| 303 |
+
|
| 304 |
+
target_map = load_target_emotions()
|
| 305 |
+
print("Target cảm xúc:", len(target_map), "wav")
|
| 306 |
+
|
| 307 |
+
def onehot_target(tgt):
|
| 308 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 309 |
+
if tgt in EMOTIONS5:
|
| 310 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 311 |
+
return v
|
| 312 |
+
|
| 313 |
+
# %% [markdown]
|
| 314 |
+
# ## 5. Khối Mamba (giống exp15) + MambaEncoder
|
| 315 |
+
|
| 316 |
+
# %%
|
| 317 |
+
import math
|
| 318 |
+
|
| 319 |
+
try:
|
| 320 |
+
from mamba_ssm import Mamba as _OfficialMamba
|
| 321 |
+
_HAS_MAMBA_SSM = True
|
| 322 |
+
print("✅ Dùng mamba-ssm (CUDA kernel)")
|
| 323 |
+
except Exception:
|
| 324 |
+
_HAS_MAMBA_SSM = False
|
| 325 |
+
print("ℹ️ Không có mamba-ssm → Mamba thuần PyTorch")
|
| 326 |
+
|
| 327 |
+
class MambaBlockTorch(nn.Module):
|
| 328 |
+
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.d_inner = expand * d_model
|
| 331 |
+
self.dt_rank = math.ceil(d_model / 16)
|
| 332 |
+
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
| 333 |
+
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,
|
| 334 |
+
groups=self.d_inner, padding=d_conv - 1, bias=True)
|
| 335 |
+
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
|
| 336 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
| 337 |
+
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
|
| 338 |
+
self.A_log = nn.Parameter(torch.log(A))
|
| 339 |
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
| 340 |
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
| 341 |
+
self.d_state = d_state
|
| 342 |
+
|
| 343 |
+
def forward(self, x):
|
| 344 |
+
B, L, _ = x.shape
|
| 345 |
+
xin, z = self.in_proj(x).chunk(2, dim=-1)
|
| 346 |
+
xin = xin.transpose(1, 2)
|
| 347 |
+
xin = self.conv1d(xin)[..., :L].transpose(1, 2)
|
| 348 |
+
xin = F.silu(xin)
|
| 349 |
+
y = self._ssm(xin) * F.silu(z)
|
| 350 |
+
return self.out_proj(y)
|
| 351 |
+
|
| 352 |
+
def _ssm(self, x):
|
| 353 |
+
A = -torch.exp(self.A_log)
|
| 354 |
+
delta, Bm, Cm = torch.split(self.x_proj(x), [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 355 |
+
delta = F.softplus(self.dt_proj(delta))
|
| 356 |
+
dA = torch.exp(delta.unsqueeze(-1) * A)
|
| 357 |
+
dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1)
|
| 358 |
+
h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
|
| 359 |
+
ys = []
|
| 360 |
+
for t in range(x.shape[1]):
|
| 361 |
+
h = dA[:, t] * h + dB_x[:, t]
|
| 362 |
+
ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1))
|
| 363 |
+
return torch.stack(ys, dim=1) + x * self.D
|
| 364 |
+
|
| 365 |
+
class MambaLayer(nn.Module):
|
| 366 |
+
def __init__(self, d_model, d_state):
|
| 367 |
+
super().__init__()
|
| 368 |
+
self.norm = nn.LayerNorm(d_model)
|
| 369 |
+
self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2) \
|
| 370 |
+
if _HAS_MAMBA_SSM else MambaBlockTorch(d_model, d_state=d_state)
|
| 371 |
+
def forward(self, x):
|
| 372 |
+
return x + self.mix(self.norm(x))
|
| 373 |
+
|
| 374 |
+
class MambaEncoder(nn.Module):
|
| 375 |
+
def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.bidir = bidir
|
| 378 |
+
self.proj = nn.Linear(d_in, d_model)
|
| 379 |
+
self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
|
| 380 |
+
if bidir:
|
| 381 |
+
self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
|
| 382 |
+
self.attn = nn.Linear(d_model, 1)
|
| 383 |
+
self.out = nn.Linear(d_model, z_dim)
|
| 384 |
+
|
| 385 |
+
@staticmethod
|
| 386 |
+
def _run(layers, h):
|
| 387 |
+
for L in layers:
|
| 388 |
+
h = L(h)
|
| 389 |
+
return h
|
| 390 |
+
|
| 391 |
+
def forward(self, x, mask):
|
| 392 |
+
with torch.cuda.amp.autocast(enabled=False):
|
| 393 |
+
x = x.float()
|
| 394 |
+
h = self.proj(x)
|
| 395 |
+
out = self._run(self.fwd, h)
|
| 396 |
+
if self.bidir:
|
| 397 |
+
out = out + torch.flip(self._run(self.bwd, torch.flip(h, dims=[1])), dims=[1])
|
| 398 |
+
a = self.attn(out).squeeze(-1).masked_fill(~mask, float("-inf"))
|
| 399 |
+
w = torch.softmax(a, dim=1).unsqueeze(-1)
|
| 400 |
+
return self.out((out * w).sum(1))
|
| 401 |
+
|
| 402 |
+
# %% [markdown]
|
| 403 |
+
# ## 6. Dựng enc + heads → nạp trọng số từ ckpt + lấy chuẩn hóa từ ckpt
|
| 404 |
+
|
| 405 |
+
# %%
|
| 406 |
+
N_EMO = len(EMOTIONS5)
|
| 407 |
+
WAVLM_BRANCH = Z_DIM if USE_MAMBA else WAVLM_DIM
|
| 408 |
+
TRUNK_IN = WAVLM_BRANCH + (AUD_DIM if USE_AUDEERING else 0)
|
| 409 |
+
|
| 410 |
+
enc = MambaEncoder(WAVLM_DIM, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL).to(device) \
|
| 411 |
+
if USE_MAMBA else None
|
| 412 |
+
|
| 413 |
+
class EmoHeads(nn.Module):
|
| 414 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 415 |
+
super().__init__()
|
| 416 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 417 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 418 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 419 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 420 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 421 |
+
def forward(self, feat, tgt):
|
| 422 |
+
h = self.trunk(feat)
|
| 423 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 424 |
+
|
| 425 |
+
heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
|
| 426 |
+
hm, hu = heads.load_state_dict(ckpt["heads"], strict=False)
|
| 427 |
+
print(f"🔁 load heads từ ckpt: thiếu {len(hm)} / dư {len(hu)} key (kỳ vọng 0)")
|
| 428 |
+
if USE_MAMBA:
|
| 429 |
+
assert ckpt.get("enc") is not None, "❌ ckpt USE_MAMBA=True nhưng KHÔNG có 'enc' → không inference đúng được."
|
| 430 |
+
em, eu = enc.load_state_dict(ckpt["enc"], strict=False)
|
| 431 |
+
print(f"🔁 load Mamba enc từ ckpt: thiếu {len(em)} / dư {len(eu)} key (kỳ vọng 0)")
|
| 432 |
+
heads.eval()
|
| 433 |
+
if USE_MAMBA:
|
| 434 |
+
enc.eval()
|
| 435 |
+
|
| 436 |
+
# Chuẩn hóa LẤY TỪ ckpt (head dự đoán ở thang z-score này → phải giải chuẩn đúng thang)
|
| 437 |
+
emos_mu = float(ckpt["emos_mu"]); emos_sd = float(ckpt["emos_sd"])
|
| 438 |
+
vad_mu = np.asarray(ckpt["vad_mu"], dtype=np.float32); vad_sd = np.asarray(ckpt["vad_sd"], dtype=np.float32)
|
| 439 |
+
print(f"Chuẩn hóa từ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
|
| 440 |
+
|
| 441 |
+
def wavlm_branch(input_values, attn_mask):
|
| 442 |
+
out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state
|
| 443 |
+
if USE_MAMBA:
|
| 444 |
+
return enc(out, frame_mask(out.shape[1], attn_mask))
|
| 445 |
+
return masked_mean(out, attn_mask)
|
| 446 |
+
|
| 447 |
+
print(f"Trunk input = {TRUNK_IN} (wavlm-branch {WAVLM_BRANCH} [{'Mamba' if USE_MAMBA else 'mean-pool'}] + aud {AUD_DIM if USE_AUDEERING else 0})")
|
| 448 |
+
|
| 449 |
+
# %% [markdown]
|
| 450 |
+
# ## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc; QMOS mượn exp07/UTMOSv2)
|
| 451 |
+
|
| 452 |
+
# %%
|
| 453 |
+
def list_dev():
|
| 454 |
+
with open(DEV_SCP) as f:
|
| 455 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 456 |
+
|
| 457 |
+
dev_names = list_dev()
|
| 458 |
+
if LIMIT_DEV:
|
| 459 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 460 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 461 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 462 |
+
aud_dev = extract_audeering(dev_stems, "dev")
|
| 463 |
+
|
| 464 |
+
def load_exp07_qmos():
|
| 465 |
+
if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
|
| 466 |
+
import csv
|
| 467 |
+
d = {}
|
| 468 |
+
with open(EXP07_ANSWER) as f:
|
| 469 |
+
for row in csv.DictReader(f):
|
| 470 |
+
d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
|
| 471 |
+
print(f"✅ Mượn QMOS exp07 ({EXP07_ANSWER}): {len(d)//2} wav")
|
| 472 |
+
return d
|
| 473 |
+
return None
|
| 474 |
+
|
| 475 |
+
qmos_map = load_exp07_qmos()
|
| 476 |
+
if qmos_map is None:
|
| 477 |
+
print("ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).")
|
| 478 |
+
pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
|
| 479 |
+
import utmosv2
|
| 480 |
+
v2 = utmosv2.create_model(pretrained=True)
|
| 481 |
+
qmos_map = {}
|
| 482 |
+
for n in tqdm(dev_names, desc="UTMOSv2"):
|
| 483 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
|
| 484 |
+
if not os.path.exists(wav):
|
| 485 |
+
continue
|
| 486 |
+
out = v2.predict(input_path=wav)
|
| 487 |
+
qmos_map[n] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
|
| 488 |
+
del v2; torch.cuda.empty_cache() if device == "cuda" else None
|
| 489 |
+
|
| 490 |
+
@torch.no_grad()
|
| 491 |
+
def predict_emotion(sid):
|
| 492 |
+
wave = load_wav(sid)
|
| 493 |
+
if wave is None or (USE_AUDEERING and sid not in aud_dev):
|
| 494 |
+
return None
|
| 495 |
+
iv = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 496 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
|
| 497 |
+
tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
|
| 498 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 499 |
+
fw = wavlm_branch(iv, am)
|
| 500 |
+
feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw
|
| 501 |
+
emos_p, cat_l, vad_p = heads(feat, tgt)
|
| 502 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 503 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 504 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 505 |
+
return emos, cat5, vad3
|
| 506 |
+
|
| 507 |
+
def fmt_cat(p5):
|
| 508 |
+
return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 509 |
+
|
| 510 |
+
def build_answer(out_path):
|
| 511 |
+
n_real = n_def = 0
|
| 512 |
+
with open(out_path, "w") as f:
|
| 513 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 514 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 515 |
+
sid = stem(name)
|
| 516 |
+
pr = predict_emotion(sid)
|
| 517 |
+
if pr is None:
|
| 518 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
|
| 519 |
+
else:
|
| 520 |
+
emos, cat5, vad3 = pr; n_real += 1
|
| 521 |
+
qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
|
| 522 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 523 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
|
| 524 |
+
|
| 525 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 526 |
+
build_answer(answer_path)
|
| 527 |
+
|
| 528 |
+
# %% [markdown]
|
| 529 |
+
# ## 8. Validate + đóng zip
|
| 530 |
+
|
| 531 |
+
# %%
|
| 532 |
+
def validate(path):
|
| 533 |
+
import csv
|
| 534 |
+
with open(path) as f:
|
| 535 |
+
rows = list(csv.reader(f))
|
| 536 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0] and "EMOS" in rows[0], "Header sai"
|
| 537 |
+
for i, r in enumerate(rows[1:], 2):
|
| 538 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 539 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 540 |
+
|
| 541 |
+
validate(answer_path)
|
| 542 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp15_predict.zip answer.txt "
|
| 543 |
+
f"&& unzip -l submission_track2_exp15_predict.zip")
|
| 544 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp15_predict.zip"))
|
| 545 |
+
|
| 546 |
+
# %% [markdown]
|
| 547 |
+
# ## Ghi chú
|
| 548 |
+
# - File này **chỉ inference** — không train, không cần train.csv. Dùng khi đã có `ft_mamba_emotion_full*.pt`.
|
| 549 |
+
# - ⚠️ **Siêu tham số Mamba/heads (MAMBA_DMODEL/LAYERS/DSTATE, TRUNK_HIDDEN, HEAD_HIDDEN) PHẢI khớp lúc train**
|
| 550 |
+
# (ckpt không lưu các số này) — nếu lúc train exp15 bạn đổi, hãy sửa cho khớp ở cell 0, nếu không load_state_dict
|
| 551 |
+
# sẽ lệch key / sai shape.
|
| 552 |
+
# - `USE_MAMBA`, `Z_DIM`, `AUD_DIM`, `UNFREEZE_TOP_LAYERS` thì **đọc tự động từ ckpt**.
|
| 553 |
+
# - QMOS: tốt nhất Add Input `answer.txt` exp07 (0.548); không có thì tự chấm UTMOSv2.
|
| 554 |
+
# - Smoke-test: đặt `LIMIT_DEV=20` chạy thử cho nhanh, OK rồi đặt lại `None` để chấm đủ 2730.
|
track2/exp15_wavlm_mamba_emotion.ipynb
ADDED
|
@@ -0,0 +1,1081 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "5b4b651f",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# VMC2026 Track 2 — exp15 (WavLM FINE-TUNE + MAMBA head cho 5 cột cảm xúc) — Kaggle\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Ý tưởng:** exp08 fine-tune WavLM nhưng vẫn **mean-pool** đặc trưng theo thời gian → 1 vector/wav\n",
|
| 11 |
+
"(vứt bỏ động lực thời gian: lên/xuống giọng, ngắt quãng, run giọng — rất quan trọng cho cảm xúc).\n",
|
| 12 |
+
"exp15 **thay mean-pool bằng MAMBA head** (bộ mã hóa chuỗi học được, độ phức tạp tuyến tính) → kỳ vọng\n",
|
| 13 |
+
"nắm temporal dynamics tốt hơn. Tham khảo: MambaRate (AudioMOS 2025, arXiv:2507.12090).\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"## Kiến trúc (= exp08 đổi đúng 1 chỗ: pool → Mamba)\n",
|
| 16 |
+
"```\n",
|
| 17 |
+
" wav ─► WavLM-large (SAILER warm-start, mở băng N lớp, TRAINABLE) ─► hidden states (B, T, 1024)\n",
|
| 18 |
+
" │ (KHÔNG mean-pool)\n",
|
| 19 |
+
" MambaEncoder (proj 1024→d, Mamba×L 2 chiều,\n",
|
| 20 |
+
" attentive-pool có mask) ─► z (B, Z_DIM)\n",
|
| 21 |
+
" │\n",
|
| 22 |
+
" (tùy chọn) audeering MSP-dim FROZEN [emb|vad3] ──concat──► TRUNK ─┬─► EMOS (+ one-hot target)\n",
|
| 23 |
+
" ├─► CAT (5, soft-CE)\n",
|
| 24 |
+
" └─► VAD (3)\n",
|
| 25 |
+
" QMOS: KHÔNG train ở đây → mượn cột QMOS exp07 (0.548) hoặc UTMOSv2.\n",
|
| 26 |
+
"```\n",
|
| 27 |
+
"- **Cờ `USE_MAMBA`:** True = Mamba head; False = quay về `masked_mean` = **đúng exp08**\n",
|
| 28 |
+
" → đây là **ablation chính cho paper** (\"Mamba temporal head vs mean-pooling\", CÙNG backbone fine-tune).\n",
|
| 29 |
+
"\n",
|
| 30 |
+
"## ⚠️ Đánh đổi / gotcha (đã phòng trong code)\n",
|
| 31 |
+
"- Fine-tune = chạy lại WavLM mỗi epoch (không cache được) → **lần đầu BẮT BUỘC `LIMIT_TRAIN=300`, `LIMIT_DEV=20`**.\n",
|
| 32 |
+
"- `mamba-ssm` khó cài Kaggle → tự fallback **Mamba thuần PyTorch** (vòng-lặp-thời-gian). Bản này khi fine-tune\n",
|
| 33 |
+
" **chậm + nặng RAM hơn** → cap `MAX_SECONDS=6`, `BATCH=2`. OOM/quá chậm → hạ MAX_SECONDS→5, MAMBA_LAYERS→1,\n",
|
| 34 |
+
" hoặc thử cài `mamba-ssm causal-conv1d`.\n",
|
| 35 |
+
"- `layerdrop=0` (tránh CheckpointError khi grad-ckpt — bài học exp12). KHÔNG đụng numpy (lệch ABI).\n",
|
| 36 |
+
"- **Checkpoint lưu CẢ backbone + Mamba + heads mỗi best** (bài học exp08 mất backbone).\n",
|
| 37 |
+
"\n",
|
| 38 |
+
"## 🔁 RESUME (yêu cầu của user): \"nếu có checkpoint thì train TIẾP, không train lại từ đầu\"\n",
|
| 39 |
+
"- Notebook **tự dò** `ft_mamba_emotion_full.pt` trong `/kaggle/input` và `/kaggle/working` (hoặc trỏ tay `RESUME_CKPT`).\n",
|
| 40 |
+
"- Có ckpt đủ (backbone WavLM + Mamba enc + heads) → **nạp lại trạng thái + thống kê chuẩn hóa TỪ ckpt** rồi train tiếp;\n",
|
| 41 |
+
" `best` khởi tạo = điểm VAL của ckpt → chỉ ghi đè khi train tiếp **TỐT HƠN** (không sợ tụt). `RESUME_LR_SCALE<1` để hạ LR.\n",
|
| 42 |
+
"- KHÔNG có ckpt → train mới từ SAILER warm-start như cũ (hành vi exp15 gốc giữ nguyên).\n",
|
| 43 |
+
"\n",
|
| 44 |
+
"**Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input dataset Track 2 (+ Add Input checkpoint cũ nếu muốn resume)\n",
|
| 45 |
+
"→ sửa `DATA_ROOT` → Run All."
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
{
|
| 49 |
+
"cell_type": "markdown",
|
| 50 |
+
"id": "194bcd01",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"source": [
|
| 53 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 54 |
+
]
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"cell_type": "code",
|
| 58 |
+
"execution_count": null,
|
| 59 |
+
"id": "8ed47b3b",
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [],
|
| 62 |
+
"source": [
|
| 63 |
+
"import os, glob\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets/train.csv + wav/ + metadata.csv) ──\n",
|
| 66 |
+
"def find_data_root(search_root=\"/kaggle/input\"):\n",
|
| 67 |
+
" cands = []\n",
|
| 68 |
+
" for train_csv in glob.glob(os.path.join(search_root, \"**\", \"sets\", \"train.csv\"), recursive=True):\n",
|
| 69 |
+
" root = os.path.dirname(os.path.dirname(train_csv)) # .../<root>/sets/train.csv → <root>\n",
|
| 70 |
+
" score = os.path.isdir(os.path.join(root, \"wav\")) + os.path.exists(os.path.join(root, \"metadata.csv\"))\n",
|
| 71 |
+
" cands.append((score, root))\n",
|
| 72 |
+
" cands.sort(reverse=True) # ưu tiên thư mục đủ wav + metadata\n",
|
| 73 |
+
" return cands\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"_cands = find_data_root(\"/kaggle/input\")\n",
|
| 76 |
+
"if _cands:\n",
|
| 77 |
+
" print(\"🔎 Ứng viên DATA_ROOT (điểm cao = đủ wav+metadata):\")\n",
|
| 78 |
+
" for sc, r in _cands:\n",
|
| 79 |
+
" print(f\" [{sc}/2] {r}\")\n",
|
| 80 |
+
" DATA_ROOT = _cands[0][1]\n",
|
| 81 |
+
" print(f\"👉 Tự chọn DATA_ROOT = {DATA_ROOT}\")\n",
|
| 82 |
+
"else:\n",
|
| 83 |
+
" DATA_ROOT = \"/kaggle/input/datasets/minhtoan2\" # dự phòng — sửa tay nếu auto-dò không thấy\n",
|
| 84 |
+
" print(f\"❌ Không thấy sets/train.csv trong /kaggle/input → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)\")\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 87 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header)\n",
|
| 88 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro\n",
|
| 89 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\"\n",
|
| 90 |
+
"\n",
|
| 91 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 92 |
+
"CACHE_DIR = \"/kaggle/working/ft_cache\" # cache audeering (.npz) — WavLM/Mamba KHÔNG cache (đang train)\n",
|
| 93 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 94 |
+
"\n",
|
| 95 |
+
"# (Tùy chọn) tái dùng cache audeering cũ (read-only /kaggle/input → copy sang working)\n",
|
| 96 |
+
"# Dataset cache_exp8: aud_*.npz nằm trong thư mục con archive/ → quét ĐỆ QUY để bắt mọi vị trí.\n",
|
| 97 |
+
"CACHE_INPUT = \"/kaggle/input/cache-exp8\" # << SỬA slug (dataset cache_exp8 → Kaggle đổi _→-); hoặc \"\"\n",
|
| 98 |
+
"if CACHE_INPUT and os.path.isdir(CACHE_INPUT):\n",
|
| 99 |
+
" import shutil\n",
|
| 100 |
+
" _n = 0\n",
|
| 101 |
+
" for _fp in glob.glob(os.path.join(CACHE_INPUT, \"**\", \"aud_*.npz\"), recursive=True):\n",
|
| 102 |
+
" shutil.copy(_fp, os.path.join(CACHE_DIR, os.path.basename(_fp))); _n += 1\n",
|
| 103 |
+
" print(f\"📦 Tái dùng cache: copy {_n} file aud_*.npz (quét đệ quy {CACHE_INPUT})\")\n",
|
| 104 |
+
"else:\n",
|
| 105 |
+
" print(f\"ℹ️ Không thấy CACHE_INPUT={CACHE_INPUT} → sẽ tự trích audeering.\")\n",
|
| 106 |
+
"\n",
|
| 107 |
+
"# Mượn cột QMOS exp07 (0.548). Trỏ answer.txt exp07 nếu có; không thì UTMOSv2.\n",
|
| 108 |
+
"EXP07_ANSWER = \"/kaggle/input/exp07-answer/answer.txt\" # << (tùy chọn)\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"# ── Cờ Mamba (ablation chính) ────────────────────────────────────────────────\n",
|
| 111 |
+
"USE_MAMBA = True # True = Mamba head; False = mean-pool = ĐÚNG exp08\n",
|
| 112 |
+
"\n",
|
| 113 |
+
"# ── Siêu tham số Mamba head ──────────────────────────────────────────────────\n",
|
| 114 |
+
"MAMBA_DMODEL = 256\n",
|
| 115 |
+
"MAMBA_LAYERS = 2\n",
|
| 116 |
+
"MAMBA_DSTATE = 16\n",
|
| 117 |
+
"BIDIRECTIONAL = True\n",
|
| 118 |
+
"Z_DIM = 256 # chiều vector ra sau attentive-pool, thay cho emb WavLM mean-pool\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"# ── Fine-tune / siêu tham số (kế thừa exp08) ─────────────────────────────────\n",
|
| 121 |
+
"DEVICE = \"cuda\"\n",
|
| 122 |
+
"SR = 16000\n",
|
| 123 |
+
"MAX_SECONDS = 6 # giảm từ 8 (exp08) vì Mamba backprop-through-time nặng RAM hơn\n",
|
| 124 |
+
"UNFREEZE_TOP_LAYERS = 6 # số lớp Transformer trên cùng được train (0 = freeze hết)\n",
|
| 125 |
+
"TRUNK_HIDDEN = 512\n",
|
| 126 |
+
"HEAD_HIDDEN = 128\n",
|
| 127 |
+
"DROPOUT = 0.3\n",
|
| 128 |
+
"LR_BACKBONE = 1e-5\n",
|
| 129 |
+
"LR_HEAD = 1e-3 # cho Mamba + trunk + head (train từ đầu)\n",
|
| 130 |
+
"WEIGHT_DECAY = 1e-5\n",
|
| 131 |
+
"EPOCHS = 12\n",
|
| 132 |
+
"PATIENCE = 3\n",
|
| 133 |
+
"BATCH = 2 # nhỏ (backbone to + Mamba); bù bằng ACCUM\n",
|
| 134 |
+
"ACCUM = 16 # effective batch = 32\n",
|
| 135 |
+
"VAL_FRAC = 0.10\n",
|
| 136 |
+
"SEED = 42\n",
|
| 137 |
+
"USE_AMP = True\n",
|
| 138 |
+
"USE_GRAD_CKPT = True\n",
|
| 139 |
+
"USE_AUDEERING = True\n",
|
| 140 |
+
"USE_UNCERTAINTY = True\n",
|
| 141 |
+
"RANK_LAMBDA = 0.3 # 0 = chỉ MSE (cũ). >0 = thêm pairwise ranking loss (tối ưu thẳng SRCC) cho emos/val/aro/dom\n",
|
| 142 |
+
" # ⚠️ ranking cần NHIỀU cặp/batch mới mạnh → BATCH nhỏ (2) thì tác dụng yếu (xem Ghi chú)\n",
|
| 143 |
+
"\n",
|
| 144 |
+
"LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None\n",
|
| 145 |
+
"LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"# ── RESUME — train TIẾP từ checkpoint, KHÔNG train lại từ đầu ─────────────────\n",
|
| 148 |
+
"# Để \"\" + auto-dò: nếu thấy `ft_mamba_emotion_full.pt` (đủ backbone+Mamba+heads) trong /kaggle/input\n",
|
| 149 |
+
"# hoặc /kaggle/working → nạp lại rồi train tiếp. Trỏ tay RESUME_CKPT nếu muốn chỉ định file cụ thể.\n",
|
| 150 |
+
"RESUME_CKPT = \"\" # << \"\" = auto-dò; hoặc \"/kaggle/input/<slug>/ft_mamba_emotion_full.pt\"\n",
|
| 151 |
+
"RESUME_LR_SCALE = 1.0 # <1.0 hạ LR khi train tiếp (vd 0.5 nếu val đã chững)\n",
|
| 152 |
+
"\n",
|
| 153 |
+
"def find_resume_ckpt(explicit):\n",
|
| 154 |
+
" \"\"\"Tìm checkpoint exp15 để train tiếp. Ưu tiên đường dẫn user trỏ; không thì auto-dò.\n",
|
| 155 |
+
" Khớp cả tên bị Kaggle/Windows thêm hậu tố trùng, vd 'ft_mamba_emotion_full (2).pt'.\"\"\"\n",
|
| 156 |
+
" if explicit and os.path.exists(explicit):\n",
|
| 157 |
+
" return explicit\n",
|
| 158 |
+
" for base in [\"/kaggle/input\", \"/kaggle/working\"]:\n",
|
| 159 |
+
" hits = sorted(glob.glob(os.path.join(base, \"**\", \"ft_mamba_emotion_full*.pt\"), recursive=True))\n",
|
| 160 |
+
" if hits:\n",
|
| 161 |
+
" return hits[0]\n",
|
| 162 |
+
" return \"\"\n",
|
| 163 |
+
"\n",
|
| 164 |
+
"RESUME_CKPT = find_resume_ckpt(RESUME_CKPT)\n",
|
| 165 |
+
"RESUME = bool(RESUME_CKPT)\n",
|
| 166 |
+
"print(\"🔁 RESUME =\", RESUME, (\"→ train tiếp từ: \" + RESUME_CKPT) if RESUME else \"(không thấy ckpt → train MỚI từ đầu)\")\n",
|
| 167 |
+
"\n",
|
| 168 |
+
"# Mốc so (exp08 fine-tune + mean-pool — đối thủ trực tiếp của Mamba head)\n",
|
| 169 |
+
"EXP08 = {\"emos\": 0.811, \"val\": 0.659, \"aro\": 0.793, \"dom\": 0.751}\n",
|
| 170 |
+
"\n",
|
| 171 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"]\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"_EMO_ALIAS = {\n",
|
| 174 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 175 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 176 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 177 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 178 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 179 |
+
"}\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"def norm_emotion(label):\n",
|
| 182 |
+
" key = str(label).strip().lower()\n",
|
| 183 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 184 |
+
"\n",
|
| 185 |
+
"def stem(p):\n",
|
| 186 |
+
" return os.path.splitext(os.path.basename(str(p)))[0]\n",
|
| 187 |
+
"\n",
|
| 188 |
+
"print(\"USE_MAMBA =\", USE_MAMBA, \"(False → ra đúng exp08)\")\n",
|
| 189 |
+
"print(\"DATA_ROOT:\", DATA_ROOT)\n",
|
| 190 |
+
"for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:\n",
|
| 191 |
+
" print((\" ✅ \" if os.path.exists(p) else \" ❌ THIẾU \") + p)\n",
|
| 192 |
+
"print(f\"Fine-tune: mở băng {UNFREEZE_TOP_LAYERS} lớp · BATCH {BATCH}×ACCUM {ACCUM} · MAX {MAX_SECONDS}s\")"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "markdown",
|
| 197 |
+
"id": "c8010473",
|
| 198 |
+
"metadata": {},
|
| 199 |
+
"source": [
|
| 200 |
+
"## 1. Cài đặt + tải code SAILER (clone + sys.path)"
|
| 201 |
+
]
|
| 202 |
+
},
|
| 203 |
+
{
|
| 204 |
+
"cell_type": "code",
|
| 205 |
+
"execution_count": null,
|
| 206 |
+
"id": "1b8d9fad",
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": [
|
| 210 |
+
"import sys, subprocess\n",
|
| 211 |
+
"\n",
|
| 212 |
+
"def pip_install(*pkgs):\n",
|
| 213 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *pkgs], check=True)\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"pip_install(\"loralib\", \"speechbrain\", \"speechmos\", \"librosa\", \"soundfile\",\n",
|
| 216 |
+
" \"scipy\", \"scikit-learn\", \"pandas\", \"tqdm\")\n",
|
| 217 |
+
"\n",
|
| 218 |
+
"# Cài kernel CUDA Mamba (nhanh + nhẹ RAM hơn bản thuần PyTorch nhiều). Build hay lỗi/chậm trên Kaggle\n",
|
| 219 |
+
"# → bọc try/except: lỗi thì BỎ QUA, mục 6a tự fallback Mamba thuần PyTorch. KHÔNG để chết notebook.\n",
|
| 220 |
+
"INSTALL_MAMBA_SSM = True # đặt False nếu muốn BỎ QUA, dùng thẳng Mamba thuần PyTorch\n",
|
| 221 |
+
"if INSTALL_MAMBA_SSM and USE_MAMBA:\n",
|
| 222 |
+
" try:\n",
|
| 223 |
+
" # --no-build-isolation cho CẢ HAI → dùng torch+CUDA sẵn có của Kaggle để biên dịch (đừng kéo torch khác).\n",
|
| 224 |
+
" # Cần ninja để build nhanh. -q ẩn log nên bước này có thể \"treo\" vài phút khi đang compile — bình thường.\n",
|
| 225 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"ninja\"], check=True)\n",
|
| 226 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n",
|
| 227 |
+
" \"--no-build-isolation\", \"causal-conv1d>=1.2.0\"], check=True)\n",
|
| 228 |
+
" subprocess.run([sys.executable, \"-m\", \"pip\", \"install\", \"-q\",\n",
|
| 229 |
+
" \"--no-build-isolation\", \"mamba-ssm\"], check=True)\n",
|
| 230 |
+
" print(\"✅ Cài mamba-ssm + causal-conv1d xong (sẽ dùng kernel CUDA nếu import được).\")\n",
|
| 231 |
+
" except Exception as e:\n",
|
| 232 |
+
" print(\"⚠️ Cài mamba-ssm thất bại:\", repr(e), \"→ dùng Mamba thuần PyTorch (chậm hơn).\")\n",
|
| 233 |
+
" print(\" ℹ️ Vẫn chạy bình thường. Nếu chạy THẬT (LIMIT=None) quá chậm → xem Ghi chú cuối notebook.\")\n",
|
| 234 |
+
"\n",
|
| 235 |
+
"REPO_DIR = \"/kaggle/working/vox-profile-release\"\n",
|
| 236 |
+
"if not os.path.exists(REPO_DIR):\n",
|
| 237 |
+
" subprocess.run([\"git\", \"clone\", \"--depth\", \"1\",\n",
|
| 238 |
+
" \"https://github.com/tiantiaf0627/vox-profile-release.git\", REPO_DIR], check=True)\n",
|
| 239 |
+
"if REPO_DIR not in sys.path:\n",
|
| 240 |
+
" sys.path.insert(0, REPO_DIR)"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"cell_type": "markdown",
|
| 245 |
+
"id": "598d74d9",
|
| 246 |
+
"metadata": {},
|
| 247 |
+
"source": [
|
| 248 |
+
"## 2. Nạp SAILER → lấy backbone WavLM bên trong để FINE-TUNE (warm-start)"
|
| 249 |
+
]
|
| 250 |
+
},
|
| 251 |
+
{
|
| 252 |
+
"cell_type": "code",
|
| 253 |
+
"execution_count": null,
|
| 254 |
+
"id": "5346a63d",
|
| 255 |
+
"metadata": {
|
| 256 |
+
"lines_to_next_cell": 1
|
| 257 |
+
},
|
| 258 |
+
"outputs": [],
|
| 259 |
+
"source": [
|
| 260 |
+
"import torch\n",
|
| 261 |
+
"import torch.nn as nn\n",
|
| 262 |
+
"import torch.nn.functional as F\n",
|
| 263 |
+
"\n",
|
| 264 |
+
"device = DEVICE if torch.cuda.is_available() else \"cpu\"\n",
|
| 265 |
+
"print(\"Device:\", device, (\"✅ \" + torch.cuda.get_device_name(0)) if device == \"cuda\" else \"⚠️ CPU (rất chậm!)\")\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"def find_hf_backbone(module):\n",
|
| 268 |
+
" \"\"\"Tìm submodule kiểu HF WavLM backbone: có .feature_extractor và .encoder.layers.\"\"\"\n",
|
| 269 |
+
" cands = []\n",
|
| 270 |
+
" for name, m in module.named_modules():\n",
|
| 271 |
+
" enc = getattr(m, \"encoder\", None)\n",
|
| 272 |
+
" if getattr(m, \"feature_extractor\", None) is not None and enc is not None \\\n",
|
| 273 |
+
" and getattr(enc, \"layers\", None) is not None:\n",
|
| 274 |
+
" cands.append((name, m))\n",
|
| 275 |
+
" if not cands:\n",
|
| 276 |
+
" return None, None\n",
|
| 277 |
+
" cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)\n",
|
| 278 |
+
" return cands[0]\n",
|
| 279 |
+
"\n",
|
| 280 |
+
"wavlm = None\n",
|
| 281 |
+
"try:\n",
|
| 282 |
+
" from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402\n",
|
| 283 |
+
" _wrapper = WavLMWrapper.from_pretrained(\"tiantiaf/wavlm-large-categorical-emotion\")\n",
|
| 284 |
+
" name, wavlm = find_hf_backbone(_wrapper)\n",
|
| 285 |
+
" if wavlm is not None:\n",
|
| 286 |
+
" print(f\"✅ Warm-start SAILER: backbone WavLM tại '.{name}' \"\n",
|
| 287 |
+
" f\"({sum(p.numel() for p in wavlm.parameters())/1e6:.0f}M params)\")\n",
|
| 288 |
+
" else:\n",
|
| 289 |
+
" print(\"⚠️ Không tìm thấy backbone HF trong wrapper SAILER → fallback WavLM trắng.\")\n",
|
| 290 |
+
"except Exception as e:\n",
|
| 291 |
+
" print(\"⚠️ Lỗi nạp SAILER wrapper:\", repr(e), \"→ fallback WavLM trắng.\")\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"if wavlm is None:\n",
|
| 294 |
+
" from transformers import WavLMModel\n",
|
| 295 |
+
" wavlm = WavLMModel.from_pretrained(\"microsoft/wavlm-large\")\n",
|
| 296 |
+
" print(\"ℹ️ Fallback: microsoft/wavlm-large (KHÔNG warm-start SAILER).\")\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"wavlm = wavlm.to(device)\n",
|
| 299 |
+
"WAVLM_DIM = int(wavlm.config.hidden_size)\n",
|
| 300 |
+
"wavlm.config.layerdrop = 0.0 # ⚠️ tránh CheckpointError khi grad-ckpt (bài học exp12)\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"# ── RESUME: nạp trọng số backbone đã fine-tune từ checkpoint (đè lên warm-start SAILER) ──\n",
|
| 303 |
+
"resume_ckpt = None\n",
|
| 304 |
+
"if RESUME:\n",
|
| 305 |
+
" resume_ckpt = torch.load(RESUME_CKPT, map_location=\"cpu\", weights_only=False) # ckpt có numpy → cần False\n",
|
| 306 |
+
" assert \"wavlm\" in resume_ckpt, (\"❌ Checkpoint KHÔNG có 'wavlm' (backbone) → không resume được. \"\n",
|
| 307 |
+
" \"Dùng file ft_mamba_emotion_full.pt do exp15 lưu.\")\n",
|
| 308 |
+
" if resume_ckpt.get(\"USE_MAMBA\", USE_MAMBA) != USE_MAMBA:\n",
|
| 309 |
+
" print(f\" ⚠️ ckpt USE_MAMBA={resume_ckpt.get('USE_MAMBA')} ≠ cấu hình hiện tại {USE_MAMBA} → kiến trúc LỆCH! \"\n",
|
| 310 |
+
" \"Đặt USE_MAMBA cho khớp ckpt.\")\n",
|
| 311 |
+
" miss, unexp = wavlm.load_state_dict(resume_ckpt[\"wavlm\"], strict=False)\n",
|
| 312 |
+
" print(f\"🔁 RESUME load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0). keys ckpt:\", list(resume_ckpt.keys()))\n",
|
| 313 |
+
" if len(miss) > 20 or len(unexp) > 20:\n",
|
| 314 |
+
" print(\" ⚠️ Lệch key nhiều → kiểm tra UNFREEZE_TOP_LAYERS / backbone có khớp ckpt không.\")\n",
|
| 315 |
+
"\n",
|
| 316 |
+
"# ── Đóng băng partial: feature-extractor + tất cả trừ UNFREEZE_TOP_LAYERS lớp trên ──\n",
|
| 317 |
+
"for p in wavlm.parameters():\n",
|
| 318 |
+
" p.requires_grad = False\n",
|
| 319 |
+
"enc_layers = wavlm.encoder.layers\n",
|
| 320 |
+
"n_layers = len(enc_layers)\n",
|
| 321 |
+
"for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:\n",
|
| 322 |
+
" for p in layer.parameters():\n",
|
| 323 |
+
" p.requires_grad = True\n",
|
| 324 |
+
"n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)\n",
|
| 325 |
+
"print(f\"WavLM: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {WAVLM_DIM})\")\n",
|
| 326 |
+
"\n",
|
| 327 |
+
"if USE_GRAD_CKPT:\n",
|
| 328 |
+
" wavlm.gradient_checkpointing_enable()\n",
|
| 329 |
+
" if hasattr(wavlm, \"enable_input_require_grads\"):\n",
|
| 330 |
+
" wavlm.enable_input_require_grads()\n",
|
| 331 |
+
"\n",
|
| 332 |
+
"def frame_mask(T, attn_mask):\n",
|
| 333 |
+
" \"\"\"attn_mask (B, Lwav) → frame-mask (B, T) bool (True=frame thật). Khớp downsample của WavLM.\"\"\"\n",
|
| 334 |
+
" if attn_mask is None:\n",
|
| 335 |
+
" return torch.ones((1, T), dtype=torch.bool, device=device)\n",
|
| 336 |
+
" try:\n",
|
| 337 |
+
" fm = wavlm._get_feature_vector_attention_mask(T, attn_mask)\n",
|
| 338 |
+
" return fm.bool()\n",
|
| 339 |
+
" except Exception:\n",
|
| 340 |
+
" return torch.ones((attn_mask.shape[0], T), dtype=torch.bool, device=attn_mask.device)\n",
|
| 341 |
+
"\n",
|
| 342 |
+
"def masked_mean(hidden, attn_mask):\n",
|
| 343 |
+
" \"\"\"Mean-pool theo thời gian bỏ pad (đường exp08 khi USE_MAMBA=False).\"\"\"\n",
|
| 344 |
+
" if attn_mask is None:\n",
|
| 345 |
+
" return hidden.mean(dim=1)\n",
|
| 346 |
+
" fm = frame_mask(hidden.shape[1], attn_mask).unsqueeze(-1).to(hidden.dtype)\n",
|
| 347 |
+
" return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)"
|
| 348 |
+
]
|
| 349 |
+
},
|
| 350 |
+
{
|
| 351 |
+
"cell_type": "markdown",
|
| 352 |
+
"id": "c72d5983",
|
| 353 |
+
"metadata": {},
|
| 354 |
+
"source": [
|
| 355 |
+
"## 3. Nạp audeering MSP-dim (FROZEN) — đặc trưng phụ (như exp08)"
|
| 356 |
+
]
|
| 357 |
+
},
|
| 358 |
+
{
|
| 359 |
+
"cell_type": "code",
|
| 360 |
+
"execution_count": null,
|
| 361 |
+
"id": "d967397d",
|
| 362 |
+
"metadata": {},
|
| 363 |
+
"outputs": [],
|
| 364 |
+
"source": [
|
| 365 |
+
"AUD_DIM = 0\n",
|
| 366 |
+
"aud_backbone = aud_head = aud_proc = None\n",
|
| 367 |
+
"if USE_AUDEERING:\n",
|
| 368 |
+
" from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor\n",
|
| 369 |
+
" from huggingface_hub import hf_hub_download\n",
|
| 370 |
+
" AUD_NAME = \"audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim\"\n",
|
| 371 |
+
" aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)\n",
|
| 372 |
+
" aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)\n",
|
| 373 |
+
" aud_backbone = Wav2Vec2Model(aud_cfg)\n",
|
| 374 |
+
" try:\n",
|
| 375 |
+
" _sd = __import__(\"safetensors.torch\", fromlist=[\"load_file\"]).load_file(\n",
|
| 376 |
+
" hf_hub_download(AUD_NAME, \"model.safetensors\"))\n",
|
| 377 |
+
" except Exception:\n",
|
| 378 |
+
" _sd = torch.load(hf_hub_download(AUD_NAME, \"pytorch_model.bin\"), map_location=\"cpu\")\n",
|
| 379 |
+
" bb_sd = {k[len(\"wav2vec2.\"):]: v for k, v in _sd.items() if k.startswith(\"wav2vec2.\")}\n",
|
| 380 |
+
" aud_backbone.load_state_dict(bb_sd, strict=False)\n",
|
| 381 |
+
" _hid = _sd[\"classifier.dense.weight\"].shape[0]\n",
|
| 382 |
+
" _out = _sd[\"classifier.out_proj.weight\"].shape[0]\n",
|
| 383 |
+
" aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))\n",
|
| 384 |
+
" aud_head[0].weight.data.copy_(_sd[\"classifier.dense.weight\"]); aud_head[0].bias.data.copy_(_sd[\"classifier.dense.bias\"])\n",
|
| 385 |
+
" aud_head[2].weight.data.copy_(_sd[\"classifier.out_proj.weight\"]); aud_head[2].bias.data.copy_(_sd[\"classifier.out_proj.bias\"])\n",
|
| 386 |
+
" aud_backbone = aud_backbone.to(device).eval()\n",
|
| 387 |
+
" aud_head = aud_head.to(device).eval()\n",
|
| 388 |
+
" AUD_DIM = _hid + 3\n",
|
| 389 |
+
" print(f\"✅ audeering frozen (đặc trưng phụ {AUD_DIM}-D = emb {_hid} + vad 3)\")"
|
| 390 |
+
]
|
| 391 |
+
},
|
| 392 |
+
{
|
| 393 |
+
"cell_type": "code",
|
| 394 |
+
"execution_count": null,
|
| 395 |
+
"id": "1a5f1592",
|
| 396 |
+
"metadata": {
|
| 397 |
+
"lines_to_next_cell": 1
|
| 398 |
+
},
|
| 399 |
+
"outputs": [],
|
| 400 |
+
"source": [
|
| 401 |
+
"import numpy as np\n",
|
| 402 |
+
"import librosa\n",
|
| 403 |
+
"from tqdm.auto import tqdm\n",
|
| 404 |
+
"\n",
|
| 405 |
+
"def load_wav(name_or_stem):\n",
|
| 406 |
+
" p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(\n",
|
| 407 |
+
" WAV_DIR, name_or_stem if str(name_or_stem).endswith(\".wav\") else str(name_or_stem) + \".wav\")\n",
|
| 408 |
+
" if not os.path.exists(p):\n",
|
| 409 |
+
" return None\n",
|
| 410 |
+
" wave, _ = librosa.load(p, sr=SR, mono=True)\n",
|
| 411 |
+
" return wave[: MAX_SECONDS * SR].astype(np.float32)\n",
|
| 412 |
+
"\n",
|
| 413 |
+
"@torch.no_grad()\n",
|
| 414 |
+
"def extract_audeering(stems, tag):\n",
|
| 415 |
+
" if not USE_AUDEERING:\n",
|
| 416 |
+
" return {}\n",
|
| 417 |
+
" cache_path = os.path.join(CACHE_DIR, f\"aud_{tag}.npz\")\n",
|
| 418 |
+
" store = {}\n",
|
| 419 |
+
" if os.path.exists(cache_path):\n",
|
| 420 |
+
" z = np.load(cache_path, allow_pickle=True)\n",
|
| 421 |
+
" store = {k: z[k] for k in z.files}\n",
|
| 422 |
+
" print(f\"[aud/{tag}] nạp cache: {len(store)}\")\n",
|
| 423 |
+
" todo = [s for s in stems if s not in store]\n",
|
| 424 |
+
" for i, s in enumerate(tqdm(todo, desc=f\"audeering {tag}\")):\n",
|
| 425 |
+
" wave = load_wav(s)\n",
|
| 426 |
+
" if wave is None:\n",
|
| 427 |
+
" continue\n",
|
| 428 |
+
" x = aud_proc(wave, sampling_rate=SR).input_values[0]\n",
|
| 429 |
+
" x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)\n",
|
| 430 |
+
" h = aud_backbone(x)[0].mean(dim=1)\n",
|
| 431 |
+
" out = aud_head(h)[0].cpu().numpy() # [arousal, dominance, valence] ∈[0,1]\n",
|
| 432 |
+
" vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]\n",
|
| 433 |
+
" store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)\n",
|
| 434 |
+
" if (i + 1) % 500 == 0:\n",
|
| 435 |
+
" np.savez(cache_path, **store)\n",
|
| 436 |
+
" if todo:\n",
|
| 437 |
+
" np.savez(cache_path, **store)\n",
|
| 438 |
+
" return store"
|
| 439 |
+
]
|
| 440 |
+
},
|
| 441 |
+
{
|
| 442 |
+
"cell_type": "markdown",
|
| 443 |
+
"id": "50717e09",
|
| 444 |
+
"metadata": {},
|
| 445 |
+
"source": [
|
| 446 |
+
"## 4. Đọc & gộp nhãn theo wavID (EMOS / VAD / CAT) — như exp08"
|
| 447 |
+
]
|
| 448 |
+
},
|
| 449 |
+
{
|
| 450 |
+
"cell_type": "code",
|
| 451 |
+
"execution_count": null,
|
| 452 |
+
"id": "b5c3e935",
|
| 453 |
+
"metadata": {},
|
| 454 |
+
"outputs": [],
|
| 455 |
+
"source": [
|
| 456 |
+
"import pandas as pd\n",
|
| 457 |
+
"\n",
|
| 458 |
+
"def load_target_emotions():\n",
|
| 459 |
+
" tgt = {}\n",
|
| 460 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 461 |
+
" for ln in f:\n",
|
| 462 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 463 |
+
" if len(parts) >= 2:\n",
|
| 464 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 465 |
+
" return tgt\n",
|
| 466 |
+
"\n",
|
| 467 |
+
"def _col(cols_map, *names, df=None, default_idx=None):\n",
|
| 468 |
+
" for n in names:\n",
|
| 469 |
+
" if n in cols_map:\n",
|
| 470 |
+
" return cols_map[n]\n",
|
| 471 |
+
" return list(df.columns)[default_idx] if default_idx is not None else None\n",
|
| 472 |
+
"\n",
|
| 473 |
+
"def parse_emocat_votes(cell):\n",
|
| 474 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 475 |
+
" for tok in str(cell).replace(\"/\", \",\").replace(\";\", \",\").replace(\"|\", \",\").replace(\" \", \",\").split(\",\"):\n",
|
| 476 |
+
" e = norm_emotion(tok)\n",
|
| 477 |
+
" if e in EMOTIONS5:\n",
|
| 478 |
+
" v[EMOTIONS5.index(e)] += 1.0\n",
|
| 479 |
+
" return v\n",
|
| 480 |
+
"\n",
|
| 481 |
+
"def load_train_labels():\n",
|
| 482 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 483 |
+
" cols = {c.lower().strip(): c for c in df.columns}\n",
|
| 484 |
+
" wav_col = _col(cols, \"wavid\", \"wav\", df=df, default_idx=1)\n",
|
| 485 |
+
" emos_col = _col(cols, \"emos\", \"emo\", \"emomos\")\n",
|
| 486 |
+
" val_col = _col(cols, \"val\", \"valence\"); aro_col = _col(cols, \"aro\", \"arousal\"); dom_col = _col(cols, \"dom\", \"dominance\")\n",
|
| 487 |
+
" cat_col = _col(cols, \"emocat\", \"cat\", \"emotion\")\n",
|
| 488 |
+
" assert emos_col, f\"Không thấy cột eMOS (cột: {list(df.columns)})\"\n",
|
| 489 |
+
" df[\"_stem\"] = df[wav_col].map(stem)\n",
|
| 490 |
+
" rows = []\n",
|
| 491 |
+
" for sid, g in df.groupby(\"_stem\"):\n",
|
| 492 |
+
" rec = {\"wavID\": sid, \"emos\": float(g[emos_col].mean())}\n",
|
| 493 |
+
" rec[\"val\"] = float(g[val_col].mean()) if val_col else np.nan\n",
|
| 494 |
+
" rec[\"aro\"] = float(g[aro_col].mean()) if aro_col else np.nan\n",
|
| 495 |
+
" rec[\"dom\"] = float(g[dom_col].mean()) if dom_col else np.nan\n",
|
| 496 |
+
" votes = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 497 |
+
" if cat_col:\n",
|
| 498 |
+
" for cell in g[cat_col]:\n",
|
| 499 |
+
" votes += parse_emocat_votes(cell)\n",
|
| 500 |
+
" s = votes.sum()\n",
|
| 501 |
+
" cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)\n",
|
| 502 |
+
" for i in range(len(EMOTIONS5)):\n",
|
| 503 |
+
" rec[f\"cat{i}\"] = float(cat[i])\n",
|
| 504 |
+
" rows.append(rec)\n",
|
| 505 |
+
" return pd.DataFrame(rows)\n",
|
| 506 |
+
"\n",
|
| 507 |
+
"target_map = load_target_emotions()\n",
|
| 508 |
+
"train_df = load_train_labels()\n",
|
| 509 |
+
"HAS_VAD = bool(train_df[\"val\"].notna().any())\n",
|
| 510 |
+
"print(f\"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}\")"
|
| 511 |
+
]
|
| 512 |
+
},
|
| 513 |
+
{
|
| 514 |
+
"cell_type": "markdown",
|
| 515 |
+
"id": "b5ab79d3",
|
| 516 |
+
"metadata": {},
|
| 517 |
+
"source": [
|
| 518 |
+
"## 5. Dataset / DataLoader (load wav theo batch — KHÔNG cache WavLM vì đang train)"
|
| 519 |
+
]
|
| 520 |
+
},
|
| 521 |
+
{
|
| 522 |
+
"cell_type": "code",
|
| 523 |
+
"execution_count": null,
|
| 524 |
+
"id": "9989f142",
|
| 525 |
+
"metadata": {},
|
| 526 |
+
"outputs": [],
|
| 527 |
+
"source": [
|
| 528 |
+
"from torch.utils.data import Dataset, DataLoader\n",
|
| 529 |
+
"\n",
|
| 530 |
+
"train_stems = [s for s in train_df[\"wavID\"] if target_map.get(s) is not None]\n",
|
| 531 |
+
"if LIMIT_TRAIN:\n",
|
| 532 |
+
" train_stems = train_stems[:LIMIT_TRAIN]\n",
|
| 533 |
+
"aud_tr = extract_audeering(train_stems, \"train\")\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"lab = train_df.set_index(\"wavID\")\n",
|
| 536 |
+
"\n",
|
| 537 |
+
"def _zfit(arr):\n",
|
| 538 |
+
" a = np.asarray(arr, dtype=np.float32)\n",
|
| 539 |
+
" return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)\n",
|
| 540 |
+
"\n",
|
| 541 |
+
"if RESUME and resume_ckpt is not None:\n",
|
| 542 |
+
" # QUAN TRỌNG: lấy chuẩn hóa TỪ ckpt (head đã train theo thang này) — KHÔNG tính lại để khỏi lệch thang\n",
|
| 543 |
+
" emos_mu = float(resume_ckpt[\"emos_mu\"]); emos_sd = float(resume_ckpt[\"emos_sd\"])\n",
|
| 544 |
+
" vad_mu = np.asarray(resume_ckpt[\"vad_mu\"], dtype=np.float32)\n",
|
| 545 |
+
" vad_sd = np.asarray(resume_ckpt[\"vad_sd\"], dtype=np.float32)\n",
|
| 546 |
+
" print(f\"🔁 RESUME: dùng chuẩn hóa TỪ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}\")\n",
|
| 547 |
+
"else:\n",
|
| 548 |
+
" emos_mu, emos_sd = _zfit([lab.loc[s, \"emos\"] for s in train_stems])\n",
|
| 549 |
+
" if HAS_VAD:\n",
|
| 550 |
+
" vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
|
| 551 |
+
" vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in [\"val\", \"aro\", \"dom\"]], dtype=np.float32)\n",
|
| 552 |
+
" else:\n",
|
| 553 |
+
" vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)\n",
|
| 554 |
+
"\n",
|
| 555 |
+
"def onehot_target(tgt):\n",
|
| 556 |
+
" v = np.zeros(len(EMOTIONS5), dtype=np.float32)\n",
|
| 557 |
+
" if tgt in EMOTIONS5:\n",
|
| 558 |
+
" v[EMOTIONS5.index(tgt)] = 1.0\n",
|
| 559 |
+
" return v\n",
|
| 560 |
+
"\n",
|
| 561 |
+
"class EmoDataset(Dataset):\n",
|
| 562 |
+
" def __init__(self, stems):\n",
|
| 563 |
+
" self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]\n",
|
| 564 |
+
" def __len__(self):\n",
|
| 565 |
+
" return len(self.stems)\n",
|
| 566 |
+
" def __getitem__(self, i):\n",
|
| 567 |
+
" s = self.stems[i]\n",
|
| 568 |
+
" wave = load_wav(s)\n",
|
| 569 |
+
" emos = (float(lab.loc[s, \"emos\"]) - emos_mu) / emos_sd\n",
|
| 570 |
+
" if HAS_VAD:\n",
|
| 571 |
+
" vad = (np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32) - vad_mu) / vad_sd\n",
|
| 572 |
+
" else:\n",
|
| 573 |
+
" vad = np.zeros(3, dtype=np.float32)\n",
|
| 574 |
+
" cat = np.array([lab.loc[s, f\"cat{j}\"] for j in range(len(EMOTIONS5))], dtype=np.float32)\n",
|
| 575 |
+
" aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)\n",
|
| 576 |
+
" return {\"wave\": wave, \"tgt\": onehot_target(target_map.get(s)), \"aud\": aud,\n",
|
| 577 |
+
" \"emos\": np.float32(emos), \"vad\": vad, \"cat\": cat,\n",
|
| 578 |
+
" \"emos_raw\": np.float32(lab.loc[s, \"emos\"]),\n",
|
| 579 |
+
" \"vad_raw\": np.array([lab.loc[s, \"val\"], lab.loc[s, \"aro\"], lab.loc[s, \"dom\"]], np.float32)}\n",
|
| 580 |
+
"\n",
|
| 581 |
+
"def collate(batch):\n",
|
| 582 |
+
" L = max(len(b[\"wave\"]) for b in batch)\n",
|
| 583 |
+
" waves = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 584 |
+
" mask = np.zeros((len(batch), L), dtype=np.float32)\n",
|
| 585 |
+
" for i, b in enumerate(batch):\n",
|
| 586 |
+
" waves[i, : len(b[\"wave\"])] = b[\"wave\"]; mask[i, : len(b[\"wave\"])] = 1.0\n",
|
| 587 |
+
" return {\n",
|
| 588 |
+
" \"input_values\": torch.from_numpy(waves), \"attn_mask\": torch.from_numpy(mask).long(),\n",
|
| 589 |
+
" \"tgt\": torch.from_numpy(np.stack([b[\"tgt\"] for b in batch])),\n",
|
| 590 |
+
" \"aud\": torch.from_numpy(np.stack([b[\"aud\"] for b in batch])) if USE_AUDEERING else None,\n",
|
| 591 |
+
" \"emos\": torch.from_numpy(np.stack([b[\"emos\"] for b in batch])).unsqueeze(1),\n",
|
| 592 |
+
" \"vad\": torch.from_numpy(np.stack([b[\"vad\"] for b in batch])),\n",
|
| 593 |
+
" \"cat\": torch.from_numpy(np.stack([b[\"cat\"] for b in batch])),\n",
|
| 594 |
+
" \"emos_raw\": np.stack([b[\"emos_raw\"] for b in batch]),\n",
|
| 595 |
+
" \"vad_raw\": np.stack([b[\"vad_raw\"] for b in batch]),\n",
|
| 596 |
+
" }\n",
|
| 597 |
+
"\n",
|
| 598 |
+
"from sklearn.model_selection import train_test_split\n",
|
| 599 |
+
"ds = EmoDataset(train_stems)\n",
|
| 600 |
+
"print(\"Dataset hợp lệ:\", len(ds), \"wav\")\n",
|
| 601 |
+
"tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)\n",
|
| 602 |
+
"tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)\n",
|
| 603 |
+
"va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)"
|
| 604 |
+
]
|
| 605 |
+
},
|
| 606 |
+
{
|
| 607 |
+
"cell_type": "markdown",
|
| 608 |
+
"id": "6006ec6c",
|
| 609 |
+
"metadata": {},
|
| 610 |
+
"source": [
|
| 611 |
+
"## 6a. Khối MAMBA (thuần PyTorch, fallback nếu không có `mamba-ssm`)\n",
|
| 612 |
+
"Theo \"mamba-minimal\" — đúng công thức selective SSM, chỉ chậm hơn kernel CUDA. Chạy trong fp32 cho ổn định."
|
| 613 |
+
]
|
| 614 |
+
},
|
| 615 |
+
{
|
| 616 |
+
"cell_type": "code",
|
| 617 |
+
"execution_count": null,
|
| 618 |
+
"id": "b9089952",
|
| 619 |
+
"metadata": {
|
| 620 |
+
"lines_to_next_cell": 1
|
| 621 |
+
},
|
| 622 |
+
"outputs": [],
|
| 623 |
+
"source": [
|
| 624 |
+
"import math\n",
|
| 625 |
+
"\n",
|
| 626 |
+
"try:\n",
|
| 627 |
+
" from mamba_ssm import Mamba as _OfficialMamba\n",
|
| 628 |
+
" _HAS_MAMBA_SSM = True\n",
|
| 629 |
+
" print(\"✅ Dùng mamba-ssm (CUDA kernel)\")\n",
|
| 630 |
+
"except Exception:\n",
|
| 631 |
+
" _HAS_MAMBA_SSM = False\n",
|
| 632 |
+
" print(\"ℹ️ Không có mamba-ssm → Mamba thuần PyTorch (chậm hơn khi fine-tune)\")\n",
|
| 633 |
+
"\n",
|
| 634 |
+
"class MambaBlockTorch(nn.Module):\n",
|
| 635 |
+
" def __init__(self, d_model, d_state=16, d_conv=4, expand=2):\n",
|
| 636 |
+
" super().__init__()\n",
|
| 637 |
+
" self.d_inner = expand * d_model\n",
|
| 638 |
+
" self.dt_rank = math.ceil(d_model / 16)\n",
|
| 639 |
+
" self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)\n",
|
| 640 |
+
" self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,\n",
|
| 641 |
+
" groups=self.d_inner, padding=d_conv - 1, bias=True)\n",
|
| 642 |
+
" self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)\n",
|
| 643 |
+
" self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)\n",
|
| 644 |
+
" A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)\n",
|
| 645 |
+
" self.A_log = nn.Parameter(torch.log(A))\n",
|
| 646 |
+
" self.D = nn.Parameter(torch.ones(self.d_inner))\n",
|
| 647 |
+
" self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)\n",
|
| 648 |
+
" self.d_state = d_state\n",
|
| 649 |
+
"\n",
|
| 650 |
+
" def forward(self, x): # x: (B, L, d_model)\n",
|
| 651 |
+
" B, L, _ = x.shape\n",
|
| 652 |
+
" xin, z = self.in_proj(x).chunk(2, dim=-1)\n",
|
| 653 |
+
" xin = xin.transpose(1, 2)\n",
|
| 654 |
+
" xin = self.conv1d(xin)[..., :L].transpose(1, 2)\n",
|
| 655 |
+
" xin = F.silu(xin)\n",
|
| 656 |
+
" y = self._ssm(xin) * F.silu(z)\n",
|
| 657 |
+
" return self.out_proj(y)\n",
|
| 658 |
+
"\n",
|
| 659 |
+
" def _ssm(self, x):\n",
|
| 660 |
+
" A = -torch.exp(self.A_log)\n",
|
| 661 |
+
" delta, Bm, Cm = torch.split(self.x_proj(x), [self.dt_rank, self.d_state, self.d_state], dim=-1)\n",
|
| 662 |
+
" delta = F.softplus(self.dt_proj(delta))\n",
|
| 663 |
+
" dA = torch.exp(delta.unsqueeze(-1) * A)\n",
|
| 664 |
+
" dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1)\n",
|
| 665 |
+
" h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)\n",
|
| 666 |
+
" ys = []\n",
|
| 667 |
+
" for t in range(x.shape[1]):\n",
|
| 668 |
+
" h = dA[:, t] * h + dB_x[:, t]\n",
|
| 669 |
+
" ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1))\n",
|
| 670 |
+
" return torch.stack(ys, dim=1) + x * self.D\n",
|
| 671 |
+
"\n",
|
| 672 |
+
"class MambaLayer(nn.Module):\n",
|
| 673 |
+
" def __init__(self, d_model, d_state):\n",
|
| 674 |
+
" super().__init__()\n",
|
| 675 |
+
" self.norm = nn.LayerNorm(d_model)\n",
|
| 676 |
+
" self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2) \\\n",
|
| 677 |
+
" if _HAS_MAMBA_SSM else MambaBlockTorch(d_model, d_state=d_state)\n",
|
| 678 |
+
" def forward(self, x):\n",
|
| 679 |
+
" return x + self.mix(self.norm(x))\n",
|
| 680 |
+
"\n",
|
| 681 |
+
"class MambaEncoder(nn.Module):\n",
|
| 682 |
+
" \"\"\"1024 → d_model → [Mamba ×L] (2 chiều) → attentive-pool (có mask) → Z_DIM.\"\"\"\n",
|
| 683 |
+
" def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):\n",
|
| 684 |
+
" super().__init__()\n",
|
| 685 |
+
" self.bidir = bidir\n",
|
| 686 |
+
" self.proj = nn.Linear(d_in, d_model)\n",
|
| 687 |
+
" self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
|
| 688 |
+
" if bidir:\n",
|
| 689 |
+
" self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])\n",
|
| 690 |
+
" self.attn = nn.Linear(d_model, 1)\n",
|
| 691 |
+
" self.out = nn.Linear(d_model, z_dim)\n",
|
| 692 |
+
"\n",
|
| 693 |
+
" @staticmethod\n",
|
| 694 |
+
" def _run(layers, h):\n",
|
| 695 |
+
" for L in layers:\n",
|
| 696 |
+
" h = L(h)\n",
|
| 697 |
+
" return h\n",
|
| 698 |
+
"\n",
|
| 699 |
+
" def forward(self, x, mask): # x:(B,L,1024) mask:(B,L) bool\n",
|
| 700 |
+
" with torch.cuda.amp.autocast(enabled=False): # SSM chạy fp32 cho ổn định\n",
|
| 701 |
+
" x = x.float()\n",
|
| 702 |
+
" h = self.proj(x)\n",
|
| 703 |
+
" out = self._run(self.fwd, h)\n",
|
| 704 |
+
" if self.bidir:\n",
|
| 705 |
+
" out = out + torch.flip(self._run(self.bwd, torch.flip(h, dims=[1])), dims=[1])\n",
|
| 706 |
+
" a = self.attn(out).squeeze(-1).masked_fill(~mask, float(\"-inf\"))\n",
|
| 707 |
+
" w = torch.softmax(a, dim=1).unsqueeze(-1)\n",
|
| 708 |
+
" return self.out((out * w).sum(1))"
|
| 709 |
+
]
|
| 710 |
+
},
|
| 711 |
+
{
|
| 712 |
+
"cell_type": "markdown",
|
| 713 |
+
"id": "ff1cec20",
|
| 714 |
+
"metadata": {},
|
| 715 |
+
"source": [
|
| 716 |
+
"## 6b. Head cảm xúc + train loop (AMP + grad-accum + uncertainty weighting)"
|
| 717 |
+
]
|
| 718 |
+
},
|
| 719 |
+
{
|
| 720 |
+
"cell_type": "code",
|
| 721 |
+
"execution_count": null,
|
| 722 |
+
"id": "c414e504",
|
| 723 |
+
"metadata": {
|
| 724 |
+
"lines_to_next_cell": 1
|
| 725 |
+
},
|
| 726 |
+
"outputs": [],
|
| 727 |
+
"source": [
|
| 728 |
+
"from scipy.stats import spearmanr\n",
|
| 729 |
+
"\n",
|
| 730 |
+
"torch.manual_seed(SEED); np.random.seed(SEED)\n",
|
| 731 |
+
"N_EMO = len(EMOTIONS5)\n",
|
| 732 |
+
"WAVLM_BRANCH = Z_DIM if USE_MAMBA else WAVLM_DIM\n",
|
| 733 |
+
"TRUNK_IN = WAVLM_BRANCH + (AUD_DIM if USE_AUDEERING else 0)\n",
|
| 734 |
+
"\n",
|
| 735 |
+
"enc = MambaEncoder(WAVLM_DIM, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL).to(device) \\\n",
|
| 736 |
+
" if USE_MAMBA else None\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"class EmoHeads(nn.Module):\n",
|
| 739 |
+
" def __init__(self, d_in, trunk_h, head_h, p, n_emo):\n",
|
| 740 |
+
" super().__init__()\n",
|
| 741 |
+
" self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),\n",
|
| 742 |
+
" nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))\n",
|
| 743 |
+
" self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))\n",
|
| 744 |
+
" self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))\n",
|
| 745 |
+
" self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))\n",
|
| 746 |
+
" def forward(self, feat, tgt):\n",
|
| 747 |
+
" h = self.trunk(feat)\n",
|
| 748 |
+
" return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)\n",
|
| 749 |
+
"\n",
|
| 750 |
+
"heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)\n",
|
| 751 |
+
"print(f\"Trunk input = {TRUNK_IN} (wavlm-branch {WAVLM_BRANCH} [{'Mamba' if USE_MAMBA else 'mean-pool'}] + aud {AUD_DIM if USE_AUDEERING else 0})\")\n",
|
| 752 |
+
"if USE_MAMBA:\n",
|
| 753 |
+
" print(f\"Mamba encoder: {sum(p.numel() for p in enc.parameters())/1e6:.2f}M param\")\n",
|
| 754 |
+
"\n",
|
| 755 |
+
"# ── RESUME: nạp heads (+ Mamba enc) từ checkpoint ──\n",
|
| 756 |
+
"if RESUME and resume_ckpt is not None:\n",
|
| 757 |
+
" hm, hu = heads.load_state_dict(resume_ckpt[\"heads\"], strict=False)\n",
|
| 758 |
+
" print(f\"🔁 RESUME load heads từ ckpt: thiếu {len(hm)} / dư {len(hu)} key (kỳ vọng 0)\")\n",
|
| 759 |
+
" if USE_MAMBA and resume_ckpt.get(\"enc\") is not None:\n",
|
| 760 |
+
" em, eu = enc.load_state_dict(resume_ckpt[\"enc\"], strict=False)\n",
|
| 761 |
+
" print(f\"🔁 RESUME load Mamba enc từ ckpt: thiếu {len(em)} / dư {len(eu)} key (kỳ vọng 0)\")\n",
|
| 762 |
+
" elif USE_MAMBA:\n",
|
| 763 |
+
" print(\" ⚠️ ckpt KHÔNG có 'enc' (Mamba) → Mamba head train lại từ đầu (chỉ resume backbone+heads).\")\n",
|
| 764 |
+
"\n",
|
| 765 |
+
"TASKS = [\"emos\", \"cat\", \"val\", \"aro\", \"dom\"]\n",
|
| 766 |
+
"log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))\n",
|
| 767 |
+
"bb_params = [p for p in wavlm.parameters() if p.requires_grad]\n",
|
| 768 |
+
"head_params = list(heads.parameters()) + (list(enc.parameters()) if USE_MAMBA else []) \\\n",
|
| 769 |
+
" + ([log_var] if USE_UNCERTAINTY else [])\n",
|
| 770 |
+
"_lr_scale = RESUME_LR_SCALE if RESUME else 1.0\n",
|
| 771 |
+
"opt = torch.optim.AdamW([\n",
|
| 772 |
+
" {\"params\": bb_params, \"lr\": LR_BACKBONE * _lr_scale},\n",
|
| 773 |
+
" {\"params\": head_params, \"lr\": LR_HEAD * _lr_scale},\n",
|
| 774 |
+
"], weight_decay=WEIGHT_DECAY)\n",
|
| 775 |
+
"if RESUME and _lr_scale != 1.0:\n",
|
| 776 |
+
" print(f\"🔁 RESUME: LR ×{_lr_scale} → backbone {LR_BACKBONE*_lr_scale:.1e} · head {LR_HEAD*_lr_scale:.1e}\")\n",
|
| 777 |
+
"scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == \"cuda\")\n",
|
| 778 |
+
"mse = nn.MSELoss()\n",
|
| 779 |
+
"\n",
|
| 780 |
+
"def soft_ce(logits, target_dist):\n",
|
| 781 |
+
" return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()\n",
|
| 782 |
+
"\n",
|
| 783 |
+
"def wavlm_branch(input_values, attn_mask):\n",
|
| 784 |
+
" out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state # (B,T,D)\n",
|
| 785 |
+
" if USE_MAMBA:\n",
|
| 786 |
+
" return enc(out, frame_mask(out.shape[1], attn_mask)) # (B, Z_DIM)\n",
|
| 787 |
+
" return masked_mean(out, attn_mask) # (B, D)\n",
|
| 788 |
+
"\n",
|
| 789 |
+
"def forward_batch(b):\n",
|
| 790 |
+
" fw = wavlm_branch(b[\"input_values\"].to(device), b[\"attn_mask\"].to(device))\n",
|
| 791 |
+
" feat = torch.cat([fw, b[\"aud\"].to(device)], dim=1) if USE_AUDEERING else fw\n",
|
| 792 |
+
" return heads(feat, b[\"tgt\"].to(device))\n",
|
| 793 |
+
"\n",
|
| 794 |
+
"def pairwise_rank_loss(pred, target):\n",
|
| 795 |
+
" \"\"\"Hinge ranking trên MỌI cặp trong batch → tối ưu thẳng thứ hạng (≈ SRCC). Khả vi (backprop được).\n",
|
| 796 |
+
" Cần ≥2 mẫu/batch mới có cặp; batch càng to càng nhiều cặp → tín hiệu càng mạnh.\"\"\"\n",
|
| 797 |
+
" p = pred.reshape(-1); t = target.reshape(-1)\n",
|
| 798 |
+
" if p.numel() < 2:\n",
|
| 799 |
+
" return torch.zeros((), device=p.device)\n",
|
| 800 |
+
" sign = torch.sign(t.unsqueeze(0) - t.unsqueeze(1)) # +1 nếu câu i ĐÁNG cao hơn câu j\n",
|
| 801 |
+
" diff = p.unsqueeze(0) - p.unsqueeze(1) # chênh lệch model dự đoán\n",
|
| 802 |
+
" return torch.relu(-sign * diff).mean() # phạt khi xếp sai thứ tự\n",
|
| 803 |
+
"\n",
|
| 804 |
+
"def compute_loss(emos_p, cat_l, vad_p, b):\n",
|
| 805 |
+
" L = {\"emos\": mse(emos_p, b[\"emos\"].to(device)), \"cat\": soft_ce(cat_l, b[\"cat\"].to(device))}\n",
|
| 806 |
+
" if HAS_VAD:\n",
|
| 807 |
+
" vt = b[\"vad\"].to(device)\n",
|
| 808 |
+
" L[\"val\"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L[\"aro\"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L[\"dom\"] = mse(vad_p[:, 2:3], vt[:, 2:3])\n",
|
| 809 |
+
" else:\n",
|
| 810 |
+
" vt = None\n",
|
| 811 |
+
" z = torch.zeros((), device=device); L[\"val\"] = L[\"aro\"] = L[\"dom\"] = z\n",
|
| 812 |
+
" # Ranking loss CHỈ cho các cột chấm SRCC (emos/val/aro/dom). CAT là ERR phân bố → giữ soft-CE.\n",
|
| 813 |
+
" if RANK_LAMBDA > 0:\n",
|
| 814 |
+
" L[\"emos\"] = L[\"emos\"] + RANK_LAMBDA * pairwise_rank_loss(emos_p, b[\"emos\"].to(device))\n",
|
| 815 |
+
" if HAS_VAD:\n",
|
| 816 |
+
" L[\"val\"] = L[\"val\"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 0:1], vt[:, 0:1])\n",
|
| 817 |
+
" L[\"aro\"] = L[\"aro\"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 1:2], vt[:, 1:2])\n",
|
| 818 |
+
" L[\"dom\"] = L[\"dom\"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 2:3], vt[:, 2:3])\n",
|
| 819 |
+
" if USE_UNCERTAINTY:\n",
|
| 820 |
+
" return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))\n",
|
| 821 |
+
" return sum(L.values())\n",
|
| 822 |
+
"\n",
|
| 823 |
+
"def set_mode(train):\n",
|
| 824 |
+
" wavlm.train(train); heads.train(train)\n",
|
| 825 |
+
" if USE_MAMBA:\n",
|
| 826 |
+
" enc.train(train)\n",
|
| 827 |
+
"\n",
|
| 828 |
+
"@torch.no_grad()\n",
|
| 829 |
+
"def evaluate():\n",
|
| 830 |
+
" set_mode(False)\n",
|
| 831 |
+
" P = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}; Y = {\"emos\": [], \"val\": [], \"aro\": [], \"dom\": []}\n",
|
| 832 |
+
" catP, catY = [], []\n",
|
| 833 |
+
" for b in va_loader:\n",
|
| 834 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 835 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 836 |
+
" P[\"emos\"] += emos_p.float().cpu().numpy().ravel().tolist(); Y[\"emos\"] += b[\"emos_raw\"].tolist()\n",
|
| 837 |
+
" vad_p = vad_p.float().cpu().numpy()\n",
|
| 838 |
+
" for j, t in enumerate([\"val\", \"aro\", \"dom\"]):\n",
|
| 839 |
+
" P[t] += vad_p[:, j].tolist(); Y[t] += b[\"vad_raw\"][:, j].tolist()\n",
|
| 840 |
+
" catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b[\"cat\"])\n",
|
| 841 |
+
" out = {t: spearmanr(P[t], Y[t]).correlation for t in [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])}\n",
|
| 842 |
+
" q = np.concatenate(catP); p = np.concatenate(catY)\n",
|
| 843 |
+
" out[\"cat_err\"] = float(np.abs(q - p).sum(1).mean())\n",
|
| 844 |
+
" return out\n",
|
| 845 |
+
"\n",
|
| 846 |
+
"def mean_srcc(m):\n",
|
| 847 |
+
" keys = [\"emos\"] + ([\"val\", \"aro\", \"dom\"] if HAS_VAD else [])\n",
|
| 848 |
+
" return float(np.mean([m[k] for k in keys]))\n",
|
| 849 |
+
"\n",
|
| 850 |
+
"CKPT_PATH = os.path.join(OUT_DIR, \"ft_mamba_emotion_full.pt\")\n",
|
| 851 |
+
"def save_full_ckpt(state, val_emos=float(\"nan\")):\n",
|
| 852 |
+
" torch.save({\"wavlm\": state[\"wavlm\"], \"heads\": state[\"heads\"], \"enc\": state.get(\"enc\"),\n",
|
| 853 |
+
" \"USE_MAMBA\": USE_MAMBA, \"emos_mu\": emos_mu, \"emos_sd\": emos_sd,\n",
|
| 854 |
+
" \"vad_mu\": vad_mu, \"vad_sd\": vad_sd, \"WAVLM_DIM\": WAVLM_DIM, \"AUD_DIM\": AUD_DIM,\n",
|
| 855 |
+
" \"Z_DIM\": Z_DIM, \"UNFREEZE_TOP_LAYERS\": UNFREEZE_TOP_LAYERS,\n",
|
| 856 |
+
" \"val_emos\": float(val_emos)}, CKPT_PATH)\n",
|
| 857 |
+
"\n",
|
| 858 |
+
"def snapshot():\n",
|
| 859 |
+
" s = {\"wavlm\": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},\n",
|
| 860 |
+
" \"heads\": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}\n",
|
| 861 |
+
" if USE_MAMBA:\n",
|
| 862 |
+
" s[\"enc\"] = {k: v.cpu().clone() for k, v in enc.state_dict().items()}\n",
|
| 863 |
+
" return s\n",
|
| 864 |
+
"\n",
|
| 865 |
+
"# RESUME: init best = điểm VAL của ckpt hiện tại → chỉ ghi đè nếu train tiếp TỐT HƠN (không sợ tụt)\n",
|
| 866 |
+
"if RESUME and resume_ckpt is not None:\n",
|
| 867 |
+
" m0 = evaluate(); best = mean_srcc(m0); best_state = snapshot(); bad = 0\n",
|
| 868 |
+
" print(f\"📍 RESUME — checkpoint hiện tại: mean SRCC={best:.4f} | \"\n",
|
| 869 |
+
" + \" \".join(f\"{k}={m0[k]:.3f}\" for k in ['emos', 'val', 'aro', 'dom'] if k in m0))\n",
|
| 870 |
+
"else:\n",
|
| 871 |
+
" m0 = None\n",
|
| 872 |
+
" best, best_state, bad = -1e9, None, 0\n",
|
| 873 |
+
"for ep in range(1, EPOCHS + 1):\n",
|
| 874 |
+
" set_mode(True)\n",
|
| 875 |
+
" opt.zero_grad(); run = 0.0; nb = 0\n",
|
| 876 |
+
" for step, b in enumerate(tqdm(tr_loader, desc=f\"epoch {ep}\")):\n",
|
| 877 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 878 |
+
" emos_p, cat_l, vad_p = forward_batch(b)\n",
|
| 879 |
+
" loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM\n",
|
| 880 |
+
" scaler.scale(loss).backward()\n",
|
| 881 |
+
" if (step + 1) % ACCUM == 0:\n",
|
| 882 |
+
" scaler.step(opt); scaler.update(); opt.zero_grad()\n",
|
| 883 |
+
" run += loss.item() * ACCUM; nb += 1\n",
|
| 884 |
+
" m = evaluate(); sc = mean_srcc(m)\n",
|
| 885 |
+
" msg = \" \".join(f\"{k}={m[k]:.3f}\" for k in [\"emos\", \"val\", \"aro\", \"dom\"] if k in m)\n",
|
| 886 |
+
" print(f\"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})\")\n",
|
| 887 |
+
" if sc > best:\n",
|
| 888 |
+
" best = sc; bad = 0\n",
|
| 889 |
+
" best_state = snapshot()\n",
|
| 890 |
+
" save_full_ckpt(best_state, m[\"emos\"])\n",
|
| 891 |
+
" print(f\" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})\")\n",
|
| 892 |
+
" else:\n",
|
| 893 |
+
" bad += 1\n",
|
| 894 |
+
" if bad >= PATIENCE:\n",
|
| 895 |
+
" print(f\"Early stop ở epoch {ep}.\"); break\n",
|
| 896 |
+
"\n",
|
| 897 |
+
"if best_state:\n",
|
| 898 |
+
" wavlm.load_state_dict(best_state[\"wavlm\"]); heads.load_state_dict(best_state[\"heads\"])\n",
|
| 899 |
+
" if USE_MAMBA:\n",
|
| 900 |
+
" enc.load_state_dict(best_state[\"enc\"])\n",
|
| 901 |
+
"final = evaluate()\n",
|
| 902 |
+
"if RESUME and m0 is not None:\n",
|
| 903 |
+
" print(f\"\\n🔁 RESUME: mean SRCC ckpt {mean_srcc(m0):.4f} → sau train tiếp {mean_srcc(final):.4f} \"\n",
|
| 904 |
+
" + (\"🚀 cải thiện → đã ghi đè ckpt\" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else \"➖ không cải thiện (giữ best cũ)\"))\n",
|
| 905 |
+
"print(f\"\\n✅ VAL (nội bộ) — exp15 (Mamba={'ON' if USE_MAMBA else 'OFF'}):\")\n",
|
| 906 |
+
"print(f\" EMOS={final['emos']:.4f} (exp08 {EXP08['emos']})\")\n",
|
| 907 |
+
"if HAS_VAD:\n",
|
| 908 |
+
" print(f\" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} \"\n",
|
| 909 |
+
" f\"(exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})\")\n",
|
| 910 |
+
"warn = [f\"EMOS {final['emos']:.3f}<{EXP08['emos']}\"] if final[\"emos\"] < EXP08[\"emos\"] - 0.005 else []\n",
|
| 911 |
+
"if HAS_VAD:\n",
|
| 912 |
+
" warn += [f\"{t.upper()} {final[t]:.3f}<{EXP08[t]}\" for t in [\"val\", \"aro\", \"dom\"] if final[t] < EXP08[t] - 0.005]\n",
|
| 913 |
+
"print(\" ⚠️ Mamba head CHƯA thắng exp08 ở:\", \"; \".join(warn), \"(vẫn là kết quả cho paper)\" if warn else \"\")\n",
|
| 914 |
+
"if not warn:\n",
|
| 915 |
+
" print(\" ✅ Mamba head thắng/ngang exp08 ở mọi cột → temporal modeling có ích!\")\n",
|
| 916 |
+
"save_full_ckpt(best_state if best_state else\n",
|
| 917 |
+
" {\"wavlm\": wavlm.state_dict(), \"heads\": heads.state_dict(),\n",
|
| 918 |
+
" \"enc\": enc.state_dict() if USE_MAMBA else None}, final[\"emos\"])\n",
|
| 919 |
+
"print(f\"✅ Đã lưu {CKPT_PATH} (CÓ backbone + Mamba + heads). NHỚ Save Version!\")"
|
| 920 |
+
]
|
| 921 |
+
},
|
| 922 |
+
{
|
| 923 |
+
"cell_type": "markdown",
|
| 924 |
+
"id": "9c748af2",
|
| 925 |
+
"metadata": {},
|
| 926 |
+
"source": [
|
| 927 |
+
"## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc exp15; QMOS mượn exp07/UTMOSv2)"
|
| 928 |
+
]
|
| 929 |
+
},
|
| 930 |
+
{
|
| 931 |
+
"cell_type": "code",
|
| 932 |
+
"execution_count": null,
|
| 933 |
+
"id": "92d43e56",
|
| 934 |
+
"metadata": {
|
| 935 |
+
"lines_to_next_cell": 1
|
| 936 |
+
},
|
| 937 |
+
"outputs": [],
|
| 938 |
+
"source": [
|
| 939 |
+
"def list_dev():\n",
|
| 940 |
+
" with open(DEV_SCP) as f:\n",
|
| 941 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 942 |
+
"\n",
|
| 943 |
+
"dev_names = list_dev()\n",
|
| 944 |
+
"if LIMIT_DEV:\n",
|
| 945 |
+
" dev_names = dev_names[:LIMIT_DEV]\n",
|
| 946 |
+
"dev_stems = [stem(n) for n in dev_names]\n",
|
| 947 |
+
"print(\"DEV:\", len(dev_names), \"mẫu\")\n",
|
| 948 |
+
"aud_dev = extract_audeering(dev_stems, \"dev\")\n",
|
| 949 |
+
"\n",
|
| 950 |
+
"def load_exp07_qmos():\n",
|
| 951 |
+
" if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):\n",
|
| 952 |
+
" import csv\n",
|
| 953 |
+
" d = {}\n",
|
| 954 |
+
" with open(EXP07_ANSWER) as f:\n",
|
| 955 |
+
" for row in csv.DictReader(f):\n",
|
| 956 |
+
" d[row[\"wav\"]] = float(row[\"QMOS\"]); d[stem(row[\"wav\"])] = float(row[\"QMOS\"])\n",
|
| 957 |
+
" print(f\"✅ Mượn QMOS exp07 ({EXP07_ANSWER}): {len(d)//2} wav\")\n",
|
| 958 |
+
" return d\n",
|
| 959 |
+
" return None\n",
|
| 960 |
+
"\n",
|
| 961 |
+
"qmos_map = load_exp07_qmos()\n",
|
| 962 |
+
"if qmos_map is None:\n",
|
| 963 |
+
" print(\"ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).\")\n",
|
| 964 |
+
" pip_install(\"git+https://github.com/sarulab-speech/UTMOSv2.git\")\n",
|
| 965 |
+
" import utmosv2\n",
|
| 966 |
+
" v2 = utmosv2.create_model(pretrained=True)\n",
|
| 967 |
+
" qmos_map = {}\n",
|
| 968 |
+
" for n in tqdm(dev_names, desc=\"UTMOSv2\"):\n",
|
| 969 |
+
" wav = os.path.join(WAV_DIR, n if str(n).endswith(\".wav\") else str(n) + \".wav\")\n",
|
| 970 |
+
" if not os.path.exists(wav):\n",
|
| 971 |
+
" continue\n",
|
| 972 |
+
" out = v2.predict(input_path=wav)\n",
|
| 973 |
+
" qmos_map[n] = float(out[\"predicted_mos\"]) if isinstance(out, dict) else float(out)\n",
|
| 974 |
+
" del v2; torch.cuda.empty_cache() if device == \"cuda\" else None\n",
|
| 975 |
+
"\n",
|
| 976 |
+
"@torch.no_grad()\n",
|
| 977 |
+
"def predict_emotion(sid):\n",
|
| 978 |
+
" wave = load_wav(sid)\n",
|
| 979 |
+
" if wave is None or (USE_AUDEERING and sid not in aud_dev):\n",
|
| 980 |
+
" return None\n",
|
| 981 |
+
" set_mode(False)\n",
|
| 982 |
+
" iv = torch.from_numpy(wave).unsqueeze(0).to(device)\n",
|
| 983 |
+
" am = torch.ones((1, len(wave)), dtype=torch.long, device=device)\n",
|
| 984 |
+
" tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)\n",
|
| 985 |
+
" with torch.cuda.amp.autocast(enabled=USE_AMP and device == \"cuda\"):\n",
|
| 986 |
+
" fw = wavlm_branch(iv, am)\n",
|
| 987 |
+
" feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw\n",
|
| 988 |
+
" emos_p, cat_l, vad_p = heads(feat, tgt)\n",
|
| 989 |
+
" emos = float(emos_p.item()) * emos_sd + emos_mu\n",
|
| 990 |
+
" cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()\n",
|
| 991 |
+
" vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu\n",
|
| 992 |
+
" return emos, cat5, vad3\n",
|
| 993 |
+
"\n",
|
| 994 |
+
"def fmt_cat(p5):\n",
|
| 995 |
+
" return \"|\".join(f\"{e}:{p5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 996 |
+
"\n",
|
| 997 |
+
"def build_answer(out_path):\n",
|
| 998 |
+
" n_real = n_def = 0\n",
|
| 999 |
+
" with open(out_path, \"w\") as f:\n",
|
| 1000 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 1001 |
+
" for name in tqdm(dev_names, desc=\"answer\"):\n",
|
| 1002 |
+
" sid = stem(name)\n",
|
| 1003 |
+
" pr = predict_emotion(sid)\n",
|
| 1004 |
+
" if pr is None:\n",
|
| 1005 |
+
" emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1\n",
|
| 1006 |
+
" else:\n",
|
| 1007 |
+
" emos, cat5, vad3 = pr; n_real += 1\n",
|
| 1008 |
+
" qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))\n",
|
| 1009 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\\n\")\n",
|
| 1010 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}\")\n",
|
| 1011 |
+
"\n",
|
| 1012 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 1013 |
+
"build_answer(answer_path)"
|
| 1014 |
+
]
|
| 1015 |
+
},
|
| 1016 |
+
{
|
| 1017 |
+
"cell_type": "markdown",
|
| 1018 |
+
"id": "20ec4343",
|
| 1019 |
+
"metadata": {},
|
| 1020 |
+
"source": [
|
| 1021 |
+
"## 8. Validate + đóng zip"
|
| 1022 |
+
]
|
| 1023 |
+
},
|
| 1024 |
+
{
|
| 1025 |
+
"cell_type": "code",
|
| 1026 |
+
"execution_count": null,
|
| 1027 |
+
"id": "e289ea27",
|
| 1028 |
+
"metadata": {},
|
| 1029 |
+
"outputs": [],
|
| 1030 |
+
"source": [
|
| 1031 |
+
"def validate(path):\n",
|
| 1032 |
+
" import csv\n",
|
| 1033 |
+
" with open(path) as f:\n",
|
| 1034 |
+
" rows = list(csv.reader(f))\n",
|
| 1035 |
+
" assert rows[0][0] == \"wav\" and \"QMOS\" in rows[0] and \"EMOS\" in rows[0], \"Header sai\"\n",
|
| 1036 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 1037 |
+
" assert len(r) == len(rows[0]), f\"Dòng {i} sai số cột\"\n",
|
| 1038 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {rows[0]}\")\n",
|
| 1039 |
+
"\n",
|
| 1040 |
+
"validate(answer_path)\n",
|
| 1041 |
+
"os.system(f\"cd {OUT_DIR} && zip -j submission_track2_exp15_mamba-emotion.zip answer.txt \"\n",
|
| 1042 |
+
" f\"&& unzip -l submission_track2_exp15_mamba-emotion.zip\")\n",
|
| 1043 |
+
"print(\"Sẵn sàng nộp:\", os.path.join(OUT_DIR, \"submission_track2_exp15_mamba-emotion.zip\"))"
|
| 1044 |
+
]
|
| 1045 |
+
},
|
| 1046 |
+
{
|
| 1047 |
+
"cell_type": "markdown",
|
| 1048 |
+
"id": "7aeeb9ea",
|
| 1049 |
+
"metadata": {},
|
| 1050 |
+
"source": [
|
| 1051 |
+
"## Ghi chú\n",
|
| 1052 |
+
"- **🔁 RESUME (train tiếp, không train lại từ đầu):** Add Input dataset chứa `ft_mamba_emotion_full.pt` của lần\n",
|
| 1053 |
+
" chạy trước (hoặc để nó nằm sẵn trong `/kaggle/working` khi chạy nối phiên) → notebook tự dò & train tiếp.\n",
|
| 1054 |
+
" `EPOCHS` lúc này là **số epoch train THÊM**. Val chững → đặt `RESUME_LR_SCALE=0.5`. Muốn ép train mới: `RESUME_CKPT=\"—\"`\n",
|
| 1055 |
+
" (đường dẫn không tồn tại) hoặc xóa ckpt khỏi input. ⚠️ `USE_MAMBA` phải KHỚP ckpt (code sẽ cảnh báo nếu lệch).\n",
|
| 1056 |
+
"- **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` → kiểm 1 epoch không OOM / không CheckpointError; rồi đặt `None`.\n",
|
| 1057 |
+
"- **Ablation chính cho paper:** chạy `USE_MAMBA=True` vs `USE_MAMBA=False` (=exp08) → so EMOS/VAL/ARO/DOM nội bộ\n",
|
| 1058 |
+
" → trả lời \"Mamba temporal head có hơn mean-pooling không?\".\n",
|
| 1059 |
+
"- **OOM / quá chậm trên T4 (nhất là khi dùng Mamba thuần PyTorch):** giảm theo thứ tự\n",
|
| 1060 |
+
" `MAX_SECONDS` (6→5) → `MAMBA_LAYERS` (2→1) → `UNFREEZE_TOP_LAYERS` (6→4) → `BATCH` (2→1, tăng `ACCUM`).\n",
|
| 1061 |
+
" Hoặc thử cài `mamba-ssm causal-conv1d` (nhanh + nhẹ RAM hơn nhiều) — code tự dùng nếu import được.\n",
|
| 1062 |
+
"- **Ranking loss (`RANK_LAMBDA`):** thêm pairwise ranking cho 4 cột SRCC (emos/val/aro/dom) → khớp metric\n",
|
| 1063 |
+
" UTT-SRCC hơn MSE. ⚠️ **Điểm yếu:** ranking tính trên các cặp TRONG 1 mini-batch; `BATCH=2` → mỗi forward\n",
|
| 1064 |
+
" chỉ có 1 cặp → tín hiệu YẾU. Muốn ranking mạnh: tăng `BATCH` (4→8 nếu VRAM chịu được). Ở các exp head\n",
|
| 1065 |
+
" ĐÓNG BĂNG (exp06/07, BATCH=64) ranking mạnh hơn nhiều. A/B `RANK_LAMBDA=0` vs `0.3` → bảng ablation cho paper.\n",
|
| 1066 |
+
"- **QMOS:** Add Input answer.txt exp07 vào `/kaggle/input/exp07-answer/answer.txt` để mượn QMOS 0.548;\n",
|
| 1067 |
+
" không có thì tự chấm UTMOSv2 (cần Internet On).\n",
|
| 1068 |
+
"- Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp15)."
|
| 1069 |
+
]
|
| 1070 |
+
}
|
| 1071 |
+
],
|
| 1072 |
+
"metadata": {
|
| 1073 |
+
"jupytext": {
|
| 1074 |
+
"cell_metadata_filter": "-all",
|
| 1075 |
+
"main_language": "python",
|
| 1076 |
+
"notebook_metadata_filter": "-all"
|
| 1077 |
+
}
|
| 1078 |
+
},
|
| 1079 |
+
"nbformat": 4,
|
| 1080 |
+
"nbformat_minor": 5
|
| 1081 |
+
}
|
track2/exp15_wavlm_mamba_emotion_pipeline.py
ADDED
|
@@ -0,0 +1,920 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — exp15 (WavLM FINE-TUNE + MAMBA head cho 5 cột cảm xúc) — Kaggle
|
| 3 |
+
#
|
| 4 |
+
# **Ý tưởng:** exp08 fine-tune WavLM nhưng vẫn **mean-pool** đặc trưng theo thời gian → 1 vector/wav
|
| 5 |
+
# (vứt bỏ động lực thời gian: lên/xuống giọng, ngắt quãng, run giọng — rất quan trọng cho cảm xúc).
|
| 6 |
+
# exp15 **thay mean-pool bằng MAMBA head** (bộ mã hóa chuỗi học được, độ phức tạp tuyến tính) → kỳ vọng
|
| 7 |
+
# nắm temporal dynamics tốt hơn. Tham khảo: MambaRate (AudioMOS 2025, arXiv:2507.12090).
|
| 8 |
+
#
|
| 9 |
+
# ## Kiến trúc (= exp08 đổi đúng 1 chỗ: pool → Mamba)
|
| 10 |
+
# ```
|
| 11 |
+
# wav ─► WavLM-large (SAILER warm-start, mở băng N lớp, TRAINABLE) ─► hidden states (B, T, 1024)
|
| 12 |
+
# │ (KHÔNG mean-pool)
|
| 13 |
+
# MambaEncoder (proj 1024→d, Mamba×L 2 chiều,
|
| 14 |
+
# attentive-pool có mask) ─► z (B, Z_DIM)
|
| 15 |
+
# │
|
| 16 |
+
# (tùy chọn) audeering MSP-dim FROZEN [emb|vad3] ──concat──► TRUNK ─┬─► EMOS (+ one-hot target)
|
| 17 |
+
# ├─► CAT (5, soft-CE)
|
| 18 |
+
# └─► VAD (3)
|
| 19 |
+
# QMOS: KHÔNG train ở đây → mượn cột QMOS exp07 (0.548) hoặc UTMOSv2.
|
| 20 |
+
# ```
|
| 21 |
+
# - **Cờ `USE_MAMBA`:** True = Mamba head; False = quay về `masked_mean` = **đúng exp08**
|
| 22 |
+
# → đây là **ablation chính cho paper** ("Mamba temporal head vs mean-pooling", CÙNG backbone fine-tune).
|
| 23 |
+
#
|
| 24 |
+
# ## ⚠️ Đánh đổi / gotcha (đã phòng trong code)
|
| 25 |
+
# - Fine-tune = chạy lại WavLM mỗi epoch (không cache được) → **lần đầu BẮT BUỘC `LIMIT_TRAIN=300`, `LIMIT_DEV=20`**.
|
| 26 |
+
# - `mamba-ssm` khó cài Kaggle → tự fallback **Mamba thuần PyTorch** (vòng-lặp-thời-gian). Bản này khi fine-tune
|
| 27 |
+
# **chậm + nặng RAM hơn** → cap `MAX_SECONDS=6`, `BATCH=2`. OOM/quá chậm → hạ MAX_SECONDS→5, MAMBA_LAYERS→1,
|
| 28 |
+
# hoặc thử cài `mamba-ssm causal-conv1d`.
|
| 29 |
+
# - `layerdrop=0` (tránh CheckpointError khi grad-ckpt — bài học exp12). KHÔNG đụng numpy (lệch ABI).
|
| 30 |
+
# - **Checkpoint lưu CẢ backbone + Mamba + heads mỗi best** (bài học exp08 mất backbone).
|
| 31 |
+
#
|
| 32 |
+
# ## 🔁 RESUME (yêu cầu của user): "nếu có checkpoint thì train TIẾP, không train lại từ đầu"
|
| 33 |
+
# - Notebook **tự dò** `ft_mamba_emotion_full.pt` trong `/kaggle/input` và `/kaggle/working` (hoặc trỏ tay `RESUME_CKPT`).
|
| 34 |
+
# - Có ckpt đủ (backbone WavLM + Mamba enc + heads) → **nạp lại trạng thái + thống kê chuẩn hóa TỪ ckpt** rồi train tiếp;
|
| 35 |
+
# `best` khởi tạo = điểm VAL của ckpt → chỉ ghi đè khi train tiếp **TỐT HƠN** (không sợ tụt). `RESUME_LR_SCALE<1` để hạ LR.
|
| 36 |
+
# - KHÔNG có ckpt → train mới từ SAILER warm-start như cũ (hành vi exp15 gốc giữ nguyên).
|
| 37 |
+
#
|
| 38 |
+
# **Cách chạy Kaggle:** GPU **T4** + Internet **On** → Add Input dataset Track 2 (+ Add Input checkpoint cũ nếu muốn resume)
|
| 39 |
+
# → sửa `DATA_ROOT` → Run All.
|
| 40 |
+
|
| 41 |
+
# %% [markdown]
|
| 42 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 43 |
+
|
| 44 |
+
# %%
|
| 45 |
+
import os, glob
|
| 46 |
+
|
| 47 |
+
# ── TỰ DÒ DATA_ROOT (quét /kaggle/input tìm thư mục có sets/train.csv + wav/ + metadata.csv) ──
|
| 48 |
+
def find_data_root(search_root="/kaggle/input"):
|
| 49 |
+
cands = []
|
| 50 |
+
for train_csv in glob.glob(os.path.join(search_root, "**", "sets", "train.csv"), recursive=True):
|
| 51 |
+
root = os.path.dirname(os.path.dirname(train_csv)) # .../<root>/sets/train.csv → <root>
|
| 52 |
+
score = os.path.isdir(os.path.join(root, "wav")) + os.path.exists(os.path.join(root, "metadata.csv"))
|
| 53 |
+
cands.append((score, root))
|
| 54 |
+
cands.sort(reverse=True) # ưu tiên thư mục đủ wav + metadata
|
| 55 |
+
return cands
|
| 56 |
+
|
| 57 |
+
_cands = find_data_root("/kaggle/input")
|
| 58 |
+
if _cands:
|
| 59 |
+
print("🔎 Ứng viên DATA_ROOT (điểm cao = đủ wav+metadata):")
|
| 60 |
+
for sc, r in _cands:
|
| 61 |
+
print(f" [{sc}/2] {r}")
|
| 62 |
+
DATA_ROOT = _cands[0][1]
|
| 63 |
+
print(f"👉 Tự chọn DATA_ROOT = {DATA_ROOT}")
|
| 64 |
+
else:
|
| 65 |
+
DATA_ROOT = "/kaggle/input/datasets/minhtoan2" # dự phòng — sửa tay nếu auto-dò không thấy
|
| 66 |
+
print(f"❌ Không thấy sets/train.csv trong /kaggle/input → dùng dự phòng {DATA_ROOT} (đã Add Input chưa?)")
|
| 67 |
+
|
| 68 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 69 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header)
|
| 70 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # lisID|wavID|qMOS|emoCat|eMOS|val|dom|aro
|
| 71 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp"
|
| 72 |
+
|
| 73 |
+
OUT_DIR = "/kaggle/working"
|
| 74 |
+
CACHE_DIR = "/kaggle/working/ft_cache" # cache audeering (.npz) — WavLM/Mamba KHÔNG cache (đang train)
|
| 75 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 76 |
+
|
| 77 |
+
# (Tùy chọn) tái dùng cache audeering cũ (read-only /kaggle/input → copy sang working)
|
| 78 |
+
# Dataset cache_exp8: aud_*.npz nằm trong thư mục con archive/ → quét ĐỆ QUY để bắt mọi vị trí.
|
| 79 |
+
CACHE_INPUT = "/kaggle/input/cache-exp8" # << SỬA slug (dataset cache_exp8 → Kaggle đổi _→-); hoặc ""
|
| 80 |
+
if CACHE_INPUT and os.path.isdir(CACHE_INPUT):
|
| 81 |
+
import shutil
|
| 82 |
+
_n = 0
|
| 83 |
+
for _fp in glob.glob(os.path.join(CACHE_INPUT, "**", "aud_*.npz"), recursive=True):
|
| 84 |
+
shutil.copy(_fp, os.path.join(CACHE_DIR, os.path.basename(_fp))); _n += 1
|
| 85 |
+
print(f"📦 Tái dùng cache: copy {_n} file aud_*.npz (quét đệ quy {CACHE_INPUT})")
|
| 86 |
+
else:
|
| 87 |
+
print(f"ℹ️ Không thấy CACHE_INPUT={CACHE_INPUT} → sẽ tự trích audeering.")
|
| 88 |
+
|
| 89 |
+
# Mượn cột QMOS exp07 (0.548). Trỏ answer.txt exp07 nếu có; không thì UTMOSv2.
|
| 90 |
+
EXP07_ANSWER = "/kaggle/input/exp07-answer/answer.txt" # << (tùy chọn)
|
| 91 |
+
|
| 92 |
+
# ── Cờ Mamba (ablation chính) ────────────────────────────────────────────────
|
| 93 |
+
USE_MAMBA = True # True = Mamba head; False = mean-pool = ĐÚNG exp08
|
| 94 |
+
|
| 95 |
+
# ── Siêu tham số Mamba head ──────────────────────────────────────────────────
|
| 96 |
+
MAMBA_DMODEL = 256
|
| 97 |
+
MAMBA_LAYERS = 2
|
| 98 |
+
MAMBA_DSTATE = 16
|
| 99 |
+
BIDIRECTIONAL = True
|
| 100 |
+
Z_DIM = 256 # chiều vector ra sau attentive-pool, thay cho emb WavLM mean-pool
|
| 101 |
+
|
| 102 |
+
# ── Fine-tune / siêu tham số (kế thừa exp08) ─────────────────────────────────
|
| 103 |
+
DEVICE = "cuda"
|
| 104 |
+
SR = 16000
|
| 105 |
+
MAX_SECONDS = 6 # giảm từ 8 (exp08) vì Mamba backprop-through-time nặng RAM hơn
|
| 106 |
+
UNFREEZE_TOP_LAYERS = 6 # số lớp Transformer trên cùng được train (0 = freeze hết)
|
| 107 |
+
TRUNK_HIDDEN = 512
|
| 108 |
+
HEAD_HIDDEN = 128
|
| 109 |
+
DROPOUT = 0.3
|
| 110 |
+
LR_BACKBONE = 1e-5
|
| 111 |
+
LR_HEAD = 1e-3 # cho Mamba + trunk + head (train từ đầu)
|
| 112 |
+
WEIGHT_DECAY = 1e-5
|
| 113 |
+
EPOCHS = 12
|
| 114 |
+
PATIENCE = 3
|
| 115 |
+
BATCH = 2 # nhỏ (backbone to + Mamba); bù bằng ACCUM
|
| 116 |
+
ACCUM = 16 # effective batch = 32
|
| 117 |
+
VAL_FRAC = 0.10
|
| 118 |
+
SEED = 42
|
| 119 |
+
USE_AMP = True
|
| 120 |
+
USE_GRAD_CKPT = True
|
| 121 |
+
USE_AUDEERING = True
|
| 122 |
+
USE_UNCERTAINTY = True
|
| 123 |
+
RANK_LAMBDA = 0.3 # 0 = chỉ MSE (cũ). >0 = thêm pairwise ranking loss (tối ưu thẳng SRCC) cho emos/val/aro/dom
|
| 124 |
+
# ⚠️ ranking cần NHIỀU cặp/batch mới mạnh → BATCH nhỏ (2) thì tác dụng yếu (xem Ghi chú)
|
| 125 |
+
|
| 126 |
+
LIMIT_TRAIN = 300 # << LẦN ĐẦU 300; chạy thật None
|
| 127 |
+
LIMIT_DEV = 20 # << LẦN ĐẦU 20; chạy thật None
|
| 128 |
+
|
| 129 |
+
# ── RESUME — train TIẾP từ checkpoint, KHÔNG train lại từ đầu ─────────────────
|
| 130 |
+
# Để "" + auto-dò: nếu thấy `ft_mamba_emotion_full.pt` (đủ backbone+Mamba+heads) trong /kaggle/input
|
| 131 |
+
# hoặc /kaggle/working → nạp lại rồi train tiếp. Trỏ tay RESUME_CKPT nếu muốn chỉ định file cụ thể.
|
| 132 |
+
RESUME_CKPT = "" # << "" = auto-dò; hoặc "/kaggle/input/<slug>/ft_mamba_emotion_full.pt"
|
| 133 |
+
RESUME_LR_SCALE = 1.0 # <1.0 hạ LR khi train tiếp (vd 0.5 nếu val đã chững)
|
| 134 |
+
|
| 135 |
+
def find_resume_ckpt(explicit):
|
| 136 |
+
"""Tìm checkpoint exp15 để train tiếp. Ưu tiên đường dẫn user trỏ; không thì auto-dò.
|
| 137 |
+
Khớp cả tên bị Kaggle/Windows thêm hậu tố trùng, vd 'ft_mamba_emotion_full (2).pt'."""
|
| 138 |
+
if explicit and os.path.exists(explicit):
|
| 139 |
+
return explicit
|
| 140 |
+
for base in ["/kaggle/input", "/kaggle/working"]:
|
| 141 |
+
hits = sorted(glob.glob(os.path.join(base, "**", "ft_mamba_emotion_full*.pt"), recursive=True))
|
| 142 |
+
if hits:
|
| 143 |
+
return hits[0]
|
| 144 |
+
return ""
|
| 145 |
+
|
| 146 |
+
RESUME_CKPT = find_resume_ckpt(RESUME_CKPT)
|
| 147 |
+
RESUME = bool(RESUME_CKPT)
|
| 148 |
+
print("🔁 RESUME =", RESUME, ("→ train tiếp từ: " + RESUME_CKPT) if RESUME else "(không thấy ckpt → train MỚI từ đầu)")
|
| 149 |
+
|
| 150 |
+
# Mốc so (exp08 fine-tune + mean-pool — đối thủ trực tiếp của Mamba head)
|
| 151 |
+
EXP08 = {"emos": 0.811, "val": 0.659, "aro": 0.793, "dom": 0.751}
|
| 152 |
+
|
| 153 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 154 |
+
|
| 155 |
+
_EMO_ALIAS = {
|
| 156 |
+
"angry": "angry", "anger": "angry",
|
| 157 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 158 |
+
"neutral": "neutral", "calm": "neutral",
|
| 159 |
+
"sad": "sad", "sadness": "sad",
|
| 160 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
def norm_emotion(label):
|
| 164 |
+
key = str(label).strip().lower()
|
| 165 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 166 |
+
|
| 167 |
+
def stem(p):
|
| 168 |
+
return os.path.splitext(os.path.basename(str(p)))[0]
|
| 169 |
+
|
| 170 |
+
print("USE_MAMBA =", USE_MAMBA, "(False → ra đúng exp08)")
|
| 171 |
+
print("DATA_ROOT:", DATA_ROOT)
|
| 172 |
+
for p in [WAV_DIR, METADATA_CSV, TRAIN_CSV, DEV_SCP]:
|
| 173 |
+
print((" ✅ " if os.path.exists(p) else " ❌ THIẾU ") + p)
|
| 174 |
+
print(f"Fine-tune: mở băng {UNFREEZE_TOP_LAYERS} lớp · BATCH {BATCH}×ACCUM {ACCUM} · MAX {MAX_SECONDS}s")
|
| 175 |
+
|
| 176 |
+
# %% [markdown]
|
| 177 |
+
# ## 1. Cài đặt + tải code SAILER (clone + sys.path)
|
| 178 |
+
|
| 179 |
+
# %%
|
| 180 |
+
import sys, subprocess
|
| 181 |
+
|
| 182 |
+
def pip_install(*pkgs):
|
| 183 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", *pkgs], check=True)
|
| 184 |
+
|
| 185 |
+
pip_install("loralib", "speechbrain", "speechmos", "librosa", "soundfile",
|
| 186 |
+
"scipy", "scikit-learn", "pandas", "tqdm")
|
| 187 |
+
|
| 188 |
+
# Cài kernel CUDA Mamba (nhanh + nhẹ RAM hơn bản thuần PyTorch nhiều). Build hay lỗi/chậm trên Kaggle
|
| 189 |
+
# → bọc try/except: lỗi thì BỎ QUA, mục 6a tự fallback Mamba thuần PyTorch. KHÔNG để chết notebook.
|
| 190 |
+
INSTALL_MAMBA_SSM = True # đặt False nếu muốn BỎ QUA, dùng thẳng Mamba thuần PyTorch
|
| 191 |
+
if INSTALL_MAMBA_SSM and USE_MAMBA:
|
| 192 |
+
try:
|
| 193 |
+
# --no-build-isolation cho CẢ HAI → dùng torch+CUDA sẵn có của Kaggle để biên dịch (đừng kéo torch khác).
|
| 194 |
+
# Cần ninja để build nhanh. -q ẩn log nên bước này có thể "treo" vài phút khi đang compile — bình thường.
|
| 195 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "ninja"], check=True)
|
| 196 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q",
|
| 197 |
+
"--no-build-isolation", "causal-conv1d>=1.2.0"], check=True)
|
| 198 |
+
subprocess.run([sys.executable, "-m", "pip", "install", "-q",
|
| 199 |
+
"--no-build-isolation", "mamba-ssm"], check=True)
|
| 200 |
+
print("✅ Cài mamba-ssm + causal-conv1d xong (sẽ dùng kernel CUDA nếu import được).")
|
| 201 |
+
except Exception as e:
|
| 202 |
+
print("⚠️ Cài mamba-ssm thất bại:", repr(e), "→ dùng Mamba thuần PyTorch (chậm hơn).")
|
| 203 |
+
print(" ℹ️ Vẫn chạy bình thường. Nếu chạy THẬT (LIMIT=None) quá chậm → xem Ghi chú cuối notebook.")
|
| 204 |
+
|
| 205 |
+
REPO_DIR = "/kaggle/working/vox-profile-release"
|
| 206 |
+
if not os.path.exists(REPO_DIR):
|
| 207 |
+
subprocess.run(["git", "clone", "--depth", "1",
|
| 208 |
+
"https://github.com/tiantiaf0627/vox-profile-release.git", REPO_DIR], check=True)
|
| 209 |
+
if REPO_DIR not in sys.path:
|
| 210 |
+
sys.path.insert(0, REPO_DIR)
|
| 211 |
+
|
| 212 |
+
# %% [markdown]
|
| 213 |
+
# ## 2. Nạp SAILER → lấy backbone WavLM bên trong để FINE-TUNE (warm-start)
|
| 214 |
+
|
| 215 |
+
# %%
|
| 216 |
+
import torch
|
| 217 |
+
import torch.nn as nn
|
| 218 |
+
import torch.nn.functional as F
|
| 219 |
+
|
| 220 |
+
device = DEVICE if torch.cuda.is_available() else "cpu"
|
| 221 |
+
print("Device:", device, ("✅ " + torch.cuda.get_device_name(0)) if device == "cuda" else "⚠️ CPU (rất chậm!)")
|
| 222 |
+
|
| 223 |
+
def find_hf_backbone(module):
|
| 224 |
+
"""Tìm submodule kiểu HF WavLM backbone: có .feature_extractor và .encoder.layers."""
|
| 225 |
+
cands = []
|
| 226 |
+
for name, m in module.named_modules():
|
| 227 |
+
enc = getattr(m, "encoder", None)
|
| 228 |
+
if getattr(m, "feature_extractor", None) is not None and enc is not None \
|
| 229 |
+
and getattr(enc, "layers", None) is not None:
|
| 230 |
+
cands.append((name, m))
|
| 231 |
+
if not cands:
|
| 232 |
+
return None, None
|
| 233 |
+
cands.sort(key=lambda nm: sum(p.numel() for p in nm[1].parameters()), reverse=True)
|
| 234 |
+
return cands[0]
|
| 235 |
+
|
| 236 |
+
wavlm = None
|
| 237 |
+
try:
|
| 238 |
+
from src.model.emotion.wavlm_emotion import WavLMWrapper # noqa: E402
|
| 239 |
+
_wrapper = WavLMWrapper.from_pretrained("tiantiaf/wavlm-large-categorical-emotion")
|
| 240 |
+
name, wavlm = find_hf_backbone(_wrapper)
|
| 241 |
+
if wavlm is not None:
|
| 242 |
+
print(f"✅ Warm-start SAILER: backbone WavLM tại '.{name}' "
|
| 243 |
+
f"({sum(p.numel() for p in wavlm.parameters())/1e6:.0f}M params)")
|
| 244 |
+
else:
|
| 245 |
+
print("⚠️ Không tìm thấy backbone HF trong wrapper SAILER → fallback WavLM trắng.")
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print("⚠️ Lỗi nạp SAILER wrapper:", repr(e), "→ fallback WavLM trắng.")
|
| 248 |
+
|
| 249 |
+
if wavlm is None:
|
| 250 |
+
from transformers import WavLMModel
|
| 251 |
+
wavlm = WavLMModel.from_pretrained("microsoft/wavlm-large")
|
| 252 |
+
print("ℹ️ Fallback: microsoft/wavlm-large (KHÔNG warm-start SAILER).")
|
| 253 |
+
|
| 254 |
+
wavlm = wavlm.to(device)
|
| 255 |
+
WAVLM_DIM = int(wavlm.config.hidden_size)
|
| 256 |
+
wavlm.config.layerdrop = 0.0 # ⚠️ tránh CheckpointError khi grad-ckpt (bài học exp12)
|
| 257 |
+
|
| 258 |
+
# ── RESUME: nạp trọng số backbone đã fine-tune từ checkpoint (đè lên warm-start SAILER) ──
|
| 259 |
+
resume_ckpt = None
|
| 260 |
+
if RESUME:
|
| 261 |
+
resume_ckpt = torch.load(RESUME_CKPT, map_location="cpu", weights_only=False) # ckpt có numpy → cần False
|
| 262 |
+
assert "wavlm" in resume_ckpt, ("❌ Checkpoint KHÔNG có 'wavlm' (backbone) → không resume được. "
|
| 263 |
+
"Dùng file ft_mamba_emotion_full.pt do exp15 lưu.")
|
| 264 |
+
if resume_ckpt.get("USE_MAMBA", USE_MAMBA) != USE_MAMBA:
|
| 265 |
+
print(f" ⚠️ ckpt USE_MAMBA={resume_ckpt.get('USE_MAMBA')} ≠ cấu hình hiện tại {USE_MAMBA} → kiến trúc LỆCH! "
|
| 266 |
+
"Đặt USE_MAMBA cho khớp ckpt.")
|
| 267 |
+
miss, unexp = wavlm.load_state_dict(resume_ckpt["wavlm"], strict=False)
|
| 268 |
+
print(f"🔁 RESUME load wavlm từ ckpt: thiếu {len(miss)} / dư {len(unexp)} key (kỳ vọng ~0). keys ckpt:", list(resume_ckpt.keys()))
|
| 269 |
+
if len(miss) > 20 or len(unexp) > 20:
|
| 270 |
+
print(" ⚠️ Lệch key nhiều → kiểm tra UNFREEZE_TOP_LAYERS / backbone có khớp ckpt không.")
|
| 271 |
+
|
| 272 |
+
# ── Đóng băng partial: feature-extractor + tất cả trừ UNFREEZE_TOP_LAYERS lớp trên ──
|
| 273 |
+
for p in wavlm.parameters():
|
| 274 |
+
p.requires_grad = False
|
| 275 |
+
enc_layers = wavlm.encoder.layers
|
| 276 |
+
n_layers = len(enc_layers)
|
| 277 |
+
for layer in enc_layers[max(0, n_layers - UNFREEZE_TOP_LAYERS):]:
|
| 278 |
+
for p in layer.parameters():
|
| 279 |
+
p.requires_grad = True
|
| 280 |
+
n_train = sum(p.numel() for p in wavlm.parameters() if p.requires_grad)
|
| 281 |
+
print(f"WavLM: {n_layers} lớp · mở băng {min(UNFREEZE_TOP_LAYERS, n_layers)} → {n_train/1e6:.1f}M param train (dim {WAVLM_DIM})")
|
| 282 |
+
|
| 283 |
+
if USE_GRAD_CKPT:
|
| 284 |
+
wavlm.gradient_checkpointing_enable()
|
| 285 |
+
if hasattr(wavlm, "enable_input_require_grads"):
|
| 286 |
+
wavlm.enable_input_require_grads()
|
| 287 |
+
|
| 288 |
+
def frame_mask(T, attn_mask):
|
| 289 |
+
"""attn_mask (B, Lwav) → frame-mask (B, T) bool (True=frame thật). Khớp downsample của WavLM."""
|
| 290 |
+
if attn_mask is None:
|
| 291 |
+
return torch.ones((1, T), dtype=torch.bool, device=device)
|
| 292 |
+
try:
|
| 293 |
+
fm = wavlm._get_feature_vector_attention_mask(T, attn_mask)
|
| 294 |
+
return fm.bool()
|
| 295 |
+
except Exception:
|
| 296 |
+
return torch.ones((attn_mask.shape[0], T), dtype=torch.bool, device=attn_mask.device)
|
| 297 |
+
|
| 298 |
+
def masked_mean(hidden, attn_mask):
|
| 299 |
+
"""Mean-pool theo thời gian bỏ pad (đường exp08 khi USE_MAMBA=False)."""
|
| 300 |
+
if attn_mask is None:
|
| 301 |
+
return hidden.mean(dim=1)
|
| 302 |
+
fm = frame_mask(hidden.shape[1], attn_mask).unsqueeze(-1).to(hidden.dtype)
|
| 303 |
+
return (hidden * fm).sum(1) / fm.sum(1).clamp(min=1e-6)
|
| 304 |
+
|
| 305 |
+
# %% [markdown]
|
| 306 |
+
# ## 3. Nạp audeering MSP-dim (FROZEN) — đặc trưng phụ (như exp08)
|
| 307 |
+
|
| 308 |
+
# %%
|
| 309 |
+
AUD_DIM = 0
|
| 310 |
+
aud_backbone = aud_head = aud_proc = None
|
| 311 |
+
if USE_AUDEERING:
|
| 312 |
+
from transformers import Wav2Vec2Model, Wav2Vec2Config, Wav2Vec2Processor
|
| 313 |
+
from huggingface_hub import hf_hub_download
|
| 314 |
+
AUD_NAME = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
|
| 315 |
+
aud_proc = Wav2Vec2Processor.from_pretrained(AUD_NAME)
|
| 316 |
+
aud_cfg = Wav2Vec2Config.from_pretrained(AUD_NAME)
|
| 317 |
+
aud_backbone = Wav2Vec2Model(aud_cfg)
|
| 318 |
+
try:
|
| 319 |
+
_sd = __import__("safetensors.torch", fromlist=["load_file"]).load_file(
|
| 320 |
+
hf_hub_download(AUD_NAME, "model.safetensors"))
|
| 321 |
+
except Exception:
|
| 322 |
+
_sd = torch.load(hf_hub_download(AUD_NAME, "pytorch_model.bin"), map_location="cpu")
|
| 323 |
+
bb_sd = {k[len("wav2vec2."):]: v for k, v in _sd.items() if k.startswith("wav2vec2.")}
|
| 324 |
+
aud_backbone.load_state_dict(bb_sd, strict=False)
|
| 325 |
+
_hid = _sd["classifier.dense.weight"].shape[0]
|
| 326 |
+
_out = _sd["classifier.out_proj.weight"].shape[0]
|
| 327 |
+
aud_head = nn.Sequential(nn.Linear(_hid, _hid), nn.Tanh(), nn.Linear(_hid, _out))
|
| 328 |
+
aud_head[0].weight.data.copy_(_sd["classifier.dense.weight"]); aud_head[0].bias.data.copy_(_sd["classifier.dense.bias"])
|
| 329 |
+
aud_head[2].weight.data.copy_(_sd["classifier.out_proj.weight"]); aud_head[2].bias.data.copy_(_sd["classifier.out_proj.bias"])
|
| 330 |
+
aud_backbone = aud_backbone.to(device).eval()
|
| 331 |
+
aud_head = aud_head.to(device).eval()
|
| 332 |
+
AUD_DIM = _hid + 3
|
| 333 |
+
print(f"✅ audeering frozen (đặc trưng phụ {AUD_DIM}-D = emb {_hid} + vad 3)")
|
| 334 |
+
|
| 335 |
+
# %%
|
| 336 |
+
import numpy as np
|
| 337 |
+
import librosa
|
| 338 |
+
from tqdm.auto import tqdm
|
| 339 |
+
|
| 340 |
+
def load_wav(name_or_stem):
|
| 341 |
+
p = name_or_stem if os.path.isabs(str(name_or_stem)) else os.path.join(
|
| 342 |
+
WAV_DIR, name_or_stem if str(name_or_stem).endswith(".wav") else str(name_or_stem) + ".wav")
|
| 343 |
+
if not os.path.exists(p):
|
| 344 |
+
return None
|
| 345 |
+
wave, _ = librosa.load(p, sr=SR, mono=True)
|
| 346 |
+
return wave[: MAX_SECONDS * SR].astype(np.float32)
|
| 347 |
+
|
| 348 |
+
@torch.no_grad()
|
| 349 |
+
def extract_audeering(stems, tag):
|
| 350 |
+
if not USE_AUDEERING:
|
| 351 |
+
return {}
|
| 352 |
+
cache_path = os.path.join(CACHE_DIR, f"aud_{tag}.npz")
|
| 353 |
+
store = {}
|
| 354 |
+
if os.path.exists(cache_path):
|
| 355 |
+
z = np.load(cache_path, allow_pickle=True)
|
| 356 |
+
store = {k: z[k] for k in z.files}
|
| 357 |
+
print(f"[aud/{tag}] nạp cache: {len(store)}")
|
| 358 |
+
todo = [s for s in stems if s not in store]
|
| 359 |
+
for i, s in enumerate(tqdm(todo, desc=f"audeering {tag}")):
|
| 360 |
+
wave = load_wav(s)
|
| 361 |
+
if wave is None:
|
| 362 |
+
continue
|
| 363 |
+
x = aud_proc(wave, sampling_rate=SR).input_values[0]
|
| 364 |
+
x = torch.from_numpy(np.asarray(x, dtype=np.float32)).unsqueeze(0).to(device)
|
| 365 |
+
h = aud_backbone(x)[0].mean(dim=1)
|
| 366 |
+
out = aud_head(h)[0].cpu().numpy() # [arousal, dominance, valence] ∈[0,1]
|
| 367 |
+
vad = np.array([1 + 4 * out[2], 1 + 4 * out[0], 1 + 4 * out[1]], dtype=np.float32) # [VAL,ARO,DOM]
|
| 368 |
+
store[s] = np.concatenate([h[0].cpu().numpy(), vad]).astype(np.float32)
|
| 369 |
+
if (i + 1) % 500 == 0:
|
| 370 |
+
np.savez(cache_path, **store)
|
| 371 |
+
if todo:
|
| 372 |
+
np.savez(cache_path, **store)
|
| 373 |
+
return store
|
| 374 |
+
|
| 375 |
+
# %% [markdown]
|
| 376 |
+
# ## 4. Đọc & gộp nhãn theo wavID (EMOS / VAD / CAT) — như exp08
|
| 377 |
+
|
| 378 |
+
# %%
|
| 379 |
+
import pandas as pd
|
| 380 |
+
|
| 381 |
+
def load_target_emotions():
|
| 382 |
+
tgt = {}
|
| 383 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 384 |
+
for ln in f:
|
| 385 |
+
parts = ln.strip().split("|")
|
| 386 |
+
if len(parts) >= 2:
|
| 387 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 388 |
+
return tgt
|
| 389 |
+
|
| 390 |
+
def _col(cols_map, *names, df=None, default_idx=None):
|
| 391 |
+
for n in names:
|
| 392 |
+
if n in cols_map:
|
| 393 |
+
return cols_map[n]
|
| 394 |
+
return list(df.columns)[default_idx] if default_idx is not None else None
|
| 395 |
+
|
| 396 |
+
def parse_emocat_votes(cell):
|
| 397 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 398 |
+
for tok in str(cell).replace("/", ",").replace(";", ",").replace("|", ",").replace(" ", ",").split(","):
|
| 399 |
+
e = norm_emotion(tok)
|
| 400 |
+
if e in EMOTIONS5:
|
| 401 |
+
v[EMOTIONS5.index(e)] += 1.0
|
| 402 |
+
return v
|
| 403 |
+
|
| 404 |
+
def load_train_labels():
|
| 405 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 406 |
+
cols = {c.lower().strip(): c for c in df.columns}
|
| 407 |
+
wav_col = _col(cols, "wavid", "wav", df=df, default_idx=1)
|
| 408 |
+
emos_col = _col(cols, "emos", "emo", "emomos")
|
| 409 |
+
val_col = _col(cols, "val", "valence"); aro_col = _col(cols, "aro", "arousal"); dom_col = _col(cols, "dom", "dominance")
|
| 410 |
+
cat_col = _col(cols, "emocat", "cat", "emotion")
|
| 411 |
+
assert emos_col, f"Không thấy cột eMOS (cột: {list(df.columns)})"
|
| 412 |
+
df["_stem"] = df[wav_col].map(stem)
|
| 413 |
+
rows = []
|
| 414 |
+
for sid, g in df.groupby("_stem"):
|
| 415 |
+
rec = {"wavID": sid, "emos": float(g[emos_col].mean())}
|
| 416 |
+
rec["val"] = float(g[val_col].mean()) if val_col else np.nan
|
| 417 |
+
rec["aro"] = float(g[aro_col].mean()) if aro_col else np.nan
|
| 418 |
+
rec["dom"] = float(g[dom_col].mean()) if dom_col else np.nan
|
| 419 |
+
votes = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 420 |
+
if cat_col:
|
| 421 |
+
for cell in g[cat_col]:
|
| 422 |
+
votes += parse_emocat_votes(cell)
|
| 423 |
+
s = votes.sum()
|
| 424 |
+
cat = votes / s if s > 0 else np.full(len(EMOTIONS5), 0.2, dtype=np.float32)
|
| 425 |
+
for i in range(len(EMOTIONS5)):
|
| 426 |
+
rec[f"cat{i}"] = float(cat[i])
|
| 427 |
+
rows.append(rec)
|
| 428 |
+
return pd.DataFrame(rows)
|
| 429 |
+
|
| 430 |
+
target_map = load_target_emotions()
|
| 431 |
+
train_df = load_train_labels()
|
| 432 |
+
HAS_VAD = bool(train_df["val"].notna().any())
|
| 433 |
+
print(f"Target: {len(target_map)} | wav train (gộp): {len(train_df)} | có VAD: {HAS_VAD}")
|
| 434 |
+
|
| 435 |
+
# %% [markdown]
|
| 436 |
+
# ## 5. Dataset / DataLoader (load wav theo batch — KHÔNG cache WavLM vì đang train)
|
| 437 |
+
|
| 438 |
+
# %%
|
| 439 |
+
from torch.utils.data import Dataset, DataLoader
|
| 440 |
+
|
| 441 |
+
train_stems = [s for s in train_df["wavID"] if target_map.get(s) is not None]
|
| 442 |
+
if LIMIT_TRAIN:
|
| 443 |
+
train_stems = train_stems[:LIMIT_TRAIN]
|
| 444 |
+
aud_tr = extract_audeering(train_stems, "train")
|
| 445 |
+
|
| 446 |
+
lab = train_df.set_index("wavID")
|
| 447 |
+
|
| 448 |
+
def _zfit(arr):
|
| 449 |
+
a = np.asarray(arr, dtype=np.float32)
|
| 450 |
+
return float(np.nanmean(a)), float(np.nanstd(a) + 1e-6)
|
| 451 |
+
|
| 452 |
+
if RESUME and resume_ckpt is not None:
|
| 453 |
+
# QUAN TRỌNG: lấy chuẩn hóa TỪ ckpt (head đã train theo thang này) — KHÔNG tính lại để khỏi lệch thang
|
| 454 |
+
emos_mu = float(resume_ckpt["emos_mu"]); emos_sd = float(resume_ckpt["emos_sd"])
|
| 455 |
+
vad_mu = np.asarray(resume_ckpt["vad_mu"], dtype=np.float32)
|
| 456 |
+
vad_sd = np.asarray(resume_ckpt["vad_sd"], dtype=np.float32)
|
| 457 |
+
print(f"🔁 RESUME: dùng chuẩn hóa TỪ ckpt: emos μ={emos_mu:.3f} σ={emos_sd:.3f} | vad μ={np.round(vad_mu,2)}")
|
| 458 |
+
else:
|
| 459 |
+
emos_mu, emos_sd = _zfit([lab.loc[s, "emos"] for s in train_stems])
|
| 460 |
+
if HAS_VAD:
|
| 461 |
+
vad_mu = np.array([_zfit([lab.loc[s, c] for s in train_stems])[0] for c in ["val", "aro", "dom"]], dtype=np.float32)
|
| 462 |
+
vad_sd = np.array([_zfit([lab.loc[s, c] for s in train_stems])[1] for c in ["val", "aro", "dom"]], dtype=np.float32)
|
| 463 |
+
else:
|
| 464 |
+
vad_mu = np.zeros(3, dtype=np.float32); vad_sd = np.ones(3, dtype=np.float32)
|
| 465 |
+
|
| 466 |
+
def onehot_target(tgt):
|
| 467 |
+
v = np.zeros(len(EMOTIONS5), dtype=np.float32)
|
| 468 |
+
if tgt in EMOTIONS5:
|
| 469 |
+
v[EMOTIONS5.index(tgt)] = 1.0
|
| 470 |
+
return v
|
| 471 |
+
|
| 472 |
+
class EmoDataset(Dataset):
|
| 473 |
+
def __init__(self, stems):
|
| 474 |
+
self.stems = [s for s in stems if (load_wav(s) is not None) and ((not USE_AUDEERING) or s in aud_tr)]
|
| 475 |
+
def __len__(self):
|
| 476 |
+
return len(self.stems)
|
| 477 |
+
def __getitem__(self, i):
|
| 478 |
+
s = self.stems[i]
|
| 479 |
+
wave = load_wav(s)
|
| 480 |
+
emos = (float(lab.loc[s, "emos"]) - emos_mu) / emos_sd
|
| 481 |
+
if HAS_VAD:
|
| 482 |
+
vad = (np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32) - vad_mu) / vad_sd
|
| 483 |
+
else:
|
| 484 |
+
vad = np.zeros(3, dtype=np.float32)
|
| 485 |
+
cat = np.array([lab.loc[s, f"cat{j}"] for j in range(len(EMOTIONS5))], dtype=np.float32)
|
| 486 |
+
aud = aud_tr[s] if USE_AUDEERING else np.zeros(0, dtype=np.float32)
|
| 487 |
+
return {"wave": wave, "tgt": onehot_target(target_map.get(s)), "aud": aud,
|
| 488 |
+
"emos": np.float32(emos), "vad": vad, "cat": cat,
|
| 489 |
+
"emos_raw": np.float32(lab.loc[s, "emos"]),
|
| 490 |
+
"vad_raw": np.array([lab.loc[s, "val"], lab.loc[s, "aro"], lab.loc[s, "dom"]], np.float32)}
|
| 491 |
+
|
| 492 |
+
def collate(batch):
|
| 493 |
+
L = max(len(b["wave"]) for b in batch)
|
| 494 |
+
waves = np.zeros((len(batch), L), dtype=np.float32)
|
| 495 |
+
mask = np.zeros((len(batch), L), dtype=np.float32)
|
| 496 |
+
for i, b in enumerate(batch):
|
| 497 |
+
waves[i, : len(b["wave"])] = b["wave"]; mask[i, : len(b["wave"])] = 1.0
|
| 498 |
+
return {
|
| 499 |
+
"input_values": torch.from_numpy(waves), "attn_mask": torch.from_numpy(mask).long(),
|
| 500 |
+
"tgt": torch.from_numpy(np.stack([b["tgt"] for b in batch])),
|
| 501 |
+
"aud": torch.from_numpy(np.stack([b["aud"] for b in batch])) if USE_AUDEERING else None,
|
| 502 |
+
"emos": torch.from_numpy(np.stack([b["emos"] for b in batch])).unsqueeze(1),
|
| 503 |
+
"vad": torch.from_numpy(np.stack([b["vad"] for b in batch])),
|
| 504 |
+
"cat": torch.from_numpy(np.stack([b["cat"] for b in batch])),
|
| 505 |
+
"emos_raw": np.stack([b["emos_raw"] for b in batch]),
|
| 506 |
+
"vad_raw": np.stack([b["vad_raw"] for b in batch]),
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
from sklearn.model_selection import train_test_split
|
| 510 |
+
ds = EmoDataset(train_stems)
|
| 511 |
+
print("Dataset hợp lệ:", len(ds), "wav")
|
| 512 |
+
tr_i, va_i = train_test_split(np.arange(len(ds)), test_size=VAL_FRAC, random_state=SEED)
|
| 513 |
+
tr_loader = DataLoader(torch.utils.data.Subset(ds, tr_i), batch_size=BATCH, shuffle=True, collate_fn=collate, num_workers=2)
|
| 514 |
+
va_loader = DataLoader(torch.utils.data.Subset(ds, va_i), batch_size=BATCH, shuffle=False, collate_fn=collate, num_workers=2)
|
| 515 |
+
|
| 516 |
+
# %% [markdown]
|
| 517 |
+
# ## 6a. Khối MAMBA (thuần PyTorch, fallback nếu không có `mamba-ssm`)
|
| 518 |
+
# Theo "mamba-minimal" — đúng công thức selective SSM, chỉ chậm hơn kernel CUDA. Chạy trong fp32 cho ổn định.
|
| 519 |
+
|
| 520 |
+
# %%
|
| 521 |
+
import math
|
| 522 |
+
|
| 523 |
+
try:
|
| 524 |
+
from mamba_ssm import Mamba as _OfficialMamba
|
| 525 |
+
_HAS_MAMBA_SSM = True
|
| 526 |
+
print("✅ Dùng mamba-ssm (CUDA kernel)")
|
| 527 |
+
except Exception:
|
| 528 |
+
_HAS_MAMBA_SSM = False
|
| 529 |
+
print("ℹ️ Không có mamba-ssm → Mamba thuần PyTorch (chậm hơn khi fine-tune)")
|
| 530 |
+
|
| 531 |
+
class MambaBlockTorch(nn.Module):
|
| 532 |
+
def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.d_inner = expand * d_model
|
| 535 |
+
self.dt_rank = math.ceil(d_model / 16)
|
| 536 |
+
self.in_proj = nn.Linear(d_model, self.d_inner * 2, bias=False)
|
| 537 |
+
self.conv1d = nn.Conv1d(self.d_inner, self.d_inner, kernel_size=d_conv,
|
| 538 |
+
groups=self.d_inner, padding=d_conv - 1, bias=True)
|
| 539 |
+
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + d_state * 2, bias=False)
|
| 540 |
+
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
|
| 541 |
+
A = torch.arange(1, d_state + 1, dtype=torch.float32).repeat(self.d_inner, 1)
|
| 542 |
+
self.A_log = nn.Parameter(torch.log(A))
|
| 543 |
+
self.D = nn.Parameter(torch.ones(self.d_inner))
|
| 544 |
+
self.out_proj = nn.Linear(self.d_inner, d_model, bias=False)
|
| 545 |
+
self.d_state = d_state
|
| 546 |
+
|
| 547 |
+
def forward(self, x): # x: (B, L, d_model)
|
| 548 |
+
B, L, _ = x.shape
|
| 549 |
+
xin, z = self.in_proj(x).chunk(2, dim=-1)
|
| 550 |
+
xin = xin.transpose(1, 2)
|
| 551 |
+
xin = self.conv1d(xin)[..., :L].transpose(1, 2)
|
| 552 |
+
xin = F.silu(xin)
|
| 553 |
+
y = self._ssm(xin) * F.silu(z)
|
| 554 |
+
return self.out_proj(y)
|
| 555 |
+
|
| 556 |
+
def _ssm(self, x):
|
| 557 |
+
A = -torch.exp(self.A_log)
|
| 558 |
+
delta, Bm, Cm = torch.split(self.x_proj(x), [self.dt_rank, self.d_state, self.d_state], dim=-1)
|
| 559 |
+
delta = F.softplus(self.dt_proj(delta))
|
| 560 |
+
dA = torch.exp(delta.unsqueeze(-1) * A)
|
| 561 |
+
dB_x = delta.unsqueeze(-1) * Bm.unsqueeze(2) * x.unsqueeze(-1)
|
| 562 |
+
h = torch.zeros(x.shape[0], self.d_inner, self.d_state, device=x.device, dtype=x.dtype)
|
| 563 |
+
ys = []
|
| 564 |
+
for t in range(x.shape[1]):
|
| 565 |
+
h = dA[:, t] * h + dB_x[:, t]
|
| 566 |
+
ys.append((h * Cm[:, t].unsqueeze(1)).sum(-1))
|
| 567 |
+
return torch.stack(ys, dim=1) + x * self.D
|
| 568 |
+
|
| 569 |
+
class MambaLayer(nn.Module):
|
| 570 |
+
def __init__(self, d_model, d_state):
|
| 571 |
+
super().__init__()
|
| 572 |
+
self.norm = nn.LayerNorm(d_model)
|
| 573 |
+
self.mix = _OfficialMamba(d_model=d_model, d_state=d_state, d_conv=4, expand=2) \
|
| 574 |
+
if _HAS_MAMBA_SSM else MambaBlockTorch(d_model, d_state=d_state)
|
| 575 |
+
def forward(self, x):
|
| 576 |
+
return x + self.mix(self.norm(x))
|
| 577 |
+
|
| 578 |
+
class MambaEncoder(nn.Module):
|
| 579 |
+
"""1024 → d_model → [Mamba ×L] (2 chiều) → attentive-pool (có mask) → Z_DIM."""
|
| 580 |
+
def __init__(self, d_in, d_model, n_layers, d_state, z_dim, bidir):
|
| 581 |
+
super().__init__()
|
| 582 |
+
self.bidir = bidir
|
| 583 |
+
self.proj = nn.Linear(d_in, d_model)
|
| 584 |
+
self.fwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
|
| 585 |
+
if bidir:
|
| 586 |
+
self.bwd = nn.ModuleList([MambaLayer(d_model, d_state) for _ in range(n_layers)])
|
| 587 |
+
self.attn = nn.Linear(d_model, 1)
|
| 588 |
+
self.out = nn.Linear(d_model, z_dim)
|
| 589 |
+
|
| 590 |
+
@staticmethod
|
| 591 |
+
def _run(layers, h):
|
| 592 |
+
for L in layers:
|
| 593 |
+
h = L(h)
|
| 594 |
+
return h
|
| 595 |
+
|
| 596 |
+
def forward(self, x, mask): # x:(B,L,1024) mask:(B,L) bool
|
| 597 |
+
with torch.cuda.amp.autocast(enabled=False): # SSM chạy fp32 cho ổn định
|
| 598 |
+
x = x.float()
|
| 599 |
+
h = self.proj(x)
|
| 600 |
+
out = self._run(self.fwd, h)
|
| 601 |
+
if self.bidir:
|
| 602 |
+
out = out + torch.flip(self._run(self.bwd, torch.flip(h, dims=[1])), dims=[1])
|
| 603 |
+
a = self.attn(out).squeeze(-1).masked_fill(~mask, float("-inf"))
|
| 604 |
+
w = torch.softmax(a, dim=1).unsqueeze(-1)
|
| 605 |
+
return self.out((out * w).sum(1))
|
| 606 |
+
|
| 607 |
+
# %% [markdown]
|
| 608 |
+
# ## 6b. Head cảm xúc + train loop (AMP + grad-accum + uncertainty weighting)
|
| 609 |
+
|
| 610 |
+
# %%
|
| 611 |
+
from scipy.stats import spearmanr
|
| 612 |
+
|
| 613 |
+
torch.manual_seed(SEED); np.random.seed(SEED)
|
| 614 |
+
N_EMO = len(EMOTIONS5)
|
| 615 |
+
WAVLM_BRANCH = Z_DIM if USE_MAMBA else WAVLM_DIM
|
| 616 |
+
TRUNK_IN = WAVLM_BRANCH + (AUD_DIM if USE_AUDEERING else 0)
|
| 617 |
+
|
| 618 |
+
enc = MambaEncoder(WAVLM_DIM, MAMBA_DMODEL, MAMBA_LAYERS, MAMBA_DSTATE, Z_DIM, BIDIRECTIONAL).to(device) \
|
| 619 |
+
if USE_MAMBA else None
|
| 620 |
+
|
| 621 |
+
class EmoHeads(nn.Module):
|
| 622 |
+
def __init__(self, d_in, trunk_h, head_h, p, n_emo):
|
| 623 |
+
super().__init__()
|
| 624 |
+
self.trunk = nn.Sequential(nn.Linear(d_in, trunk_h), nn.ReLU(), nn.Dropout(p),
|
| 625 |
+
nn.Linear(trunk_h, trunk_h), nn.ReLU(), nn.Dropout(p))
|
| 626 |
+
self.emos = nn.Sequential(nn.Linear(trunk_h + n_emo, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 1))
|
| 627 |
+
self.cat = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, n_emo))
|
| 628 |
+
self.vad = nn.Sequential(nn.Linear(trunk_h, head_h), nn.ReLU(), nn.Dropout(p), nn.Linear(head_h, 3))
|
| 629 |
+
def forward(self, feat, tgt):
|
| 630 |
+
h = self.trunk(feat)
|
| 631 |
+
return self.emos(torch.cat([h, tgt], 1)), self.cat(h), self.vad(h)
|
| 632 |
+
|
| 633 |
+
heads = EmoHeads(TRUNK_IN, TRUNK_HIDDEN, HEAD_HIDDEN, DROPOUT, N_EMO).to(device)
|
| 634 |
+
print(f"Trunk input = {TRUNK_IN} (wavlm-branch {WAVLM_BRANCH} [{'Mamba' if USE_MAMBA else 'mean-pool'}] + aud {AUD_DIM if USE_AUDEERING else 0})")
|
| 635 |
+
if USE_MAMBA:
|
| 636 |
+
print(f"Mamba encoder: {sum(p.numel() for p in enc.parameters())/1e6:.2f}M param")
|
| 637 |
+
|
| 638 |
+
# ── RESUME: nạp heads (+ Mamba enc) từ checkpoint ──
|
| 639 |
+
if RESUME and resume_ckpt is not None:
|
| 640 |
+
hm, hu = heads.load_state_dict(resume_ckpt["heads"], strict=False)
|
| 641 |
+
print(f"🔁 RESUME load heads từ ckpt: thiếu {len(hm)} / dư {len(hu)} key (kỳ vọng 0)")
|
| 642 |
+
if USE_MAMBA and resume_ckpt.get("enc") is not None:
|
| 643 |
+
em, eu = enc.load_state_dict(resume_ckpt["enc"], strict=False)
|
| 644 |
+
print(f"🔁 RESUME load Mamba enc từ ckpt: thiếu {len(em)} / dư {len(eu)} key (kỳ vọng 0)")
|
| 645 |
+
elif USE_MAMBA:
|
| 646 |
+
print(" ⚠️ ckpt KHÔNG có 'enc' (Mamba) → Mamba head train lại từ đầu (chỉ resume backbone+heads).")
|
| 647 |
+
|
| 648 |
+
TASKS = ["emos", "cat", "val", "aro", "dom"]
|
| 649 |
+
log_var = nn.Parameter(torch.zeros(len(TASKS), device=device))
|
| 650 |
+
bb_params = [p for p in wavlm.parameters() if p.requires_grad]
|
| 651 |
+
head_params = list(heads.parameters()) + (list(enc.parameters()) if USE_MAMBA else []) \
|
| 652 |
+
+ ([log_var] if USE_UNCERTAINTY else [])
|
| 653 |
+
_lr_scale = RESUME_LR_SCALE if RESUME else 1.0
|
| 654 |
+
opt = torch.optim.AdamW([
|
| 655 |
+
{"params": bb_params, "lr": LR_BACKBONE * _lr_scale},
|
| 656 |
+
{"params": head_params, "lr": LR_HEAD * _lr_scale},
|
| 657 |
+
], weight_decay=WEIGHT_DECAY)
|
| 658 |
+
if RESUME and _lr_scale != 1.0:
|
| 659 |
+
print(f"🔁 RESUME: LR ×{_lr_scale} → backbone {LR_BACKBONE*_lr_scale:.1e} · head {LR_HEAD*_lr_scale:.1e}")
|
| 660 |
+
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP and device == "cuda")
|
| 661 |
+
mse = nn.MSELoss()
|
| 662 |
+
|
| 663 |
+
def soft_ce(logits, target_dist):
|
| 664 |
+
return -(target_dist * F.log_softmax(logits, dim=1)).sum(1).mean()
|
| 665 |
+
|
| 666 |
+
def wavlm_branch(input_values, attn_mask):
|
| 667 |
+
out = wavlm(input_values, attention_mask=attn_mask).last_hidden_state # (B,T,D)
|
| 668 |
+
if USE_MAMBA:
|
| 669 |
+
return enc(out, frame_mask(out.shape[1], attn_mask)) # (B, Z_DIM)
|
| 670 |
+
return masked_mean(out, attn_mask) # (B, D)
|
| 671 |
+
|
| 672 |
+
def forward_batch(b):
|
| 673 |
+
fw = wavlm_branch(b["input_values"].to(device), b["attn_mask"].to(device))
|
| 674 |
+
feat = torch.cat([fw, b["aud"].to(device)], dim=1) if USE_AUDEERING else fw
|
| 675 |
+
return heads(feat, b["tgt"].to(device))
|
| 676 |
+
|
| 677 |
+
def pairwise_rank_loss(pred, target):
|
| 678 |
+
"""Hinge ranking trên MỌI cặp trong batch → tối ưu thẳng thứ hạng (≈ SRCC). Khả vi (backprop được).
|
| 679 |
+
Cần ≥2 mẫu/batch mới có cặp; batch càng to càng nhiều cặp → tín hiệu càng mạnh."""
|
| 680 |
+
p = pred.reshape(-1); t = target.reshape(-1)
|
| 681 |
+
if p.numel() < 2:
|
| 682 |
+
return torch.zeros((), device=p.device)
|
| 683 |
+
sign = torch.sign(t.unsqueeze(0) - t.unsqueeze(1)) # +1 nếu câu i ĐÁNG cao hơn câu j
|
| 684 |
+
diff = p.unsqueeze(0) - p.unsqueeze(1) # chênh lệch model dự đoán
|
| 685 |
+
return torch.relu(-sign * diff).mean() # phạt khi xếp sai thứ tự
|
| 686 |
+
|
| 687 |
+
def compute_loss(emos_p, cat_l, vad_p, b):
|
| 688 |
+
L = {"emos": mse(emos_p, b["emos"].to(device)), "cat": soft_ce(cat_l, b["cat"].to(device))}
|
| 689 |
+
if HAS_VAD:
|
| 690 |
+
vt = b["vad"].to(device)
|
| 691 |
+
L["val"] = mse(vad_p[:, 0:1], vt[:, 0:1]); L["aro"] = mse(vad_p[:, 1:2], vt[:, 1:2]); L["dom"] = mse(vad_p[:, 2:3], vt[:, 2:3])
|
| 692 |
+
else:
|
| 693 |
+
vt = None
|
| 694 |
+
z = torch.zeros((), device=device); L["val"] = L["aro"] = L["dom"] = z
|
| 695 |
+
# Ranking loss CHỈ cho các cột chấm SRCC (emos/val/aro/dom). CAT là ERR phân bố → giữ soft-CE.
|
| 696 |
+
if RANK_LAMBDA > 0:
|
| 697 |
+
L["emos"] = L["emos"] + RANK_LAMBDA * pairwise_rank_loss(emos_p, b["emos"].to(device))
|
| 698 |
+
if HAS_VAD:
|
| 699 |
+
L["val"] = L["val"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 0:1], vt[:, 0:1])
|
| 700 |
+
L["aro"] = L["aro"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 1:2], vt[:, 1:2])
|
| 701 |
+
L["dom"] = L["dom"] + RANK_LAMBDA * pairwise_rank_loss(vad_p[:, 2:3], vt[:, 2:3])
|
| 702 |
+
if USE_UNCERTAINTY:
|
| 703 |
+
return sum(torch.exp(-log_var[i]) * L[t] + log_var[i] for i, t in enumerate(TASKS))
|
| 704 |
+
return sum(L.values())
|
| 705 |
+
|
| 706 |
+
def set_mode(train):
|
| 707 |
+
wavlm.train(train); heads.train(train)
|
| 708 |
+
if USE_MAMBA:
|
| 709 |
+
enc.train(train)
|
| 710 |
+
|
| 711 |
+
@torch.no_grad()
|
| 712 |
+
def evaluate():
|
| 713 |
+
set_mode(False)
|
| 714 |
+
P = {"emos": [], "val": [], "aro": [], "dom": []}; Y = {"emos": [], "val": [], "aro": [], "dom": []}
|
| 715 |
+
catP, catY = [], []
|
| 716 |
+
for b in va_loader:
|
| 717 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 718 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 719 |
+
P["emos"] += emos_p.float().cpu().numpy().ravel().tolist(); Y["emos"] += b["emos_raw"].tolist()
|
| 720 |
+
vad_p = vad_p.float().cpu().numpy()
|
| 721 |
+
for j, t in enumerate(["val", "aro", "dom"]):
|
| 722 |
+
P[t] += vad_p[:, j].tolist(); Y[t] += b["vad_raw"][:, j].tolist()
|
| 723 |
+
catP.append(F.softmax(cat_l, 1).float().cpu().numpy()); catY.append(b["cat"])
|
| 724 |
+
out = {t: spearmanr(P[t], Y[t]).correlation for t in ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])}
|
| 725 |
+
q = np.concatenate(catP); p = np.concatenate(catY)
|
| 726 |
+
out["cat_err"] = float(np.abs(q - p).sum(1).mean())
|
| 727 |
+
return out
|
| 728 |
+
|
| 729 |
+
def mean_srcc(m):
|
| 730 |
+
keys = ["emos"] + (["val", "aro", "dom"] if HAS_VAD else [])
|
| 731 |
+
return float(np.mean([m[k] for k in keys]))
|
| 732 |
+
|
| 733 |
+
CKPT_PATH = os.path.join(OUT_DIR, "ft_mamba_emotion_full.pt")
|
| 734 |
+
def save_full_ckpt(state, val_emos=float("nan")):
|
| 735 |
+
torch.save({"wavlm": state["wavlm"], "heads": state["heads"], "enc": state.get("enc"),
|
| 736 |
+
"USE_MAMBA": USE_MAMBA, "emos_mu": emos_mu, "emos_sd": emos_sd,
|
| 737 |
+
"vad_mu": vad_mu, "vad_sd": vad_sd, "WAVLM_DIM": WAVLM_DIM, "AUD_DIM": AUD_DIM,
|
| 738 |
+
"Z_DIM": Z_DIM, "UNFREEZE_TOP_LAYERS": UNFREEZE_TOP_LAYERS,
|
| 739 |
+
"val_emos": float(val_emos)}, CKPT_PATH)
|
| 740 |
+
|
| 741 |
+
def snapshot():
|
| 742 |
+
s = {"wavlm": {k: v.cpu().clone() for k, v in wavlm.state_dict().items()},
|
| 743 |
+
"heads": {k: v.cpu().clone() for k, v in heads.state_dict().items()}}
|
| 744 |
+
if USE_MAMBA:
|
| 745 |
+
s["enc"] = {k: v.cpu().clone() for k, v in enc.state_dict().items()}
|
| 746 |
+
return s
|
| 747 |
+
|
| 748 |
+
# RESUME: init best = điểm VAL của ckpt hiện tại → chỉ ghi đè nếu train tiếp TỐT HƠN (không sợ tụt)
|
| 749 |
+
if RESUME and resume_ckpt is not None:
|
| 750 |
+
m0 = evaluate(); best = mean_srcc(m0); best_state = snapshot(); bad = 0
|
| 751 |
+
print(f"📍 RESUME — checkpoint hiện tại: mean SRCC={best:.4f} | "
|
| 752 |
+
+ " ".join(f"{k}={m0[k]:.3f}" for k in ['emos', 'val', 'aro', 'dom'] if k in m0))
|
| 753 |
+
else:
|
| 754 |
+
m0 = None
|
| 755 |
+
best, best_state, bad = -1e9, None, 0
|
| 756 |
+
for ep in range(1, EPOCHS + 1):
|
| 757 |
+
set_mode(True)
|
| 758 |
+
opt.zero_grad(); run = 0.0; nb = 0
|
| 759 |
+
for step, b in enumerate(tqdm(tr_loader, desc=f"epoch {ep}")):
|
| 760 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 761 |
+
emos_p, cat_l, vad_p = forward_batch(b)
|
| 762 |
+
loss = compute_loss(emos_p, cat_l, vad_p, b) / ACCUM
|
| 763 |
+
scaler.scale(loss).backward()
|
| 764 |
+
if (step + 1) % ACCUM == 0:
|
| 765 |
+
scaler.step(opt); scaler.update(); opt.zero_grad()
|
| 766 |
+
run += loss.item() * ACCUM; nb += 1
|
| 767 |
+
m = evaluate(); sc = mean_srcc(m)
|
| 768 |
+
msg = " ".join(f"{k}={m[k]:.3f}" for k in ["emos", "val", "aro", "dom"] if k in m)
|
| 769 |
+
print(f"epoch {ep:2d} | loss {run/max(nb,1):.4f} | {msg} | cat_err {m['cat_err']:.3f} | mean {sc:.4f} (best {max(best,sc):.4f})")
|
| 770 |
+
if sc > best:
|
| 771 |
+
best = sc; bad = 0
|
| 772 |
+
best_state = snapshot()
|
| 773 |
+
save_full_ckpt(best_state, m["emos"])
|
| 774 |
+
print(f" 💾 lưu best → {CKPT_PATH} (epoch {ep}, mean {sc:.4f})")
|
| 775 |
+
else:
|
| 776 |
+
bad += 1
|
| 777 |
+
if bad >= PATIENCE:
|
| 778 |
+
print(f"Early stop ở epoch {ep}."); break
|
| 779 |
+
|
| 780 |
+
if best_state:
|
| 781 |
+
wavlm.load_state_dict(best_state["wavlm"]); heads.load_state_dict(best_state["heads"])
|
| 782 |
+
if USE_MAMBA:
|
| 783 |
+
enc.load_state_dict(best_state["enc"])
|
| 784 |
+
final = evaluate()
|
| 785 |
+
if RESUME and m0 is not None:
|
| 786 |
+
print(f"\n🔁 RESUME: mean SRCC ckpt {mean_srcc(m0):.4f} → sau train tiếp {mean_srcc(final):.4f} "
|
| 787 |
+
+ ("🚀 cải thiện → đã ghi đè ckpt" if mean_srcc(final) > mean_srcc(m0) + 1e-4 else "➖ không cải thiện (giữ best cũ)"))
|
| 788 |
+
print(f"\n✅ VAL (nội bộ) — exp15 (Mamba={'ON' if USE_MAMBA else 'OFF'}):")
|
| 789 |
+
print(f" EMOS={final['emos']:.4f} (exp08 {EXP08['emos']})")
|
| 790 |
+
if HAS_VAD:
|
| 791 |
+
print(f" VAL/ARO/DOM={final['val']:.4f}/{final['aro']:.4f}/{final['dom']:.4f} "
|
| 792 |
+
f"(exp08 {EXP08['val']}/{EXP08['aro']}/{EXP08['dom']})")
|
| 793 |
+
warn = [f"EMOS {final['emos']:.3f}<{EXP08['emos']}"] if final["emos"] < EXP08["emos"] - 0.005 else []
|
| 794 |
+
if HAS_VAD:
|
| 795 |
+
warn += [f"{t.upper()} {final[t]:.3f}<{EXP08[t]}" for t in ["val", "aro", "dom"] if final[t] < EXP08[t] - 0.005]
|
| 796 |
+
print(" ⚠️ Mamba head CHƯA thắng exp08 ở:", "; ".join(warn), "(vẫn là kết quả cho paper)" if warn else "")
|
| 797 |
+
if not warn:
|
| 798 |
+
print(" ✅ Mamba head thắng/ngang exp08 ở mọi cột → temporal modeling có ích!")
|
| 799 |
+
save_full_ckpt(best_state if best_state else
|
| 800 |
+
{"wavlm": wavlm.state_dict(), "heads": heads.state_dict(),
|
| 801 |
+
"enc": enc.state_dict() if USE_MAMBA else None}, final["emos"])
|
| 802 |
+
print(f"✅ Đã lưu {CKPT_PATH} (CÓ backbone + Mamba + heads). NHỚ Save Version!")
|
| 803 |
+
|
| 804 |
+
# %% [markdown]
|
| 805 |
+
# ## 7. Dự đoán DEV → answer.txt (5 cột cảm xúc exp15; QMOS mượn exp07/UTMOSv2)
|
| 806 |
+
|
| 807 |
+
# %%
|
| 808 |
+
def list_dev():
|
| 809 |
+
with open(DEV_SCP) as f:
|
| 810 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 811 |
+
|
| 812 |
+
dev_names = list_dev()
|
| 813 |
+
if LIMIT_DEV:
|
| 814 |
+
dev_names = dev_names[:LIMIT_DEV]
|
| 815 |
+
dev_stems = [stem(n) for n in dev_names]
|
| 816 |
+
print("DEV:", len(dev_names), "mẫu")
|
| 817 |
+
aud_dev = extract_audeering(dev_stems, "dev")
|
| 818 |
+
|
| 819 |
+
def load_exp07_qmos():
|
| 820 |
+
if EXP07_ANSWER and os.path.exists(EXP07_ANSWER):
|
| 821 |
+
import csv
|
| 822 |
+
d = {}
|
| 823 |
+
with open(EXP07_ANSWER) as f:
|
| 824 |
+
for row in csv.DictReader(f):
|
| 825 |
+
d[row["wav"]] = float(row["QMOS"]); d[stem(row["wav"])] = float(row["QMOS"])
|
| 826 |
+
print(f"✅ Mượn QMOS exp07 ({EXP07_ANSWER}): {len(d)//2} wav")
|
| 827 |
+
return d
|
| 828 |
+
return None
|
| 829 |
+
|
| 830 |
+
qmos_map = load_exp07_qmos()
|
| 831 |
+
if qmos_map is None:
|
| 832 |
+
print("ℹ️ Không có answer.txt exp07 → chấm QMOS bằng UTMOSv2 (T05, vô địch VMC2024).")
|
| 833 |
+
pip_install("git+https://github.com/sarulab-speech/UTMOSv2.git")
|
| 834 |
+
import utmosv2
|
| 835 |
+
v2 = utmosv2.create_model(pretrained=True)
|
| 836 |
+
qmos_map = {}
|
| 837 |
+
for n in tqdm(dev_names, desc="UTMOSv2"):
|
| 838 |
+
wav = os.path.join(WAV_DIR, n if str(n).endswith(".wav") else str(n) + ".wav")
|
| 839 |
+
if not os.path.exists(wav):
|
| 840 |
+
continue
|
| 841 |
+
out = v2.predict(input_path=wav)
|
| 842 |
+
qmos_map[n] = float(out["predicted_mos"]) if isinstance(out, dict) else float(out)
|
| 843 |
+
del v2; torch.cuda.empty_cache() if device == "cuda" else None
|
| 844 |
+
|
| 845 |
+
@torch.no_grad()
|
| 846 |
+
def predict_emotion(sid):
|
| 847 |
+
wave = load_wav(sid)
|
| 848 |
+
if wave is None or (USE_AUDEERING and sid not in aud_dev):
|
| 849 |
+
return None
|
| 850 |
+
set_mode(False)
|
| 851 |
+
iv = torch.from_numpy(wave).unsqueeze(0).to(device)
|
| 852 |
+
am = torch.ones((1, len(wave)), dtype=torch.long, device=device)
|
| 853 |
+
tgt = torch.from_numpy(onehot_target(target_map.get(sid))).unsqueeze(0).to(device)
|
| 854 |
+
with torch.cuda.amp.autocast(enabled=USE_AMP and device == "cuda"):
|
| 855 |
+
fw = wavlm_branch(iv, am)
|
| 856 |
+
feat = torch.cat([fw, torch.from_numpy(aud_dev[sid]).unsqueeze(0).to(device)], dim=1) if USE_AUDEERING else fw
|
| 857 |
+
emos_p, cat_l, vad_p = heads(feat, tgt)
|
| 858 |
+
emos = float(emos_p.item()) * emos_sd + emos_mu
|
| 859 |
+
cat5 = F.softmax(cat_l, 1)[0].float().cpu().numpy()
|
| 860 |
+
vad3 = vad_p[0].float().cpu().numpy() * vad_sd + vad_mu
|
| 861 |
+
return emos, cat5, vad3
|
| 862 |
+
|
| 863 |
+
def fmt_cat(p5):
|
| 864 |
+
return "|".join(f"{e}:{p5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 865 |
+
|
| 866 |
+
def build_answer(out_path):
|
| 867 |
+
n_real = n_def = 0
|
| 868 |
+
with open(out_path, "w") as f:
|
| 869 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 870 |
+
for name in tqdm(dev_names, desc="answer"):
|
| 871 |
+
sid = stem(name)
|
| 872 |
+
pr = predict_emotion(sid)
|
| 873 |
+
if pr is None:
|
| 874 |
+
emos, cat5, vad3 = 3.0, np.full(5, 0.2, np.float32), np.array([3.0, 3.0, 3.0]); n_def += 1
|
| 875 |
+
else:
|
| 876 |
+
emos, cat5, vad3 = pr; n_real += 1
|
| 877 |
+
qmos = qmos_map.get(name, qmos_map.get(sid, 3.0))
|
| 878 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},{vad3[0]:.6g},{vad3[1]:.6g},{vad3[2]:.6g}\n")
|
| 879 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | cảm xúc thật {n_real}, mặc định {n_def}")
|
| 880 |
+
|
| 881 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 882 |
+
build_answer(answer_path)
|
| 883 |
+
|
| 884 |
+
# %% [markdown]
|
| 885 |
+
# ## 8. Validate + đóng zip
|
| 886 |
+
|
| 887 |
+
# %%
|
| 888 |
+
def validate(path):
|
| 889 |
+
import csv
|
| 890 |
+
with open(path) as f:
|
| 891 |
+
rows = list(csv.reader(f))
|
| 892 |
+
assert rows[0][0] == "wav" and "QMOS" in rows[0] and "EMOS" in rows[0], "Header sai"
|
| 893 |
+
for i, r in enumerate(rows[1:], 2):
|
| 894 |
+
assert len(r) == len(rows[0]), f"Dòng {i} sai số cột"
|
| 895 |
+
print(f"OK: {len(rows)-1} dòng, header = {rows[0]}")
|
| 896 |
+
|
| 897 |
+
validate(answer_path)
|
| 898 |
+
os.system(f"cd {OUT_DIR} && zip -j submission_track2_exp15_mamba-emotion.zip answer.txt "
|
| 899 |
+
f"&& unzip -l submission_track2_exp15_mamba-emotion.zip")
|
| 900 |
+
print("Sẵn sàng nộp:", os.path.join(OUT_DIR, "submission_track2_exp15_mamba-emotion.zip"))
|
| 901 |
+
|
| 902 |
+
# %% [markdown]
|
| 903 |
+
# ## Ghi chú
|
| 904 |
+
# - **🔁 RESUME (train tiếp, không train lại từ đầu):** Add Input dataset chứa `ft_mamba_emotion_full.pt` của lần
|
| 905 |
+
# chạy trước (hoặc để nó nằm sẵn trong `/kaggle/working` khi chạy nối phiên) → notebook tự dò & train tiếp.
|
| 906 |
+
# `EPOCHS` lúc này là **số epoch train THÊM**. Val chững → đặt `RESUME_LR_SCALE=0.5`. Muốn ép train mới: `RESUME_CKPT="—"`
|
| 907 |
+
# (đường dẫn không tồn tại) hoặc xóa ckpt khỏi input. ⚠️ `USE_MAMBA` phải KHỚP ckpt (code sẽ cảnh báo nếu lệch).
|
| 908 |
+
# - **Lần đầu** `LIMIT_TRAIN=300`, `LIMIT_DEV=20` → kiểm 1 epoch không OOM / không CheckpointError; rồi đặt `None`.
|
| 909 |
+
# - **Ablation chính cho paper:** chạy `USE_MAMBA=True` vs `USE_MAMBA=False` (=exp08) → so EMOS/VAL/ARO/DOM nội bộ
|
| 910 |
+
# → trả lời "Mamba temporal head có hơn mean-pooling không?".
|
| 911 |
+
# - **OOM / quá chậm trên T4 (nhất là khi dùng Mamba thuần PyTorch):** giảm theo thứ tự
|
| 912 |
+
# `MAX_SECONDS` (6→5) → `MAMBA_LAYERS` (2→1) → `UNFREEZE_TOP_LAYERS` (6→4) → `BATCH` (2→1, tăng `ACCUM`).
|
| 913 |
+
# Hoặc thử cài `mamba-ssm causal-conv1d` (nhanh + nhẹ RAM hơn nhiều) — code tự dùng nếu import được.
|
| 914 |
+
# - **Ranking loss (`RANK_LAMBDA`):** thêm pairwise ranking cho 4 cột SRCC (emos/val/aro/dom) → khớp metric
|
| 915 |
+
# UTT-SRCC hơn MSE. ⚠️ **Điểm yếu:** ranking tính trên các cặp TRONG 1 mini-batch; `BATCH=2` → mỗi forward
|
| 916 |
+
# chỉ có 1 cặp → tín hiệu YẾU. Muốn ranking mạnh: tăng `BATCH` (4→8 nếu VRAM chịu được). Ở các exp head
|
| 917 |
+
# ĐÓNG BĂNG (exp06/07, BATCH=64) ranking mạnh hơn nhiều. A/B `RANK_LAMBDA=0` vs `0.3` → bảng ablation cho paper.
|
| 918 |
+
# - **QMOS:** Add Input answer.txt exp07 vào `/kaggle/input/exp07-answer/answer.txt` để mượn QMOS 0.548;
|
| 919 |
+
# không có thì tự chấm UTMOSv2 (cần Internet On).
|
| 920 |
+
# - Ghi config → kết quả → nhận xét vào `docs/04_experiments_log.md` (mục exp15).
|
track2/exp16_llm_judge.ipynb
ADDED
|
@@ -0,0 +1,650 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "7bae4e03",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# exp16 — Audio-LLM-as-Judge cho MOS cảm xúc (Track 2)\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"**Ý tưởng:** đưa thẳng audio cho một **audio-LLM** (Gemini / GPT-4o-audio) qua **API** + prompt có\n",
|
| 11 |
+
"cấu trúc → bắt nó chấm cả 6 cột (`QMOS, EMOS, CAT, VAL, ARO, DOM`) → ráp `answer.txt` → nộp CodaBench.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"**Mục tiêu chính = NOVELTY cho paper** (khảo sát có hệ thống audio-LLM-as-judge cho MOS cảm xúc),\n",
|
| 14 |
+
"so với hệ SSL đã train (exp07 QMOS 0.548 · exp08 EMOS 0.811…). KHÔNG cần GPU — thuần gọi API.\n",
|
| 15 |
+
"\n",
|
| 16 |
+
"| Đặc điểm | Giá trị |\n",
|
| 17 |
+
"|---|---|\n",
|
| 18 |
+
"| GPU | ❌ không cần (chỉ network I/O) |\n",
|
| 19 |
+
"| Tốn phí | ✅ API trả tiền theo token/audio → **cache + resume bắt buộc** |\n",
|
| 20 |
+
"| Provider | `gemini` (mặc định, đã có billing) · `openai` (GPT-4o-audio, để so 2 LLM) |\n",
|
| 21 |
+
"| Output | `answer.txt` 6 cột giống exp07 |\n",
|
| 22 |
+
"\n",
|
| 23 |
+
"**Cách dùng Kaggle:** Internet = **On**; Add-ons → Secrets: `GEMINI_API_KEY` (và `OPENAI_API_KEY`\n",
|
| 24 |
+
"nếu chạy provider openai). Settings GPU **không cần**. Sửa `DATA_ROOT` cho khớp slug rồi Run All.\n",
|
| 25 |
+
"\n",
|
| 26 |
+
"⚠️ **Model ID có thể đã đổi** theo thời gian → kiểm tra `GEMINI_MODEL` / `OPENAI_MODEL` còn nhận\n",
|
| 27 |
+
"audio không trước khi chạy full (xem mục 1)."
|
| 28 |
+
]
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"cell_type": "markdown",
|
| 32 |
+
"id": "720a7dc2",
|
| 33 |
+
"metadata": {},
|
| 34 |
+
"source": [
|
| 35 |
+
"## 0. Cấu hình — SỬA Ở ĐÂY"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "code",
|
| 40 |
+
"execution_count": null,
|
| 41 |
+
"id": "c583b4dc",
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"import os, io, re, json, time, base64, glob\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"# ── Data Track 2 trên Kaggle ────────────────────────────────────────────────\n",
|
| 48 |
+
"DATA_ROOT = \"/kaggle/input/vmc2026-track2-full/vmc2026-track2\" # << SỬA slug\n",
|
| 49 |
+
"WAV_DIR = f\"{DATA_ROOT}/wav\"\n",
|
| 50 |
+
"METADATA_CSV = f\"{DATA_ROOT}/metadata.csv\" # wavID|emotion|transcript (KHÔNG header) — nhãn cảm xúc target\n",
|
| 51 |
+
"DEV_SCP = f\"{DATA_ROOT}/sets/dev.scp\" # danh sách wav DEV cần nộp (train phase)\n",
|
| 52 |
+
"TRAIN_CSV = f\"{DATA_ROOT}/sets/train.csv\" # chỉ cần khi SHOT_MODE=\"few_shot\"\n",
|
| 53 |
+
"\n",
|
| 54 |
+
"OUT_DIR = \"/kaggle/working\"\n",
|
| 55 |
+
"CACHE_DIR = \"/kaggle/working/exp16_llm_cache\" # nên Save Version / lưu Dataset để KHÔNG gọi lại API\n",
|
| 56 |
+
"os.makedirs(CACHE_DIR, exist_ok=True)\n",
|
| 57 |
+
"\n",
|
| 58 |
+
"# ── Provider & model ────────────────────────────────────────────────────────\n",
|
| 59 |
+
"PROVIDER = \"gemini\" # \"gemini\" | \"openai\"\n",
|
| 60 |
+
"GEMINI_MODEL = \"gemini-2.5-flash\" # << xác nhận model audio hiện hành (baseline dùng họ gemini-*-flash)\n",
|
| 61 |
+
"OPENAI_MODEL = \"gpt-4o-audio-preview\" # << model audio của OpenAI; cần OPENAI_API_KEY\n",
|
| 62 |
+
"TEMPERATURE = 0.0 # cố định để TÁI LẬP (paper)\n",
|
| 63 |
+
"\n",
|
| 64 |
+
"# ── Chế độ chạy ─────────────────────────────────────────────────────────────\n",
|
| 65 |
+
"SHOT_MODE = \"zero_shot\" # \"zero_shot\" | \"few_shot\" (nhét K ví dụ audio có nhãn từ train.csv)\n",
|
| 66 |
+
"FEW_K = 2 # số ví dụ few-shot (mỗi ví dụ = 1 audio + nhãn vàng) — tốn thêm token!\n",
|
| 67 |
+
"LIMIT = 20 # << số nhỏ (20) để smoke test; None = full DEV (~2730) — CHẠY THỬ TRƯỚC\n",
|
| 68 |
+
"MAX_SECONDS = 12 # cắt audio cho rẻ + nhanh\n",
|
| 69 |
+
"WORKERS = 4 # luồng gọi song song (giảm nếu dính rate limit)\n",
|
| 70 |
+
"MAX_RETRY = 3 # số lần thử lại 1 wav khi lỗi mạng / JSON hỏng\n",
|
| 71 |
+
"RETRY_SLEEP = 2.0 # giây nghỉ giữa các lần thử\n",
|
| 72 |
+
"\n",
|
| 73 |
+
"TAG = f\"{PROVIDER}_{(GEMINI_MODEL if PROVIDER=='gemini' else OPENAI_MODEL)}_{SHOT_MODE}\".replace(\"/\", \"-\")\n",
|
| 74 |
+
"CACHE_PATH = os.path.join(CACHE_DIR, f\"{TAG}.jsonl\") # 1 dòng JSON / wav (raw + parsed) → resume\n",
|
| 75 |
+
"print(\"TAG:\", TAG, \"| cache:\", CACHE_PATH)"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "markdown",
|
| 80 |
+
"id": "f2cc876e",
|
| 81 |
+
"metadata": {},
|
| 82 |
+
"source": [
|
| 83 |
+
"## 0b. Nhãn cảm xúc target + chuẩn hóa lớp (tái dùng quy ước baseline)"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"id": "5b5c7f92",
|
| 90 |
+
"metadata": {
|
| 91 |
+
"lines_to_next_cell": 1
|
| 92 |
+
},
|
| 93 |
+
"outputs": [],
|
| 94 |
+
"source": [
|
| 95 |
+
"EMOTIONS5 = [\"angry\", \"happy\", \"neutral\", \"sad\", \"surprised\"] # THỨ TỰ chuẩn cho cột CAT\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"_EMO_ALIAS = {\n",
|
| 98 |
+
" \"angry\": \"angry\", \"anger\": \"angry\",\n",
|
| 99 |
+
" \"happy\": \"happy\", \"happiness\": \"happy\", \"joy\": \"happy\",\n",
|
| 100 |
+
" \"neutral\": \"neutral\", \"calm\": \"neutral\",\n",
|
| 101 |
+
" \"sad\": \"sad\", \"sadness\": \"sad\",\n",
|
| 102 |
+
" \"surprise\": \"surprised\", \"surprised\": \"surprised\", \"surprising\": \"surprised\",\n",
|
| 103 |
+
"}\n",
|
| 104 |
+
"\n",
|
| 105 |
+
"def norm_emotion(label):\n",
|
| 106 |
+
" key = str(label).strip().lower()\n",
|
| 107 |
+
" return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n",
|
| 108 |
+
"\n",
|
| 109 |
+
"def stem(name):\n",
|
| 110 |
+
" return os.path.splitext(os.path.basename(name))[0]\n",
|
| 111 |
+
"\n",
|
| 112 |
+
"def load_target_emotions():\n",
|
| 113 |
+
" \"\"\"metadata.csv (wavID|emotion|transcript, không header) → {stem: emotion_chuẩn}.\"\"\"\n",
|
| 114 |
+
" tgt = {}\n",
|
| 115 |
+
" if not (METADATA_CSV and os.path.exists(METADATA_CSV)):\n",
|
| 116 |
+
" print(\"⚠️ Không thấy metadata.csv → EMOS sẽ thiếu cảm xúc target.\")\n",
|
| 117 |
+
" return tgt\n",
|
| 118 |
+
" with open(METADATA_CSV, encoding=\"utf-8\") as f:\n",
|
| 119 |
+
" for ln in f:\n",
|
| 120 |
+
" parts = ln.strip().split(\"|\")\n",
|
| 121 |
+
" if len(parts) < 2:\n",
|
| 122 |
+
" continue\n",
|
| 123 |
+
" tgt[stem(parts[0])] = norm_emotion(parts[1])\n",
|
| 124 |
+
" return tgt\n",
|
| 125 |
+
"\n",
|
| 126 |
+
"target_map = load_target_emotions()\n",
|
| 127 |
+
"print(\"Nhãn cảm xúc target:\", len(target_map))\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"def list_dev():\n",
|
| 130 |
+
" with open(DEV_SCP) as f:\n",
|
| 131 |
+
" return [ln.strip() for ln in f if ln.strip()]\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"dev_names = list_dev()\n",
|
| 134 |
+
"if LIMIT:\n",
|
| 135 |
+
" dev_names = dev_names[:LIMIT]\n",
|
| 136 |
+
"print(\"DEV cần chấm:\", len(dev_names), \"mẫu\", \"| LIMIT =\", LIMIT)"
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
{
|
| 140 |
+
"cell_type": "markdown",
|
| 141 |
+
"id": "881a37a5",
|
| 142 |
+
"metadata": {},
|
| 143 |
+
"source": [
|
| 144 |
+
"## 1. Cài SDK + nạp key\n",
|
| 145 |
+
"\n",
|
| 146 |
+
"Gemini dùng SDK mới `google-genai`; OpenAI dùng `openai`. Trên Kaggle **Internet phải On**."
|
| 147 |
+
]
|
| 148 |
+
},
|
| 149 |
+
{
|
| 150 |
+
"cell_type": "code",
|
| 151 |
+
"execution_count": null,
|
| 152 |
+
"id": "a1d0c66b",
|
| 153 |
+
"metadata": {},
|
| 154 |
+
"outputs": [],
|
| 155 |
+
"source": [
|
| 156 |
+
"!pip -q install google-genai openai soundfile librosa\n",
|
| 157 |
+
"\n",
|
| 158 |
+
"def setup_keys():\n",
|
| 159 |
+
" \"\"\"Nạp API key từ Kaggle Secrets (fallback: biến môi trường đã set sẵn).\"\"\"\n",
|
| 160 |
+
" try:\n",
|
| 161 |
+
" from kaggle_secrets import UserSecretsClient\n",
|
| 162 |
+
" sec = UserSecretsClient()\n",
|
| 163 |
+
" for k in [\"GEMINI_API_KEY\", \"OPENAI_API_KEY\"]:\n",
|
| 164 |
+
" try:\n",
|
| 165 |
+
" os.environ[k] = sec.get_secret(k)\n",
|
| 166 |
+
" print(f\"Đã nạp {k} từ Secrets\")\n",
|
| 167 |
+
" except Exception:\n",
|
| 168 |
+
" pass\n",
|
| 169 |
+
" except Exception as e:\n",
|
| 170 |
+
" print(\"Không dùng được Kaggle Secrets:\", e, \"→ set tay os.environ[...] nếu cần\")\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"setup_keys()"
|
| 173 |
+
]
|
| 174 |
+
},
|
| 175 |
+
{
|
| 176 |
+
"cell_type": "markdown",
|
| 177 |
+
"id": "a4ceeacf",
|
| 178 |
+
"metadata": {},
|
| 179 |
+
"source": [
|
| 180 |
+
"## 2. Đọc + chuẩn hóa audio (16kHz mono, cắt MAX_SECONDS) → bytes WAV trong RAM"
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"cell_type": "code",
|
| 185 |
+
"execution_count": null,
|
| 186 |
+
"id": "68d431ff",
|
| 187 |
+
"metadata": {
|
| 188 |
+
"lines_to_next_cell": 1
|
| 189 |
+
},
|
| 190 |
+
"outputs": [],
|
| 191 |
+
"source": [
|
| 192 |
+
"import numpy as np\n",
|
| 193 |
+
"\n",
|
| 194 |
+
"def load_wav_bytes(path, sr=16000, max_seconds=MAX_SECONDS):\n",
|
| 195 |
+
" \"\"\"Trả (wav_bytes, base64_str). Cắt ≤ max_seconds, resample 16k mono, encode WAV PCM16.\"\"\"\n",
|
| 196 |
+
" import soundfile as sf\n",
|
| 197 |
+
" try:\n",
|
| 198 |
+
" import librosa\n",
|
| 199 |
+
" y, _ = librosa.load(path, sr=sr, mono=True)\n",
|
| 200 |
+
" except Exception:\n",
|
| 201 |
+
" y, in_sr = sf.read(path)\n",
|
| 202 |
+
" if y.ndim > 1:\n",
|
| 203 |
+
" y = y.mean(axis=1)\n",
|
| 204 |
+
" if in_sr != sr: # fallback resample tuyến tính nếu không có librosa\n",
|
| 205 |
+
" idx = np.linspace(0, len(y) - 1, int(len(y) * sr / in_sr))\n",
|
| 206 |
+
" y = np.interp(idx, np.arange(len(y)), y)\n",
|
| 207 |
+
" if max_seconds:\n",
|
| 208 |
+
" y = y[: int(sr * max_seconds)]\n",
|
| 209 |
+
" buf = io.BytesIO()\n",
|
| 210 |
+
" sf.write(buf, y.astype(np.float32), sr, format=\"WAV\", subtype=\"PCM_16\")\n",
|
| 211 |
+
" raw = buf.getvalue()\n",
|
| 212 |
+
" return raw, base64.b64encode(raw).decode(\"ascii\")"
|
| 213 |
+
]
|
| 214 |
+
},
|
| 215 |
+
{
|
| 216 |
+
"cell_type": "markdown",
|
| 217 |
+
"id": "9d846428",
|
| 218 |
+
"metadata": {},
|
| 219 |
+
"source": [
|
| 220 |
+
"## 3. Prompt — định nghĩa 6 metric + ép JSON nghiêm ngặt\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"QMOS = chất lượng/độ tự nhiên (sạch, không méo/robot). EMOS = độ KHỚP với **cảm xúc target**.\n",
|
| 223 |
+
"CAT = phân phối vote 5 lớp. VAD = Valence/Arousal/Dominance. Tất cả thang **1–5** (CAT là tỉ lệ 0–1)."
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "code",
|
| 228 |
+
"execution_count": null,
|
| 229 |
+
"id": "f7046919",
|
| 230 |
+
"metadata": {
|
| 231 |
+
"lines_to_next_cell": 1
|
| 232 |
+
},
|
| 233 |
+
"outputs": [],
|
| 234 |
+
"source": [
|
| 235 |
+
"SYSTEM_INSTRUCTION = (\n",
|
| 236 |
+
" \"You are an expert evaluator of emotional text-to-speech. \"\n",
|
| 237 |
+
" \"Listen to the audio and rate it. Respond with ONLY a compact JSON object, no prose.\"\n",
|
| 238 |
+
")\n",
|
| 239 |
+
"\n",
|
| 240 |
+
"def build_prompt(target_emo):\n",
|
| 241 |
+
" tgt = target_emo if target_emo else \"unknown\"\n",
|
| 242 |
+
" return (\n",
|
| 243 |
+
" \"Rate this speech utterance. The INTENDED (target) emotion is: \"\n",
|
| 244 |
+
" f\"\\\"{tgt}\\\".\\n\\n\"\n",
|
| 245 |
+
" \"Return a JSON object with EXACTLY these keys (numbers on a 1-5 scale unless stated):\\n\"\n",
|
| 246 |
+
" \" \\\"qmos\\\": overall audio QUALITY / naturalness (1=very unnatural/robotic/distorted, 5=clean & human-like).\\n\"\n",
|
| 247 |
+
" \" \\\"emos\\\": how well the emotion expressed MATCHES the target emotion above \"\n",
|
| 248 |
+
" \"(1=not matching at all, 5=perfectly matching).\\n\"\n",
|
| 249 |
+
" \" \\\"cat\\\": an object with probabilities (summing to 1.0) over the 5 perceived emotions: \"\n",
|
| 250 |
+
" \"{\\\"neutral\\\":_, \\\"happy\\\":_, \\\"sad\\\":_, \\\"angry\\\":_, \\\"surprised\\\":_}.\\n\"\n",
|
| 251 |
+
" \" \\\"val\\\": valence (1=very negative, 5=very positive).\\n\"\n",
|
| 252 |
+
" \" \\\"aro\\\": arousal (1=very calm, 5=very excited).\\n\"\n",
|
| 253 |
+
" \" \\\"dom\\\": dominance (1=very submissive, 5=very dominant).\\n\\n\"\n",
|
| 254 |
+
" \"Example format: \"\n",
|
| 255 |
+
" \"{\\\"qmos\\\":3.5,\\\"emos\\\":4.0,\"\n",
|
| 256 |
+
" \"\\\"cat\\\":{\\\"neutral\\\":0.1,\\\"happy\\\":0.7,\\\"sad\\\":0.0,\\\"angry\\\":0.1,\\\"surprised\\\":0.1},\"\n",
|
| 257 |
+
" \"\\\"val\\\":4.0,\\\"aro\\\":3.5,\\\"dom\\\":3.0}\\n\"\n",
|
| 258 |
+
" \"Respond with ONLY the JSON.\"\n",
|
| 259 |
+
" )"
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"cell_type": "markdown",
|
| 264 |
+
"id": "fe8d1303",
|
| 265 |
+
"metadata": {},
|
| 266 |
+
"source": [
|
| 267 |
+
"## 3b. (tùy chọn) Few-shot — lấy K ví dụ audio có nhãn vàng từ train.csv\n",
|
| 268 |
+
"\n",
|
| 269 |
+
"Bật khi `SHOT_MODE=\"few_shot\"`. Mỗi ví dụ = 1 audio train + nhãn vàng (gộp TB theo wav). Tốn thêm token."
|
| 270 |
+
]
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"cell_type": "code",
|
| 274 |
+
"execution_count": null,
|
| 275 |
+
"id": "5fdf89c7",
|
| 276 |
+
"metadata": {},
|
| 277 |
+
"outputs": [],
|
| 278 |
+
"source": [
|
| 279 |
+
"few_shot_examples = [] # list[(audio_b64, audio_bytes, gold_json_str)]\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"def _agg_train_labels():\n",
|
| 282 |
+
" \"\"\"Gộp train.csv (sep='|') theo wavID → nhãn vàng trung bình; CAT = tỉ lệ vote.\"\"\"\n",
|
| 283 |
+
" import pandas as pd\n",
|
| 284 |
+
" df = pd.read_csv(TRAIN_CSV, sep=\"|\")\n",
|
| 285 |
+
" rows = {}\n",
|
| 286 |
+
" for wav, g in df.groupby(\"wavID\"):\n",
|
| 287 |
+
" votes = np.zeros(5, np.float32)\n",
|
| 288 |
+
" for cell in g[\"emoCat\"].astype(str):\n",
|
| 289 |
+
" for tok in cell.split(\",\"):\n",
|
| 290 |
+
" e = norm_emotion(tok)\n",
|
| 291 |
+
" if e in EMOTIONS5:\n",
|
| 292 |
+
" votes[EMOTIONS5.index(e)] += 1\n",
|
| 293 |
+
" s = votes.sum()\n",
|
| 294 |
+
" cat = (votes / s) if s > 0 else np.full(5, 0.2, np.float32)\n",
|
| 295 |
+
" rows[stem(wav)] = dict(\n",
|
| 296 |
+
" qmos=float(g[\"qMOS\"].mean()), emos=float(g[\"eMOS\"].mean()),\n",
|
| 297 |
+
" val=float(g[\"val\"].mean()), aro=float(g[\"aro\"].mean()), dom=float(g[\"dom\"].mean()),\n",
|
| 298 |
+
" cat={EMOTIONS5[i]: round(float(cat[i]), 4) for i in range(5)},\n",
|
| 299 |
+
" )\n",
|
| 300 |
+
" return rows\n",
|
| 301 |
+
"\n",
|
| 302 |
+
"def build_few_shot():\n",
|
| 303 |
+
" if SHOT_MODE != \"few_shot\":\n",
|
| 304 |
+
" return\n",
|
| 305 |
+
" labels = _agg_train_labels()\n",
|
| 306 |
+
" picked = list(labels.keys())[:FEW_K]\n",
|
| 307 |
+
" for sid in picked:\n",
|
| 308 |
+
" wavp = os.path.join(WAV_DIR, sid + \".wav\")\n",
|
| 309 |
+
" if not os.path.exists(wavp):\n",
|
| 310 |
+
" continue\n",
|
| 311 |
+
" raw, b64 = load_wav_bytes(wavp)\n",
|
| 312 |
+
" gold = labels[sid]\n",
|
| 313 |
+
" gold_json = json.dumps({\n",
|
| 314 |
+
" \"qmos\": round(gold[\"qmos\"], 2), \"emos\": round(gold[\"emos\"], 2),\n",
|
| 315 |
+
" \"cat\": gold[\"cat\"], \"val\": round(gold[\"val\"], 2),\n",
|
| 316 |
+
" \"aro\": round(gold[\"aro\"], 2), \"dom\": round(gold[\"dom\"], 2),\n",
|
| 317 |
+
" })\n",
|
| 318 |
+
" few_shot_examples.append((b64, raw, gold_json))\n",
|
| 319 |
+
" print(f\"Few-shot: {len(few_shot_examples)} ví dụ\")\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"build_few_shot()"
|
| 322 |
+
]
|
| 323 |
+
},
|
| 324 |
+
{
|
| 325 |
+
"cell_type": "markdown",
|
| 326 |
+
"id": "4d3c1fef",
|
| 327 |
+
"metadata": {},
|
| 328 |
+
"source": [
|
| 329 |
+
"## 4. Gọi API — trừu tượng hóa provider (gemini / openai)\n",
|
| 330 |
+
"\n",
|
| 331 |
+
"Mỗi provider tự dựng message của nó (kèm few-shot nếu có). Trả về **text thô** để parse ở mục 5."
|
| 332 |
+
]
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"cell_type": "code",
|
| 336 |
+
"execution_count": null,
|
| 337 |
+
"id": "ae85c4bf",
|
| 338 |
+
"metadata": {
|
| 339 |
+
"lines_to_next_cell": 1
|
| 340 |
+
},
|
| 341 |
+
"outputs": [],
|
| 342 |
+
"source": [
|
| 343 |
+
"_client = {\"gemini\": None, \"openai\": None}\n",
|
| 344 |
+
"\n",
|
| 345 |
+
"def _gemini_client():\n",
|
| 346 |
+
" if _client[\"gemini\"] is None:\n",
|
| 347 |
+
" from google import genai\n",
|
| 348 |
+
" _client[\"gemini\"] = genai.Client(api_key=os.environ[\"GEMINI_API_KEY\"])\n",
|
| 349 |
+
" return _client[\"gemini\"]\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"def _openai_client():\n",
|
| 352 |
+
" if _client[\"openai\"] is None:\n",
|
| 353 |
+
" from openai import OpenAI\n",
|
| 354 |
+
" _client[\"openai\"] = OpenAI(api_key=os.environ[\"OPENAI_API_KEY\"])\n",
|
| 355 |
+
" return _client[\"openai\"]\n",
|
| 356 |
+
"\n",
|
| 357 |
+
"def call_gemini(audio_b64, audio_bytes, prompt):\n",
|
| 358 |
+
" from google.genai import types\n",
|
| 359 |
+
" client = _gemini_client()\n",
|
| 360 |
+
" contents = []\n",
|
| 361 |
+
" for ex_b64, ex_bytes, ex_gold in few_shot_examples: # few-shot: audio ví dụ + nhãn vàng\n",
|
| 362 |
+
" contents.append(types.Content(role=\"user\", parts=[\n",
|
| 363 |
+
" types.Part.from_bytes(data=ex_bytes, mime_type=\"audio/wav\"),\n",
|
| 364 |
+
" types.Part.from_text(text=build_prompt(None)),\n",
|
| 365 |
+
" ]))\n",
|
| 366 |
+
" contents.append(types.Content(role=\"model\", parts=[types.Part.from_text(text=ex_gold)]))\n",
|
| 367 |
+
" contents.append(types.Content(role=\"user\", parts=[\n",
|
| 368 |
+
" types.Part.from_bytes(data=audio_bytes, mime_type=\"audio/wav\"),\n",
|
| 369 |
+
" types.Part.from_text(text=prompt),\n",
|
| 370 |
+
" ]))\n",
|
| 371 |
+
" resp = client.models.generate_content(\n",
|
| 372 |
+
" model=GEMINI_MODEL, contents=contents,\n",
|
| 373 |
+
" config=types.GenerateContentConfig(\n",
|
| 374 |
+
" system_instruction=SYSTEM_INSTRUCTION, temperature=TEMPERATURE),\n",
|
| 375 |
+
" )\n",
|
| 376 |
+
" return resp.text\n",
|
| 377 |
+
"\n",
|
| 378 |
+
"def call_openai(audio_b64, audio_bytes, prompt):\n",
|
| 379 |
+
" client = _openai_client()\n",
|
| 380 |
+
" messages = [{\"role\": \"system\", \"content\": SYSTEM_INSTRUCTION}]\n",
|
| 381 |
+
" for ex_b64, ex_bytes, ex_gold in few_shot_examples:\n",
|
| 382 |
+
" messages.append({\"role\": \"user\", \"content\": [\n",
|
| 383 |
+
" {\"type\": \"text\", \"text\": build_prompt(None)},\n",
|
| 384 |
+
" {\"type\": \"input_audio\", \"input_audio\": {\"data\": ex_b64, \"format\": \"wav\"}},\n",
|
| 385 |
+
" ]})\n",
|
| 386 |
+
" messages.append({\"role\": \"assistant\", \"content\": ex_gold})\n",
|
| 387 |
+
" messages.append({\"role\": \"user\", \"content\": [\n",
|
| 388 |
+
" {\"type\": \"text\", \"text\": prompt},\n",
|
| 389 |
+
" {\"type\": \"input_audio\", \"input_audio\": {\"data\": audio_b64, \"format\": \"wav\"}},\n",
|
| 390 |
+
" ]})\n",
|
| 391 |
+
" resp = client.chat.completions.create(\n",
|
| 392 |
+
" model=OPENAI_MODEL, messages=messages, temperature=TEMPERATURE,\n",
|
| 393 |
+
" modalities=[\"text\"],\n",
|
| 394 |
+
" )\n",
|
| 395 |
+
" return resp.choices[0].message.content\n",
|
| 396 |
+
"\n",
|
| 397 |
+
"def call_llm(audio_b64, audio_bytes, prompt):\n",
|
| 398 |
+
" return call_gemini(audio_b64, audio_bytes, prompt) if PROVIDER == \"gemini\" \\\n",
|
| 399 |
+
" else call_openai(audio_b64, audio_bytes, prompt)"
|
| 400 |
+
]
|
| 401 |
+
},
|
| 402 |
+
{
|
| 403 |
+
"cell_type": "markdown",
|
| 404 |
+
"id": "f6ef0abc",
|
| 405 |
+
"metadata": {},
|
| 406 |
+
"source": [
|
| 407 |
+
"## 5. Parse JSON chịu lỗi → 6 cột; clamp [1,5]; chuẩn hóa CAT"
|
| 408 |
+
]
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"cell_type": "code",
|
| 412 |
+
"execution_count": null,
|
| 413 |
+
"id": "74507a4a",
|
| 414 |
+
"metadata": {
|
| 415 |
+
"lines_to_next_cell": 1
|
| 416 |
+
},
|
| 417 |
+
"outputs": [],
|
| 418 |
+
"source": [
|
| 419 |
+
"def _clamp(x, lo=1.0, hi=5.0, default=3.0):\n",
|
| 420 |
+
" try:\n",
|
| 421 |
+
" v = float(x)\n",
|
| 422 |
+
" except Exception:\n",
|
| 423 |
+
" return default\n",
|
| 424 |
+
" return max(lo, min(hi, v))\n",
|
| 425 |
+
"\n",
|
| 426 |
+
"def parse_response(text):\n",
|
| 427 |
+
" \"\"\"text thô LLM → dict {qmos,emos,cat5(list theo EMOTIONS5),val,aro,dom} hoặc None nếu hỏng.\"\"\"\n",
|
| 428 |
+
" if not text:\n",
|
| 429 |
+
" return None\n",
|
| 430 |
+
" m = re.search(r\"\\{.*\\}\", text, re.DOTALL) # trích khối JSON đầu tiên\n",
|
| 431 |
+
" if not m:\n",
|
| 432 |
+
" return None\n",
|
| 433 |
+
" try:\n",
|
| 434 |
+
" d = json.loads(m.group(0))\n",
|
| 435 |
+
" except Exception:\n",
|
| 436 |
+
" return None\n",
|
| 437 |
+
" cat_in = d.get(\"cat\", {}) or {}\n",
|
| 438 |
+
" cat = np.zeros(5, np.float32)\n",
|
| 439 |
+
" for k, v in cat_in.items():\n",
|
| 440 |
+
" e = norm_emotion(k)\n",
|
| 441 |
+
" if e in EMOTIONS5:\n",
|
| 442 |
+
" try:\n",
|
| 443 |
+
" cat[EMOTIONS5.index(e)] = max(0.0, float(v))\n",
|
| 444 |
+
" except Exception:\n",
|
| 445 |
+
" pass\n",
|
| 446 |
+
" cat = cat / cat.sum() if cat.sum() > 0 else np.full(5, 0.2, np.float32)\n",
|
| 447 |
+
" return dict(\n",
|
| 448 |
+
" qmos=_clamp(d.get(\"qmos\")), emos=_clamp(d.get(\"emos\")),\n",
|
| 449 |
+
" cat5=cat.tolist(),\n",
|
| 450 |
+
" val=_clamp(d.get(\"val\")), aro=_clamp(d.get(\"aro\")), dom=_clamp(d.get(\"dom\")),\n",
|
| 451 |
+
" )"
|
| 452 |
+
]
|
| 453 |
+
},
|
| 454 |
+
{
|
| 455 |
+
"cell_type": "markdown",
|
| 456 |
+
"id": "462449d1",
|
| 457 |
+
"metadata": {},
|
| 458 |
+
"source": [
|
| 459 |
+
"## 6. Vòng chấm có CACHE + RESUME (KHÔNG gọi lại wav đã có trong cache)"
|
| 460 |
+
]
|
| 461 |
+
},
|
| 462 |
+
{
|
| 463 |
+
"cell_type": "code",
|
| 464 |
+
"execution_count": null,
|
| 465 |
+
"id": "ee30edbc",
|
| 466 |
+
"metadata": {
|
| 467 |
+
"lines_to_next_cell": 1
|
| 468 |
+
},
|
| 469 |
+
"outputs": [],
|
| 470 |
+
"source": [
|
| 471 |
+
"def load_cache():\n",
|
| 472 |
+
" done = {}\n",
|
| 473 |
+
" if os.path.exists(CACHE_PATH):\n",
|
| 474 |
+
" with open(CACHE_PATH, encoding=\"utf-8\") as f:\n",
|
| 475 |
+
" for ln in f:\n",
|
| 476 |
+
" try:\n",
|
| 477 |
+
" r = json.loads(ln)\n",
|
| 478 |
+
" done[r[\"stem\"]] = r\n",
|
| 479 |
+
" except Exception:\n",
|
| 480 |
+
" continue\n",
|
| 481 |
+
" return done\n",
|
| 482 |
+
"\n",
|
| 483 |
+
"def score_one(name):\n",
|
| 484 |
+
" \"\"\"Gọi LLM cho 1 wav, retry; trả record dict {stem,name,raw,parsed}.\"\"\"\n",
|
| 485 |
+
" sid = stem(name)\n",
|
| 486 |
+
" wavp = os.path.join(WAV_DIR, name if name.endswith(\".wav\") else name + \".wav\")\n",
|
| 487 |
+
" tgt = target_map.get(sid)\n",
|
| 488 |
+
" prompt = build_prompt(tgt)\n",
|
| 489 |
+
" last_err = None\n",
|
| 490 |
+
" for attempt in range(MAX_RETRY):\n",
|
| 491 |
+
" try:\n",
|
| 492 |
+
" _, b64 = (None, None)\n",
|
| 493 |
+
" raw_bytes, b64 = load_wav_bytes(wavp)\n",
|
| 494 |
+
" text = call_llm(b64, raw_bytes, prompt)\n",
|
| 495 |
+
" parsed = parse_response(text)\n",
|
| 496 |
+
" if parsed is not None:\n",
|
| 497 |
+
" return dict(stem=sid, name=name, raw=text, parsed=parsed, ok=True)\n",
|
| 498 |
+
" last_err = \"parse_fail\"\n",
|
| 499 |
+
" except Exception as e:\n",
|
| 500 |
+
" last_err = str(e)\n",
|
| 501 |
+
" time.sleep(RETRY_SLEEP * (attempt + 1))\n",
|
| 502 |
+
" return dict(stem=sid, name=name, raw=None, parsed=None, ok=False, err=last_err)\n",
|
| 503 |
+
"\n",
|
| 504 |
+
"def run_scoring():\n",
|
| 505 |
+
" from concurrent.futures import ThreadPoolExecutor, as_completed\n",
|
| 506 |
+
" done = load_cache()\n",
|
| 507 |
+
" todo = [n for n in dev_names if stem(n) not in done]\n",
|
| 508 |
+
" print(f\"Cache có {len(done)} | cần chấm thêm {len(todo)} | ước lượng {len(todo)} call API\")\n",
|
| 509 |
+
" if not todo:\n",
|
| 510 |
+
" return done\n",
|
| 511 |
+
" n_ok = n_bad = 0\n",
|
| 512 |
+
" with open(CACHE_PATH, \"a\", encoding=\"utf-8\") as fout, \\\n",
|
| 513 |
+
" ThreadPoolExecutor(max_workers=WORKERS) as ex:\n",
|
| 514 |
+
" futs = {ex.submit(score_one, n): n for n in todo}\n",
|
| 515 |
+
" for i, fut in enumerate(as_completed(futs), 1):\n",
|
| 516 |
+
" rec = fut.result()\n",
|
| 517 |
+
" fout.write(json.dumps(rec, ensure_ascii=False) + \"\\n\")\n",
|
| 518 |
+
" fout.flush()\n",
|
| 519 |
+
" done[rec[\"stem\"]] = rec\n",
|
| 520 |
+
" n_ok += int(rec[\"ok\"]); n_bad += int(not rec[\"ok\"])\n",
|
| 521 |
+
" if i % 50 == 0 or i == len(todo):\n",
|
| 522 |
+
" print(f\" {i}/{len(todo)} | ok={n_ok} bad={n_bad}\")\n",
|
| 523 |
+
" if n_bad:\n",
|
| 524 |
+
" print(f\"⚠️ {n_bad} wav hỏng (parse/API) → sẽ điền mặc định ở build_answer.\")\n",
|
| 525 |
+
" return done\n",
|
| 526 |
+
"\n",
|
| 527 |
+
"records = run_scoring()"
|
| 528 |
+
]
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"cell_type": "markdown",
|
| 532 |
+
"id": "eda27285",
|
| 533 |
+
"metadata": {},
|
| 534 |
+
"source": [
|
| 535 |
+
"## 7. Ráp `answer.txt` 6 cột (giống exp07) + validate + zip"
|
| 536 |
+
]
|
| 537 |
+
},
|
| 538 |
+
{
|
| 539 |
+
"cell_type": "code",
|
| 540 |
+
"execution_count": null,
|
| 541 |
+
"id": "b9a4bd65",
|
| 542 |
+
"metadata": {
|
| 543 |
+
"lines_to_next_cell": 1
|
| 544 |
+
},
|
| 545 |
+
"outputs": [],
|
| 546 |
+
"source": [
|
| 547 |
+
"def fmt_cat(probs5):\n",
|
| 548 |
+
" return \"|\".join(f\"{e}:{probs5[i]:.6g}\" for i, e in enumerate(EMOTIONS5))\n",
|
| 549 |
+
"\n",
|
| 550 |
+
"def build_answer(out_path):\n",
|
| 551 |
+
" n_real = n_default = 0\n",
|
| 552 |
+
" with open(out_path, \"w\") as f:\n",
|
| 553 |
+
" f.write(\"wav,QMOS,EMOS,CAT,VAL,ARO,DOM\\n\")\n",
|
| 554 |
+
" for name in dev_names:\n",
|
| 555 |
+
" sid = stem(name)\n",
|
| 556 |
+
" rec = records.get(sid)\n",
|
| 557 |
+
" p = rec[\"parsed\"] if (rec and rec.get(\"parsed\")) else None\n",
|
| 558 |
+
" if p is None:\n",
|
| 559 |
+
" qmos = emos = val = aro = dom = 3.0\n",
|
| 560 |
+
" cat5 = [0.2] * 5\n",
|
| 561 |
+
" n_default += 1\n",
|
| 562 |
+
" else:\n",
|
| 563 |
+
" qmos, emos = p[\"qmos\"], p[\"emos\"]\n",
|
| 564 |
+
" val, aro, dom = p[\"val\"], p[\"aro\"], p[\"dom\"]\n",
|
| 565 |
+
" cat5 = p[\"cat5\"]; n_real += 1\n",
|
| 566 |
+
" f.write(f\"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},\"\n",
|
| 567 |
+
" f\"{val:.6g},{aro:.6g},{dom:.6g}\\n\")\n",
|
| 568 |
+
" print(f\"Ghi {len(dev_names)} dòng → {out_path} | LLM thật {n_real}, mặc định {n_default}\")\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"answer_path = os.path.join(OUT_DIR, \"answer.txt\")\n",
|
| 571 |
+
"build_answer(answer_path)\n",
|
| 572 |
+
"\n",
|
| 573 |
+
"def validate(path):\n",
|
| 574 |
+
" import csv\n",
|
| 575 |
+
" with open(path) as f:\n",
|
| 576 |
+
" rows = list(csv.reader(f))\n",
|
| 577 |
+
" header = rows[0]\n",
|
| 578 |
+
" assert header[0] == \"wav\" and \"QMOS\" in header and \"EMOS\" in header, \"Header sai\"\n",
|
| 579 |
+
" for i, r in enumerate(rows[1:], 2):\n",
|
| 580 |
+
" assert len(r) == len(header), f\"Dòng {i} sai số cột\"\n",
|
| 581 |
+
" print(f\"OK: {len(rows)-1} dòng, header = {header}\")\n",
|
| 582 |
+
"\n",
|
| 583 |
+
"validate(answer_path)\n",
|
| 584 |
+
"!cd /kaggle/working && zip -j submission_track2_exp16.zip answer.txt && unzip -l submission_track2_exp16.zip\n",
|
| 585 |
+
"print(\"Sẵn sàng nộp: /kaggle/working/submission_track2_exp16.zip\")"
|
| 586 |
+
]
|
| 587 |
+
},
|
| 588 |
+
{
|
| 589 |
+
"cell_type": "markdown",
|
| 590 |
+
"id": "f8eeafb8",
|
| 591 |
+
"metadata": {},
|
| 592 |
+
"source": [
|
| 593 |
+
"## 8. (tùy chọn) Ensemble muộn: trộn THỨ HẠNG điểm LLM + hệ trained\n",
|
| 594 |
+
"\n",
|
| 595 |
+
"Trung bình rank của exp16 với một `answer.txt` đã có (vd bản trộn cột exp07+exp08) cho từng cột số.\n",
|
| 596 |
+
"Đa dạng nguồn → có thể giảm nhiễu. CHỈ chạy khi có sẵn file kia (đặt đường dẫn rồi bỏ comment)."
|
| 597 |
+
]
|
| 598 |
+
},
|
| 599 |
+
{
|
| 600 |
+
"cell_type": "code",
|
| 601 |
+
"execution_count": null,
|
| 602 |
+
"id": "d1a15c1f",
|
| 603 |
+
"metadata": {},
|
| 604 |
+
"outputs": [],
|
| 605 |
+
"source": [
|
| 606 |
+
"def ensemble_rank_average(answer_a, answer_b, out_path):\n",
|
| 607 |
+
" \"\"\"Trộn 2 answer.txt theo TRUNG BÌNH THỨ HẠNG cho 5 cột số (QMOS/EMOS/VAL/ARO/DOM); CAT lấy theo A.\"\"\"\n",
|
| 608 |
+
" import pandas as pd\n",
|
| 609 |
+
" num_cols = [\"QMOS\", \"EMOS\", \"VAL\", \"ARO\", \"DOM\"]\n",
|
| 610 |
+
" A = pd.read_csv(answer_a); B = pd.read_csv(answer_b)\n",
|
| 611 |
+
" A = A.set_index(\"wav\"); B = B.set_index(\"wav\").reindex(A.index)\n",
|
| 612 |
+
" out = A.copy()\n",
|
| 613 |
+
" for c in num_cols:\n",
|
| 614 |
+
" if c in A.columns and c in B.columns:\n",
|
| 615 |
+
" ra = A[c].rank(); rb = B[c].rank()\n",
|
| 616 |
+
" out[c] = ((ra + rb) / 2.0) # SRCC bất biến với scale → để nguyên rank trung bình\n",
|
| 617 |
+
" out.reset_index().to_csv(out_path, index=False)\n",
|
| 618 |
+
" print(\"Ensemble →\", out_path)\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"# ensemble_rank_average(answer_path,\n",
|
| 621 |
+
"# \"/kaggle/input/.../exp_mix_q07_emo08/answer.txt\",\n",
|
| 622 |
+
"# os.path.join(OUT_DIR, \"answer_ens.txt\"))"
|
| 623 |
+
]
|
| 624 |
+
},
|
| 625 |
+
{
|
| 626 |
+
"cell_type": "markdown",
|
| 627 |
+
"id": "14172d70",
|
| 628 |
+
"metadata": {},
|
| 629 |
+
"source": [
|
| 630 |
+
"## Ghi chú nộp & paper\n",
|
| 631 |
+
"- Nộp: My Submissions → **Track 2** (bỏ chọn track khác) → `submission_track2_exp16.zip` → đọc SRCC 6 cột.\n",
|
| 632 |
+
"- **Bảng A (paper):** đặt SRCC exp16 (gemini/openai, zero-shot) cạnh exp07 (QMOS 0.548) + exp08\n",
|
| 633 |
+
" (EMOS 0.811 · CAT 0.133 · VAD 0.659/0.793/0.751). Kỳ vọng: LLM khá ở EMOS/CAT, yếu ở QMOS.\n",
|
| 634 |
+
"- **Bảng B:** chạy lại `SHOT_MODE=\"few_shot\"` (1 provider) → so zero vs few-shot.\n",
|
| 635 |
+
"- **Cache:** Save Version để giữ `exp16_llm_cache/*.jsonl` (không trả tiền lại). Lưu thành Kaggle\n",
|
| 636 |
+
" Dataset nếu muốn dùng cho eval phase.\n",
|
| 637 |
+
"- **Khai báo external resource** (API thương mại Gemini/OpenAI) trong `12_system_description.md`."
|
| 638 |
+
]
|
| 639 |
+
}
|
| 640 |
+
],
|
| 641 |
+
"metadata": {
|
| 642 |
+
"jupytext": {
|
| 643 |
+
"cell_metadata_filter": "-all",
|
| 644 |
+
"main_language": "python",
|
| 645 |
+
"notebook_metadata_filter": "-all"
|
| 646 |
+
}
|
| 647 |
+
},
|
| 648 |
+
"nbformat": 4,
|
| 649 |
+
"nbformat_minor": 5
|
| 650 |
+
}
|
track2/exp16_llm_judge_pipeline.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # exp16 — Audio-LLM-as-Judge cho MOS cảm xúc (Track 2)
|
| 3 |
+
#
|
| 4 |
+
# **Ý tưởng:** đưa thẳng audio cho một **audio-LLM** (Gemini / GPT-4o-audio) qua **API** + prompt có
|
| 5 |
+
# cấu trúc → bắt nó chấm cả 6 cột (`QMOS, EMOS, CAT, VAL, ARO, DOM`) → ráp `answer.txt` → nộp CodaBench.
|
| 6 |
+
#
|
| 7 |
+
# **Mục tiêu chính = NOVELTY cho paper** (khảo sát có hệ thống audio-LLM-as-judge cho MOS cảm xúc),
|
| 8 |
+
# so với hệ SSL đã train (exp07 QMOS 0.548 · exp08 EMOS 0.811…). KHÔNG cần GPU — thuần gọi API.
|
| 9 |
+
#
|
| 10 |
+
# | Đặc điểm | Giá trị |
|
| 11 |
+
# |---|---|
|
| 12 |
+
# | GPU | ❌ không cần (chỉ network I/O) |
|
| 13 |
+
# | Tốn phí | ✅ API trả tiền theo token/audio → **cache + resume bắt buộc** |
|
| 14 |
+
# | Provider | `gemini` (mặc định, đã có billing) · `openai` (GPT-4o-audio, để so 2 LLM) |
|
| 15 |
+
# | Output | `answer.txt` 6 cột giống exp07 |
|
| 16 |
+
#
|
| 17 |
+
# **Cách dùng Kaggle:** Internet = **On**; Add-ons → Secrets: `GEMINI_API_KEY` (và `OPENAI_API_KEY`
|
| 18 |
+
# nếu chạy provider openai). Settings GPU **không cần**. Sửa `DATA_ROOT` cho khớp slug rồi Run All.
|
| 19 |
+
#
|
| 20 |
+
# ⚠️ **Model ID có thể đã đổi** theo thời gian → kiểm tra `GEMINI_MODEL` / `OPENAI_MODEL` còn nhận
|
| 21 |
+
# audio không trước khi chạy full (xem mục 1).
|
| 22 |
+
|
| 23 |
+
# %% [markdown]
|
| 24 |
+
# ## 0. Cấu hình — SỬA Ở ĐÂY
|
| 25 |
+
|
| 26 |
+
# %%
|
| 27 |
+
import os, io, re, json, time, base64, glob
|
| 28 |
+
|
| 29 |
+
# ── Data Track 2 trên Kaggle ────────────────────────────────────────────────
|
| 30 |
+
DATA_ROOT = "/kaggle/input/vmc2026-track2-full/vmc2026-track2" # << SỬA slug
|
| 31 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 32 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # wavID|emotion|transcript (KHÔNG header) — nhãn cảm xúc target
|
| 33 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav DEV cần nộp (train phase)
|
| 34 |
+
TRAIN_CSV = f"{DATA_ROOT}/sets/train.csv" # chỉ cần khi SHOT_MODE="few_shot"
|
| 35 |
+
|
| 36 |
+
OUT_DIR = "/kaggle/working"
|
| 37 |
+
CACHE_DIR = "/kaggle/working/exp16_llm_cache" # nên Save Version / lưu Dataset để KHÔNG gọi lại API
|
| 38 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
# ── Provider & model ────────────────────────────────────────────────────────
|
| 41 |
+
PROVIDER = "gemini" # "gemini" | "openai"
|
| 42 |
+
GEMINI_MODEL = "gemini-2.5-flash" # << xác nhận model audio hiện hành (baseline dùng họ gemini-*-flash)
|
| 43 |
+
OPENAI_MODEL = "gpt-4o-audio-preview" # << model audio của OpenAI; cần OPENAI_API_KEY
|
| 44 |
+
TEMPERATURE = 0.0 # cố định để TÁI LẬP (paper)
|
| 45 |
+
|
| 46 |
+
# ── Chế độ chạy ─────────────────────────────────────────────────────────────
|
| 47 |
+
SHOT_MODE = "zero_shot" # "zero_shot" | "few_shot" (nhét K ví dụ audio có nhãn từ train.csv)
|
| 48 |
+
FEW_K = 2 # số ví dụ few-shot (mỗi ví dụ = 1 audio + nhãn vàng) — tốn thêm token!
|
| 49 |
+
LIMIT = 20 # << số nhỏ (20) để smoke test; None = full DEV (~2730) — CHẠY THỬ TRƯỚC
|
| 50 |
+
MAX_SECONDS = 12 # cắt audio cho rẻ + nhanh
|
| 51 |
+
WORKERS = 4 # luồng gọi song song (giảm nếu dính rate limit)
|
| 52 |
+
MAX_RETRY = 3 # số lần thử lại 1 wav khi lỗi mạng / JSON hỏng
|
| 53 |
+
RETRY_SLEEP = 2.0 # giây nghỉ giữa các lần thử
|
| 54 |
+
|
| 55 |
+
TAG = f"{PROVIDER}_{(GEMINI_MODEL if PROVIDER=='gemini' else OPENAI_MODEL)}_{SHOT_MODE}".replace("/", "-")
|
| 56 |
+
CACHE_PATH = os.path.join(CACHE_DIR, f"{TAG}.jsonl") # 1 dòng JSON / wav (raw + parsed) → resume
|
| 57 |
+
print("TAG:", TAG, "| cache:", CACHE_PATH)
|
| 58 |
+
|
| 59 |
+
# %% [markdown]
|
| 60 |
+
# ## 0b. Nhãn cảm xúc target + chuẩn hóa lớp (tái dùng quy ước baseline)
|
| 61 |
+
|
| 62 |
+
# %%
|
| 63 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"] # THỨ TỰ chuẩn cho cột CAT
|
| 64 |
+
|
| 65 |
+
_EMO_ALIAS = {
|
| 66 |
+
"angry": "angry", "anger": "angry",
|
| 67 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 68 |
+
"neutral": "neutral", "calm": "neutral",
|
| 69 |
+
"sad": "sad", "sadness": "sad",
|
| 70 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
def norm_emotion(label):
|
| 74 |
+
key = str(label).strip().lower()
|
| 75 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 76 |
+
|
| 77 |
+
def stem(name):
|
| 78 |
+
return os.path.splitext(os.path.basename(name))[0]
|
| 79 |
+
|
| 80 |
+
def load_target_emotions():
|
| 81 |
+
"""metadata.csv (wavID|emotion|transcript, không header) → {stem: emotion_chuẩn}."""
|
| 82 |
+
tgt = {}
|
| 83 |
+
if not (METADATA_CSV and os.path.exists(METADATA_CSV)):
|
| 84 |
+
print("⚠️ Không thấy metadata.csv → EMOS sẽ thiếu cảm xúc target.")
|
| 85 |
+
return tgt
|
| 86 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 87 |
+
for ln in f:
|
| 88 |
+
parts = ln.strip().split("|")
|
| 89 |
+
if len(parts) < 2:
|
| 90 |
+
continue
|
| 91 |
+
tgt[stem(parts[0])] = norm_emotion(parts[1])
|
| 92 |
+
return tgt
|
| 93 |
+
|
| 94 |
+
target_map = load_target_emotions()
|
| 95 |
+
print("Nhãn cảm xúc target:", len(target_map))
|
| 96 |
+
|
| 97 |
+
def list_dev():
|
| 98 |
+
with open(DEV_SCP) as f:
|
| 99 |
+
return [ln.strip() for ln in f if ln.strip()]
|
| 100 |
+
|
| 101 |
+
dev_names = list_dev()
|
| 102 |
+
if LIMIT:
|
| 103 |
+
dev_names = dev_names[:LIMIT]
|
| 104 |
+
print("DEV cần chấm:", len(dev_names), "mẫu", "| LIMIT =", LIMIT)
|
| 105 |
+
|
| 106 |
+
# %% [markdown]
|
| 107 |
+
# ## 1. Cài SDK + nạp key
|
| 108 |
+
#
|
| 109 |
+
# Gemini dùng SDK mới `google-genai`; OpenAI dùng `openai`. Trên Kaggle **Internet phải On**.
|
| 110 |
+
|
| 111 |
+
# %%
|
| 112 |
+
# !pip -q install google-genai openai soundfile librosa
|
| 113 |
+
|
| 114 |
+
def setup_keys():
|
| 115 |
+
"""Nạp API key từ Kaggle Secrets (fallback: biến môi trường đã set sẵn)."""
|
| 116 |
+
try:
|
| 117 |
+
from kaggle_secrets import UserSecretsClient
|
| 118 |
+
sec = UserSecretsClient()
|
| 119 |
+
for k in ["GEMINI_API_KEY", "OPENAI_API_KEY"]:
|
| 120 |
+
try:
|
| 121 |
+
os.environ[k] = sec.get_secret(k)
|
| 122 |
+
print(f"Đã nạp {k} từ Secrets")
|
| 123 |
+
except Exception:
|
| 124 |
+
pass
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print("Không dùng được Kaggle Secrets:", e, "→ set tay os.environ[...] nếu cần")
|
| 127 |
+
|
| 128 |
+
setup_keys()
|
| 129 |
+
|
| 130 |
+
# %% [markdown]
|
| 131 |
+
# ## 2. Đọc + chuẩn hóa audio (16kHz mono, cắt MAX_SECONDS) → bytes WAV trong RAM
|
| 132 |
+
|
| 133 |
+
# %%
|
| 134 |
+
import numpy as np
|
| 135 |
+
|
| 136 |
+
def load_wav_bytes(path, sr=16000, max_seconds=MAX_SECONDS):
|
| 137 |
+
"""Trả (wav_bytes, base64_str). Cắt ≤ max_seconds, resample 16k mono, encode WAV PCM16."""
|
| 138 |
+
import soundfile as sf
|
| 139 |
+
try:
|
| 140 |
+
import librosa
|
| 141 |
+
y, _ = librosa.load(path, sr=sr, mono=True)
|
| 142 |
+
except Exception:
|
| 143 |
+
y, in_sr = sf.read(path)
|
| 144 |
+
if y.ndim > 1:
|
| 145 |
+
y = y.mean(axis=1)
|
| 146 |
+
if in_sr != sr: # fallback resample tuyến tính nếu không có librosa
|
| 147 |
+
idx = np.linspace(0, len(y) - 1, int(len(y) * sr / in_sr))
|
| 148 |
+
y = np.interp(idx, np.arange(len(y)), y)
|
| 149 |
+
if max_seconds:
|
| 150 |
+
y = y[: int(sr * max_seconds)]
|
| 151 |
+
buf = io.BytesIO()
|
| 152 |
+
sf.write(buf, y.astype(np.float32), sr, format="WAV", subtype="PCM_16")
|
| 153 |
+
raw = buf.getvalue()
|
| 154 |
+
return raw, base64.b64encode(raw).decode("ascii")
|
| 155 |
+
|
| 156 |
+
# %% [markdown]
|
| 157 |
+
# ## 3. Prompt — định nghĩa 6 metric + ép JSON nghiêm ngặt
|
| 158 |
+
#
|
| 159 |
+
# QMOS = chất lượng/độ tự nhiên (sạch, không méo/robot). EMOS = độ KHỚP với **cảm xúc target**.
|
| 160 |
+
# CAT = phân phối vote 5 lớp. VAD = Valence/Arousal/Dominance. Tất cả thang **1–5** (CAT là tỉ lệ 0–1).
|
| 161 |
+
|
| 162 |
+
# %%
|
| 163 |
+
SYSTEM_INSTRUCTION = (
|
| 164 |
+
"You are an expert evaluator of emotional text-to-speech. "
|
| 165 |
+
"Listen to the audio and rate it. Respond with ONLY a compact JSON object, no prose."
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def build_prompt(target_emo):
|
| 169 |
+
tgt = target_emo if target_emo else "unknown"
|
| 170 |
+
return (
|
| 171 |
+
"Rate this speech utterance. The INTENDED (target) emotion is: "
|
| 172 |
+
f"\"{tgt}\".\n\n"
|
| 173 |
+
"Return a JSON object with EXACTLY these keys (numbers on a 1-5 scale unless stated):\n"
|
| 174 |
+
" \"qmos\": overall audio QUALITY / naturalness (1=very unnatural/robotic/distorted, 5=clean & human-like).\n"
|
| 175 |
+
" \"emos\": how well the emotion expressed MATCHES the target emotion above "
|
| 176 |
+
"(1=not matching at all, 5=perfectly matching).\n"
|
| 177 |
+
" \"cat\": an object with probabilities (summing to 1.0) over the 5 perceived emotions: "
|
| 178 |
+
"{\"neutral\":_, \"happy\":_, \"sad\":_, \"angry\":_, \"surprised\":_}.\n"
|
| 179 |
+
" \"val\": valence (1=very negative, 5=very positive).\n"
|
| 180 |
+
" \"aro\": arousal (1=very calm, 5=very excited).\n"
|
| 181 |
+
" \"dom\": dominance (1=very submissive, 5=very dominant).\n\n"
|
| 182 |
+
"Example format: "
|
| 183 |
+
"{\"qmos\":3.5,\"emos\":4.0,"
|
| 184 |
+
"\"cat\":{\"neutral\":0.1,\"happy\":0.7,\"sad\":0.0,\"angry\":0.1,\"surprised\":0.1},"
|
| 185 |
+
"\"val\":4.0,\"aro\":3.5,\"dom\":3.0}\n"
|
| 186 |
+
"Respond with ONLY the JSON."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# %% [markdown]
|
| 190 |
+
# ## 3b. (tùy chọn) Few-shot — lấy K ví dụ audio có nhãn vàng từ train.csv
|
| 191 |
+
#
|
| 192 |
+
# Bật khi `SHOT_MODE="few_shot"`. Mỗi ví dụ = 1 audio train + nhãn vàng (gộp TB theo wav). Tốn thêm token.
|
| 193 |
+
|
| 194 |
+
# %%
|
| 195 |
+
few_shot_examples = [] # list[(audio_b64, audio_bytes, gold_json_str)]
|
| 196 |
+
|
| 197 |
+
def _agg_train_labels():
|
| 198 |
+
"""Gộp train.csv (sep='|') theo wavID → nhãn vàng trung bình; CAT = tỉ lệ vote."""
|
| 199 |
+
import pandas as pd
|
| 200 |
+
df = pd.read_csv(TRAIN_CSV, sep="|")
|
| 201 |
+
rows = {}
|
| 202 |
+
for wav, g in df.groupby("wavID"):
|
| 203 |
+
votes = np.zeros(5, np.float32)
|
| 204 |
+
for cell in g["emoCat"].astype(str):
|
| 205 |
+
for tok in cell.split(","):
|
| 206 |
+
e = norm_emotion(tok)
|
| 207 |
+
if e in EMOTIONS5:
|
| 208 |
+
votes[EMOTIONS5.index(e)] += 1
|
| 209 |
+
s = votes.sum()
|
| 210 |
+
cat = (votes / s) if s > 0 else np.full(5, 0.2, np.float32)
|
| 211 |
+
rows[stem(wav)] = dict(
|
| 212 |
+
qmos=float(g["qMOS"].mean()), emos=float(g["eMOS"].mean()),
|
| 213 |
+
val=float(g["val"].mean()), aro=float(g["aro"].mean()), dom=float(g["dom"].mean()),
|
| 214 |
+
cat={EMOTIONS5[i]: round(float(cat[i]), 4) for i in range(5)},
|
| 215 |
+
)
|
| 216 |
+
return rows
|
| 217 |
+
|
| 218 |
+
def build_few_shot():
|
| 219 |
+
if SHOT_MODE != "few_shot":
|
| 220 |
+
return
|
| 221 |
+
labels = _agg_train_labels()
|
| 222 |
+
picked = list(labels.keys())[:FEW_K]
|
| 223 |
+
for sid in picked:
|
| 224 |
+
wavp = os.path.join(WAV_DIR, sid + ".wav")
|
| 225 |
+
if not os.path.exists(wavp):
|
| 226 |
+
continue
|
| 227 |
+
raw, b64 = load_wav_bytes(wavp)
|
| 228 |
+
gold = labels[sid]
|
| 229 |
+
gold_json = json.dumps({
|
| 230 |
+
"qmos": round(gold["qmos"], 2), "emos": round(gold["emos"], 2),
|
| 231 |
+
"cat": gold["cat"], "val": round(gold["val"], 2),
|
| 232 |
+
"aro": round(gold["aro"], 2), "dom": round(gold["dom"], 2),
|
| 233 |
+
})
|
| 234 |
+
few_shot_examples.append((b64, raw, gold_json))
|
| 235 |
+
print(f"Few-shot: {len(few_shot_examples)} ví dụ")
|
| 236 |
+
|
| 237 |
+
build_few_shot()
|
| 238 |
+
|
| 239 |
+
# %% [markdown]
|
| 240 |
+
# ## 4. Gọi API — trừu tượng hóa provider (gemini / openai)
|
| 241 |
+
#
|
| 242 |
+
# Mỗi provider tự dựng message của nó (kèm few-shot nếu có). Trả về **text thô** để parse ở mục 5.
|
| 243 |
+
|
| 244 |
+
# %%
|
| 245 |
+
_client = {"gemini": None, "openai": None}
|
| 246 |
+
|
| 247 |
+
def _gemini_client():
|
| 248 |
+
if _client["gemini"] is None:
|
| 249 |
+
from google import genai
|
| 250 |
+
_client["gemini"] = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
|
| 251 |
+
return _client["gemini"]
|
| 252 |
+
|
| 253 |
+
def _openai_client():
|
| 254 |
+
if _client["openai"] is None:
|
| 255 |
+
from openai import OpenAI
|
| 256 |
+
_client["openai"] = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
|
| 257 |
+
return _client["openai"]
|
| 258 |
+
|
| 259 |
+
def call_gemini(audio_b64, audio_bytes, prompt):
|
| 260 |
+
from google.genai import types
|
| 261 |
+
client = _gemini_client()
|
| 262 |
+
contents = []
|
| 263 |
+
for ex_b64, ex_bytes, ex_gold in few_shot_examples: # few-shot: audio ví dụ + nhãn vàng
|
| 264 |
+
contents.append(types.Content(role="user", parts=[
|
| 265 |
+
types.Part.from_bytes(data=ex_bytes, mime_type="audio/wav"),
|
| 266 |
+
types.Part.from_text(text=build_prompt(None)),
|
| 267 |
+
]))
|
| 268 |
+
contents.append(types.Content(role="model", parts=[types.Part.from_text(text=ex_gold)]))
|
| 269 |
+
contents.append(types.Content(role="user", parts=[
|
| 270 |
+
types.Part.from_bytes(data=audio_bytes, mime_type="audio/wav"),
|
| 271 |
+
types.Part.from_text(text=prompt),
|
| 272 |
+
]))
|
| 273 |
+
resp = client.models.generate_content(
|
| 274 |
+
model=GEMINI_MODEL, contents=contents,
|
| 275 |
+
config=types.GenerateContentConfig(
|
| 276 |
+
system_instruction=SYSTEM_INSTRUCTION, temperature=TEMPERATURE),
|
| 277 |
+
)
|
| 278 |
+
return resp.text
|
| 279 |
+
|
| 280 |
+
def call_openai(audio_b64, audio_bytes, prompt):
|
| 281 |
+
client = _openai_client()
|
| 282 |
+
messages = [{"role": "system", "content": SYSTEM_INSTRUCTION}]
|
| 283 |
+
for ex_b64, ex_bytes, ex_gold in few_shot_examples:
|
| 284 |
+
messages.append({"role": "user", "content": [
|
| 285 |
+
{"type": "text", "text": build_prompt(None)},
|
| 286 |
+
{"type": "input_audio", "input_audio": {"data": ex_b64, "format": "wav"}},
|
| 287 |
+
]})
|
| 288 |
+
messages.append({"role": "assistant", "content": ex_gold})
|
| 289 |
+
messages.append({"role": "user", "content": [
|
| 290 |
+
{"type": "text", "text": prompt},
|
| 291 |
+
{"type": "input_audio", "input_audio": {"data": audio_b64, "format": "wav"}},
|
| 292 |
+
]})
|
| 293 |
+
resp = client.chat.completions.create(
|
| 294 |
+
model=OPENAI_MODEL, messages=messages, temperature=TEMPERATURE,
|
| 295 |
+
modalities=["text"],
|
| 296 |
+
)
|
| 297 |
+
return resp.choices[0].message.content
|
| 298 |
+
|
| 299 |
+
def call_llm(audio_b64, audio_bytes, prompt):
|
| 300 |
+
return call_gemini(audio_b64, audio_bytes, prompt) if PROVIDER == "gemini" \
|
| 301 |
+
else call_openai(audio_b64, audio_bytes, prompt)
|
| 302 |
+
|
| 303 |
+
# %% [markdown]
|
| 304 |
+
# ## 5. Parse JSON chịu lỗi → 6 cột; clamp [1,5]; chuẩn hóa CAT
|
| 305 |
+
|
| 306 |
+
# %%
|
| 307 |
+
def _clamp(x, lo=1.0, hi=5.0, default=3.0):
|
| 308 |
+
try:
|
| 309 |
+
v = float(x)
|
| 310 |
+
except Exception:
|
| 311 |
+
return default
|
| 312 |
+
return max(lo, min(hi, v))
|
| 313 |
+
|
| 314 |
+
def parse_response(text):
|
| 315 |
+
"""text thô LLM → dict {qmos,emos,cat5(list theo EMOTIONS5),val,aro,dom} hoặc None nếu hỏng."""
|
| 316 |
+
if not text:
|
| 317 |
+
return None
|
| 318 |
+
m = re.search(r"\{.*\}", text, re.DOTALL) # trích khối JSON đầu tiên
|
| 319 |
+
if not m:
|
| 320 |
+
return None
|
| 321 |
+
try:
|
| 322 |
+
d = json.loads(m.group(0))
|
| 323 |
+
except Exception:
|
| 324 |
+
return None
|
| 325 |
+
cat_in = d.get("cat", {}) or {}
|
| 326 |
+
cat = np.zeros(5, np.float32)
|
| 327 |
+
for k, v in cat_in.items():
|
| 328 |
+
e = norm_emotion(k)
|
| 329 |
+
if e in EMOTIONS5:
|
| 330 |
+
try:
|
| 331 |
+
cat[EMOTIONS5.index(e)] = max(0.0, float(v))
|
| 332 |
+
except Exception:
|
| 333 |
+
pass
|
| 334 |
+
cat = cat / cat.sum() if cat.sum() > 0 else np.full(5, 0.2, np.float32)
|
| 335 |
+
return dict(
|
| 336 |
+
qmos=_clamp(d.get("qmos")), emos=_clamp(d.get("emos")),
|
| 337 |
+
cat5=cat.tolist(),
|
| 338 |
+
val=_clamp(d.get("val")), aro=_clamp(d.get("aro")), dom=_clamp(d.get("dom")),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# %% [markdown]
|
| 342 |
+
# ## 6. Vòng chấm có CACHE + RESUME (KHÔNG gọi lại wav đã có trong cache)
|
| 343 |
+
|
| 344 |
+
# %%
|
| 345 |
+
def load_cache():
|
| 346 |
+
done = {}
|
| 347 |
+
if os.path.exists(CACHE_PATH):
|
| 348 |
+
with open(CACHE_PATH, encoding="utf-8") as f:
|
| 349 |
+
for ln in f:
|
| 350 |
+
try:
|
| 351 |
+
r = json.loads(ln)
|
| 352 |
+
done[r["stem"]] = r
|
| 353 |
+
except Exception:
|
| 354 |
+
continue
|
| 355 |
+
return done
|
| 356 |
+
|
| 357 |
+
def score_one(name):
|
| 358 |
+
"""Gọi LLM cho 1 wav, retry; trả record dict {stem,name,raw,parsed}."""
|
| 359 |
+
sid = stem(name)
|
| 360 |
+
wavp = os.path.join(WAV_DIR, name if name.endswith(".wav") else name + ".wav")
|
| 361 |
+
tgt = target_map.get(sid)
|
| 362 |
+
prompt = build_prompt(tgt)
|
| 363 |
+
last_err = None
|
| 364 |
+
for attempt in range(MAX_RETRY):
|
| 365 |
+
try:
|
| 366 |
+
_, b64 = (None, None)
|
| 367 |
+
raw_bytes, b64 = load_wav_bytes(wavp)
|
| 368 |
+
text = call_llm(b64, raw_bytes, prompt)
|
| 369 |
+
parsed = parse_response(text)
|
| 370 |
+
if parsed is not None:
|
| 371 |
+
return dict(stem=sid, name=name, raw=text, parsed=parsed, ok=True)
|
| 372 |
+
last_err = "parse_fail"
|
| 373 |
+
except Exception as e:
|
| 374 |
+
last_err = str(e)
|
| 375 |
+
time.sleep(RETRY_SLEEP * (attempt + 1))
|
| 376 |
+
return dict(stem=sid, name=name, raw=None, parsed=None, ok=False, err=last_err)
|
| 377 |
+
|
| 378 |
+
def run_scoring():
|
| 379 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 380 |
+
done = load_cache()
|
| 381 |
+
todo = [n for n in dev_names if stem(n) not in done]
|
| 382 |
+
print(f"Cache có {len(done)} | cần chấm thêm {len(todo)} | ước lượng {len(todo)} call API")
|
| 383 |
+
if not todo:
|
| 384 |
+
return done
|
| 385 |
+
n_ok = n_bad = 0
|
| 386 |
+
with open(CACHE_PATH, "a", encoding="utf-8") as fout, \
|
| 387 |
+
ThreadPoolExecutor(max_workers=WORKERS) as ex:
|
| 388 |
+
futs = {ex.submit(score_one, n): n for n in todo}
|
| 389 |
+
for i, fut in enumerate(as_completed(futs), 1):
|
| 390 |
+
rec = fut.result()
|
| 391 |
+
fout.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 392 |
+
fout.flush()
|
| 393 |
+
done[rec["stem"]] = rec
|
| 394 |
+
n_ok += int(rec["ok"]); n_bad += int(not rec["ok"])
|
| 395 |
+
if i % 50 == 0 or i == len(todo):
|
| 396 |
+
print(f" {i}/{len(todo)} | ok={n_ok} bad={n_bad}")
|
| 397 |
+
if n_bad:
|
| 398 |
+
print(f"⚠️ {n_bad} wav hỏng (parse/API) → sẽ điền mặc định ở build_answer.")
|
| 399 |
+
return done
|
| 400 |
+
|
| 401 |
+
records = run_scoring()
|
| 402 |
+
|
| 403 |
+
# %% [markdown]
|
| 404 |
+
# ## 7. Ráp `answer.txt` 6 cột (giống exp07) + validate + zip
|
| 405 |
+
|
| 406 |
+
# %%
|
| 407 |
+
def fmt_cat(probs5):
|
| 408 |
+
return "|".join(f"{e}:{probs5[i]:.6g}" for i, e in enumerate(EMOTIONS5))
|
| 409 |
+
|
| 410 |
+
def build_answer(out_path):
|
| 411 |
+
n_real = n_default = 0
|
| 412 |
+
with open(out_path, "w") as f:
|
| 413 |
+
f.write("wav,QMOS,EMOS,CAT,VAL,ARO,DOM\n")
|
| 414 |
+
for name in dev_names:
|
| 415 |
+
sid = stem(name)
|
| 416 |
+
rec = records.get(sid)
|
| 417 |
+
p = rec["parsed"] if (rec and rec.get("parsed")) else None
|
| 418 |
+
if p is None:
|
| 419 |
+
qmos = emos = val = aro = dom = 3.0
|
| 420 |
+
cat5 = [0.2] * 5
|
| 421 |
+
n_default += 1
|
| 422 |
+
else:
|
| 423 |
+
qmos, emos = p["qmos"], p["emos"]
|
| 424 |
+
val, aro, dom = p["val"], p["aro"], p["dom"]
|
| 425 |
+
cat5 = p["cat5"]; n_real += 1
|
| 426 |
+
f.write(f"{name},{qmos:.6g},{emos:.6g},{fmt_cat(cat5)},"
|
| 427 |
+
f"{val:.6g},{aro:.6g},{dom:.6g}\n")
|
| 428 |
+
print(f"Ghi {len(dev_names)} dòng → {out_path} | LLM thật {n_real}, mặc định {n_default}")
|
| 429 |
+
|
| 430 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 431 |
+
build_answer(answer_path)
|
| 432 |
+
|
| 433 |
+
def validate(path):
|
| 434 |
+
import csv
|
| 435 |
+
with open(path) as f:
|
| 436 |
+
rows = list(csv.reader(f))
|
| 437 |
+
header = rows[0]
|
| 438 |
+
assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
|
| 439 |
+
for i, r in enumerate(rows[1:], 2):
|
| 440 |
+
assert len(r) == len(header), f"Dòng {i} sai số cột"
|
| 441 |
+
print(f"OK: {len(rows)-1} dòng, header = {header}")
|
| 442 |
+
|
| 443 |
+
validate(answer_path)
|
| 444 |
+
# !cd /kaggle/working && zip -j submission_track2_exp16.zip answer.txt && unzip -l submission_track2_exp16.zip
|
| 445 |
+
print("Sẵn sàng nộp: /kaggle/working/submission_track2_exp16.zip")
|
| 446 |
+
|
| 447 |
+
# %% [markdown]
|
| 448 |
+
# ## 8. (tùy chọn) Ensemble muộn: trộn THỨ HẠNG điểm LLM + hệ trained
|
| 449 |
+
#
|
| 450 |
+
# Trung bình rank của exp16 với một `answer.txt` đã có (vd bản trộn cột exp07+exp08) cho từng cột số.
|
| 451 |
+
# Đa dạng nguồn → có thể giảm nhiễu. CHỈ chạy khi có sẵn file kia (đặt đường dẫn rồi bỏ comment).
|
| 452 |
+
|
| 453 |
+
# %%
|
| 454 |
+
def ensemble_rank_average(answer_a, answer_b, out_path):
|
| 455 |
+
"""Trộn 2 answer.txt theo TRUNG BÌNH THỨ HẠNG cho 5 cột số (QMOS/EMOS/VAL/ARO/DOM); CAT lấy theo A."""
|
| 456 |
+
import pandas as pd
|
| 457 |
+
num_cols = ["QMOS", "EMOS", "VAL", "ARO", "DOM"]
|
| 458 |
+
A = pd.read_csv(answer_a); B = pd.read_csv(answer_b)
|
| 459 |
+
A = A.set_index("wav"); B = B.set_index("wav").reindex(A.index)
|
| 460 |
+
out = A.copy()
|
| 461 |
+
for c in num_cols:
|
| 462 |
+
if c in A.columns and c in B.columns:
|
| 463 |
+
ra = A[c].rank(); rb = B[c].rank()
|
| 464 |
+
out[c] = ((ra + rb) / 2.0) # SRCC bất biến với scale → để nguyên rank trung bình
|
| 465 |
+
out.reset_index().to_csv(out_path, index=False)
|
| 466 |
+
print("Ensemble →", out_path)
|
| 467 |
+
|
| 468 |
+
# ensemble_rank_average(answer_path,
|
| 469 |
+
# "/kaggle/input/.../exp_mix_q07_emo08/answer.txt",
|
| 470 |
+
# os.path.join(OUT_DIR, "answer_ens.txt"))
|
| 471 |
+
|
| 472 |
+
# %% [markdown]
|
| 473 |
+
# ## Ghi chú nộp & paper
|
| 474 |
+
# - Nộp: My Submissions → **Track 2** (bỏ chọn track khác) → `submission_track2_exp16.zip` → đọc SRCC 6 cột.
|
| 475 |
+
# - **Bảng A (paper):** đặt SRCC exp16 (gemini/openai, zero-shot) cạnh exp07 (QMOS 0.548) + exp08
|
| 476 |
+
# (EMOS 0.811 · CAT 0.133 · VAD 0.659/0.793/0.751). Kỳ vọng: LLM khá ở EMOS/CAT, yếu ở QMOS.
|
| 477 |
+
# - **Bảng B:** chạy lại `SHOT_MODE="few_shot"` (1 provider) → so zero vs few-shot.
|
| 478 |
+
# - **Cache:** Save Version để giữ `exp16_llm_cache/*.jsonl` (không trả tiền lại). Lưu thành Kaggle
|
| 479 |
+
# Dataset nếu muốn dùng cho eval phase.
|
| 480 |
+
# - **Khai báo external resource** (API thương mại Gemini/OpenAI) trong `12_system_description.md`.
|
track2/track2_baseline.ipynb
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": "# VMC2026 Track 2 — Baseline Pipeline (Kaggle)\n\nQMOS (SpeechMOS) + EmoCat (emotion2vec) + **EMOS (emotion2vec target-prob, mặc định offline)** → gộp `answer.txt`.\n\n**Trước khi chạy:** Accelerator = **GPU T4**, Internet = **On**.\n- **+ Add Input** → tab Datasets → dataset Track 2 đã upload (Kaggle tự giải nén → có thư mục `vmc2026-track2/`).\n- Với mặc định `EMOS_METHOD='emotion2vec'`: **KHÔNG cần** `GEMINI_API_KEY`. Chỉ cần Secrets khi đổi sang `'gemini'` (để có thêm VAD).\n\nChạy được ngay: **QMOS + EmoCat + EMOS** (chỉ cần wav + `metadata.csv` chứa cảm xúc target).\n\n> ⚠️ Train phase: dự đoán tập **DEV** (`sets/dev.scp`, ~2730 mẫu). Thư mục `wav/` có cả train+dev nên KHÔNG glob hết — chỉ lấy đúng dev.scp."
|
| 7 |
+
},
|
| 8 |
+
{
|
| 9 |
+
"cell_type": "markdown",
|
| 10 |
+
"metadata": {},
|
| 11 |
+
"source": [
|
| 12 |
+
"## 0. Config — SỬA Ở ĐÂY"
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "code",
|
| 17 |
+
"execution_count": null,
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"source": "import os, glob\n\n# ── Data Track 2 trên Kaggle (dataset đã upload, KHÔNG có thư mục con lồng) ──\nDATA_ROOT = '/kaggle/input/vmc2026-track2-full' # << slug dataset bạn upload\nWAV_DIR = f'{DATA_ROOT}/wav'\nMETADATA_CSV = f'{DATA_ROOT}/metadata.csv' # wavID|emotion|transcript (KHÔNG header)\nDEV_SCP = f'{DATA_ROOT}/sets/dev.scp' # danh sách wav tập DEV (tập cần nộp ở train phase)\n\n# Test nhanh trên ESD: trỏ WAV_DIR vào ESD, đặt DEV_SCP=None và METADATA_CSV=None.\n# WAV_DIR = '/kaggle/input/datasets/nguyenthanhlim/emotional-speech-dataset-esd/Emotion Speech Dataset'\n# DEV_SCP = None; METADATA_CSV = None\n\nLIMIT = 20 # << 20 = chạy THỬ nhanh. Đổi None để chạy TOÀN BỘ DEV rồi nộp.\n\n# ── Cách tính EMOS ──────────────────────────────────────────────────────────\n# 'emotion2vec': OFFLINE, MIỄN PHÍ (exp01, khuyến nghị) — P(cảm xúc target) từ emotion2vec → scale 1–5.\n# 'gemini' : LLM-as-judge qua Gemini API (cần GEMINI_API_KEY, tốn phí). Chỉ cách này có VAD.\nEMOS_METHOD = 'emotion2vec'\n\nOUT_DIR = '/kaggle/working'\nRUN_QMOS, RUN_EMOCAT = True, True\n_have_meta = bool(METADATA_CSV) and os.path.exists(METADATA_CSV)\nRUN_EMOS = _have_meta # cả 2 cách đều cần target từ metadata\nRUN_VAD = _have_meta and EMOS_METHOD == 'gemini' # VAD chỉ có ở Gemini\nEMOTIONS5 = ['angry', 'happy', 'neutral', 'sad', 'surprised']\n\n# Chuẩn hóa nhãn cảm xúc target (metadata) → đúng 1 trong 5 lớp của emotion2vec.\n_EMO_ALIAS = {'angry':'angry','anger':'angry','happy':'happy','happiness':'happy','joy':'happy',\n 'neutral':'neutral','calm':'neutral','sad':'sad','sadness':'sad',\n 'surprise':'surprised','surprised':'surprised','surprising':'surprised'}\ndef norm_emotion(label):\n key = str(label).strip().lower()\n return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)\n\ndef list_wavs(d):\n # Có DEV_SCP → đọc danh sách tên file tập DEV (wav nằm phẳng trong wav/).\n # Không có → quét đệ quy mọi .wav (chế độ test ESD, lồng speaker/emotion).\n if DEV_SCP and os.path.exists(DEV_SCP):\n with open(DEV_SCP) as f:\n names = [ln.strip() for ln in f if ln.strip()]\n wavs = [os.path.join(d, n) for n in names]\n else:\n wavs = sorted(glob.glob(os.path.join(d, '**', '*.wav'), recursive=True))\n return wavs[:LIMIT] if LIMIT else wavs\n\nprint('WAV_DIR:', WAV_DIR, '| EMOS_METHOD:', EMOS_METHOD)\nprint('Số wav:', len(list_wavs(WAV_DIR)) if os.path.isdir(WAV_DIR) else '(chưa thấy thư mục)')\nprint('Chế độ DEV (dev.scp):', bool(DEV_SCP and os.path.exists(DEV_SCP)))\nif METADATA_CSV and os.path.exists(METADATA_CSV) and DEV_SCP and os.path.exists(DEV_SCP):\n n_meta = sum(1 for _ in open(METADATA_CSV))\n n_dev = sum(1 for _ in open(DEV_SCP))\n print(f'metadata.csv: {n_meta} dòng | dev.scp: {n_dev} dòng')"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"cell_type": "markdown",
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"source": [
|
| 26 |
+
"## 1. Cài đặt"
|
| 27 |
+
]
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"cell_type": "code",
|
| 31 |
+
"execution_count": null,
|
| 32 |
+
"metadata": {},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"!pip install -q speechmos funasr librosa soundfile pandas google-genai loguru tqdm"
|
| 36 |
+
]
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"cell_type": "markdown",
|
| 40 |
+
"metadata": {},
|
| 41 |
+
"source": [
|
| 42 |
+
"## 2. QMOS — SpeechMOS (UTMOS, không cần fairseq)"
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"cell_type": "code",
|
| 47 |
+
"execution_count": null,
|
| 48 |
+
"metadata": {},
|
| 49 |
+
"outputs": [],
|
| 50 |
+
"source": "def run_qmos(wav_dir):\n import torch, librosa\n dev = 'cuda' if torch.cuda.is_available() else 'cpu'\n predictor = torch.hub.load('tarepan/SpeechMOS:v1.2.0', 'utmos22_strong', trust_repo=True).to(dev) # << GPU\n print('QMOS device:', dev)\n scores, missing = {}, 0\n for w in list_wavs(wav_dir): # w là đường dẫn đầy đủ\n if not os.path.exists(w): # mẫu ESD/DailyTalk chưa lấy ngoài → bỏ qua, không crash\n missing += 1; continue\n wave, _ = librosa.load(w, sr=16000, mono=True)\n wave_t = torch.from_numpy(wave).unsqueeze(0).to(dev) # << đưa input lên GPU\n scores[w] = float(predictor(wave_t, sr=16000).mean().item())\n if missing: print(f'[QMOS] Bỏ qua {missing} file thiếu (chưa có ESD/DailyTalk) → điểm mặc định.')\n return scores\n\nqmos_scores = run_qmos(WAV_DIR) if RUN_QMOS else {}\nprint('QMOS xong:', len(qmos_scores))\nlist(qmos_scores.items())[:3]"
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "markdown",
|
| 54 |
+
"metadata": {},
|
| 55 |
+
"source": [
|
| 56 |
+
"## 3. EmoCat — emotion2vec+ large\n",
|
| 57 |
+
"Đã sửa bug bản gốc + lọc 5 lớp + chuẩn hóa tổng = 1."
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "code",
|
| 62 |
+
"execution_count": null,
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": "def run_emocat(wav_dir):\n import torch\n from funasr import AutoModel\n dev = 'cuda:0' if torch.cuda.is_available() else 'cpu'\n model = AutoModel(model='iic/emotion2vec_plus_large', hub='hf', device=dev) # << chạy GPU\n print('EmoCat device:', dev)\n results, missing = {}, 0\n for w in list_wavs(wav_dir): # w là đường dẫn đầy đủ\n if not os.path.exists(w): # mẫu ESD/DailyTalk chưa lấy ngoài → bỏ qua\n missing += 1; continue\n rec = model.generate(w, granularity='utterance', extract_embedding=False)\n probs = {e: 0.0 for e in EMOTIONS5}\n for lab, sc in zip(rec[0]['labels'], rec[0]['scores']):\n name = lab.split('/')[-1]\n if name in probs:\n probs[name] = float(sc)\n total = sum(probs.values())\n if total > 0:\n probs = {k: v / total for k, v in probs.items()}\n results[w] = probs\n if missing: print(f'[EmoCat] Bỏ qua {missing} file thiếu (chưa có ESD/DailyTalk) → phân bố mặc định.')\n return results\n\nemocat_probs = run_emocat(WAV_DIR) if RUN_EMOCAT else {}\nprint('EmoCat xong:', len(emocat_probs))\nlist(emocat_probs.items())[:2]"
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "markdown",
|
| 69 |
+
"metadata": {},
|
| 70 |
+
"source": "## 4. EMOS — emotion2vec target-prob (mặc định) hoặc Gemini\n**emotion2vec (exp01, offline):** lấy P(cảm xúc target) từ emotion2vec (đã tính ở cell EmoCat), scale [0,1]→[1,5]. Chấm đủ 2.730 mẫu, KHÔNG cần API. SRCC chỉ quan tâm thứ hạng nên scale tuyến tính không đổi tương quan.\n\n**Gemini (`EMOS_METHOD='gemini'`):** LLM-as-judge, cần `GEMINI_API_KEY` + credit; tự lọc metadata về DEV để đỡ tốn. Chỉ cách này có VAD."
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"cell_type": "code",
|
| 74 |
+
"execution_count": null,
|
| 75 |
+
"metadata": {},
|
| 76 |
+
"outputs": [],
|
| 77 |
+
"source": "emos_scores, vad_scores = {}, {} # key = TÊN FILE wav (uttID, có .wav)\n\n# Đọc cảm xúc target từ metadata.csv → {stem: emotion_chuẩn}\ndef load_target_emotions():\n tgt = {}\n if not (METADATA_CSV and os.path.exists(METADATA_CSV)):\n return tgt\n with open(METADATA_CSV, encoding='utf-8') as f:\n for ln in f:\n parts = ln.strip().split('|')\n if len(parts) >= 2:\n stem = os.path.splitext(os.path.basename(parts[0]))[0]\n tgt[stem] = norm_emotion(parts[1])\n return tgt\n\ntarget_map = load_target_emotions()\nprint('Nhãn cảm xúc target đọc được:', len(target_map))\n\nif RUN_EMOS and EMOS_METHOD == 'emotion2vec':\n # ── EMOS OFFLINE: P(cảm xúc target) từ emotion2vec (cell EmoCat), scale [0,1]→[1,5] ──\n assert RUN_EMOCAT and emocat_probs, 'EMOS theo emotion2vec cần chạy cell EmoCat (mục 3) trước.'\n miss_t = miss_p = 0\n for w in list_wavs(WAV_DIR):\n name = os.path.basename(w)\n tgt = target_map.get(os.path.splitext(name)[0])\n probs = emocat_probs.get(w)\n if tgt is None:\n miss_t += 1; continue\n if not probs:\n miss_p += 1; continue\n emos_scores[name] = 1.0 + 4.0 * probs.get(tgt, 0.0) # p=0→1 điểm, p=1→5 điểm\n if miss_t: print(f'[EMOS-e2v] {miss_t} mẫu thiếu nhãn target → mặc định 3.')\n if miss_p: print(f'[EMOS-e2v] {miss_p} mẫu thiếu prob emotion2vec → mặc định 3.')\n print(f'✅ EMOS (emotion2vec) cho {len(emos_scores)} mẫu — không cần API.')\n\nelif RUN_EMOS or RUN_VAD: # EMOS_METHOD == 'gemini'\n try:\n from kaggle_secrets import UserSecretsClient\n os.environ['GEMINI_API_KEY'] = UserSecretsClient().get_secret('GEMINI_API_KEY')\n print('Đã nạp GEMINI_API_KEY từ Secrets')\n except Exception as e:\n print('Chưa nạp được key:', e)\n\n # ── Lọc metadata.csv → CHỈ giữ mẫu thuộc DEV (tránh trả tiền Gemini cho mẫu train) ──\n dev_stems = {os.path.splitext(n.strip())[0] for n in open(DEV_SCP) if n.strip()}\n META_DEV = '/kaggle/working/metadata_dev.csv'\n kept = 0\n with open(METADATA_CSV) as fin, open(META_DEV, 'w') as fout:\n for line in fin:\n if not line.strip():\n continue\n stem = os.path.splitext(os.path.basename(line.split('|')[0].strip()))[0]\n if stem in dev_stems:\n fout.write(line); kept += 1\n print(f'metadata_dev.csv: {kept} dòng (kỳ vọng ~{len(dev_stems)})')\n\n GEMINI_ROWS = f'--end-row {LIMIT}' if LIMIT else ''\n !git clone -q https://github.com/voicemos-challenge/vmc2026-baselines.git /kaggle/working/vmc2026-baselines\n !cd /kaggle/working/vmc2026-baselines/track2/EMOS && python Gemini_EMOS.py --metadata-path $META_DEV --base-path $WAV_DIR --output-file /kaggle/working/emos.csv --workers 4 --resume $GEMINI_ROWS\n !cd /kaggle/working/vmc2026-baselines/track2/VAD && python Gemini_VAD.py --metadata-path $META_DEV --base-path $WAV_DIR --output-file /kaggle/working/vad.csv --workers 4 --resume $GEMINI_ROWS\n\n import pandas as pd\n if os.path.exists('/kaggle/working/emos.csv'):\n d = pd.read_csv('/kaggle/working/emos.csv'); emos_scores = dict(zip(d['uttID'], d['emos']))\n if os.path.exists('/kaggle/working/vad.csv'):\n d = pd.read_csv('/kaggle/working/vad.csv') # cột chuẩn: uttID, val, aro, dom\n for _, r in d.iterrows():\n vad_scores[r['uttID']] = (r['val'], r['aro'], r['dom'])\n\n if emos_scores:\n dev_bases = {os.path.basename(w) for w in list_wavs(WAV_DIR)}\n if not (set(emos_scores) & dev_bases):\n print('⚠️ KEY LỆCH: uttID không khớp tên file dev → EMOS/VAD sẽ về mặc định!')\n else:\n print('✅ Key khớp — EMOS/VAD sẽ gộp đúng.')\n\nprint('EMOS:', len(emos_scores), '| VAD:', len(vad_scores))"
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "markdown",
|
| 81 |
+
"metadata": {},
|
| 82 |
+
"source": [
|
| 83 |
+
"## 5. Gộp answer.txt (tự bỏ cột thiếu)"
|
| 84 |
+
]
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"cell_type": "code",
|
| 88 |
+
"execution_count": null,
|
| 89 |
+
"metadata": {},
|
| 90 |
+
"outputs": [],
|
| 91 |
+
"source": "def fmt_cat(p):\n return '|'.join(f'{e}:{p[e]:.6g}' for e in EMOTIONS5)\n\ndef build_answer(out_path):\n wavs = list_wavs(WAV_DIR)\n have_cat = RUN_EMOCAT and len(emocat_probs) > 0\n have_vad = RUN_VAD and len(vad_scores) > 0\n cols = ['wav', 'QMOS', 'EMOS']\n if have_cat: cols.append('CAT')\n if have_vad: cols += ['VAL', 'ARO', 'DOM']\n with open(out_path, 'w') as f:\n f.write(','.join(cols) + '\\n')\n for w in wavs:\n name = os.path.basename(w) # tên file = cột wav & key của emos/vad\n row = [name, f\"{qmos_scores.get(w, 3.0):.6g}\", str(emos_scores.get(name, 3))]\n if have_cat: row.append(fmt_cat(emocat_probs.get(w, {e: 0.2 for e in EMOTIONS5})))\n if have_vad:\n v = vad_scores.get(name, (3, 3, 3)); row += [str(v[0]), str(v[1]), str(v[2])]\n f.write(','.join(row) + '\\n')\n print(f'Ghi {len(wavs)} dòng → {out_path} | cột: {cols}')\n\nanswer_path = os.path.join(OUT_DIR, 'answer.txt')\nbuild_answer(answer_path)\n!head -3 {answer_path}"
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "markdown",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"source": [
|
| 97 |
+
"## 6. Validate + zip"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": null,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": [
|
| 106 |
+
"import csv\n",
|
| 107 |
+
"with open(answer_path) as f:\n",
|
| 108 |
+
" rows = list(csv.reader(f))\n",
|
| 109 |
+
"header = rows[0]\n",
|
| 110 |
+
"assert header[0] == 'wav' and 'QMOS' in header and 'EMOS' in header, 'Header sai'\n",
|
| 111 |
+
"for i, r in enumerate(rows[1:], 2):\n",
|
| 112 |
+
" assert len(r) == len(header), f'Dòng {i} sai số cột'\n",
|
| 113 |
+
"print(f'OK: {len(rows)-1} dòng, header = {header}')\n",
|
| 114 |
+
"!cd /kaggle/working && zip -j submission_track2.zip answer.txt && unzip -l submission_track2.zip"
|
| 115 |
+
]
|
| 116 |
+
}
|
| 117 |
+
],
|
| 118 |
+
"metadata": {
|
| 119 |
+
"kernelspec": {
|
| 120 |
+
"display_name": "Python 3",
|
| 121 |
+
"language": "python",
|
| 122 |
+
"name": "python3"
|
| 123 |
+
},
|
| 124 |
+
"language_info": {
|
| 125 |
+
"name": "python"
|
| 126 |
+
}
|
| 127 |
+
},
|
| 128 |
+
"nbformat": 4,
|
| 129 |
+
"nbformat_minor": 5
|
| 130 |
+
}
|
track2/track2_baseline_pipeline.py
ADDED
|
@@ -0,0 +1,321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — Baseline Pipeline (Kaggle)
|
| 3 |
+
#
|
| 4 |
+
# Chạy 4 baseline → gộp thành `answer.txt` đúng chuẩn nộp CodaBench.
|
| 5 |
+
#
|
| 6 |
+
# | Sub-task | Baseline | GPU | Cần label? |
|
| 7 |
+
# |---|---|---|---|
|
| 8 |
+
# | QMOS | SpeechMOS (UTMOS bản pip) | có (nhẹ) | không (chỉ cần wav) |
|
| 9 |
+
# | EmoCat (CAT) | emotion2vec+ large (funasr) | có (nhẹ) | không (chỉ cần wav) |
|
| 10 |
+
# | EMOS | **emotion2vec target-prob** (mặc định, offline) HOẶC Gemini | có (nhẹ) | cần `metadata.csv` (nhãn target) |
|
| 11 |
+
# | VAD | Gemini LLM-as-judge (chỉ khi EMOS_METHOD="gemini") | không | cần `metadata.csv` + API key |
|
| 12 |
+
#
|
| 13 |
+
# **Cách dùng trên Kaggle:**
|
| 14 |
+
# 1. Tạo Notebook, Settings → Accelerator = **GPU T4**, Internet = **On** (cần verify phone).
|
| 15 |
+
# 2. **+ Add Input** → chọn dataset Track 2 đã upload (Kaggle tự giải nén → có thư mục `vmc2026-track2/`).
|
| 16 |
+
# 3. Add-ons → Secrets: thêm `GEMINI_API_KEY` (cho EMOS/VAD).
|
| 17 |
+
# 4. Sửa `DATA_ROOT` ở cell 0 cho khớp slug dataset, rồi chạy lần lượt từng cell.
|
| 18 |
+
#
|
| 19 |
+
# Format đích `answer.txt`: `wav,QMOS,EMOS,CAT,VAL,ARO,DOM` — xem `08_track2_spec.md`.
|
| 20 |
+
# QMOS & EMOS bắt buộc; CAT/VAD tùy chọn. Có thể nộp tập con cột.
|
| 21 |
+
#
|
| 22 |
+
# > ⚠️ Ở **training phase** ta dự đoán cho tập **DEV** (`sets/dev.scp`, ~2730 mẫu) rồi nộp.
|
| 23 |
+
# > Thư mục `wav/` chứa cả train + dev nên KHÔNG glob hết — chỉ lấy đúng danh sách dev.scp.
|
| 24 |
+
|
| 25 |
+
# %% [markdown]
|
| 26 |
+
# ## 0. Cấu hình đường dẫn — SỬA Ở ĐÂY
|
| 27 |
+
|
| 28 |
+
# %%
|
| 29 |
+
import os, glob
|
| 30 |
+
|
| 31 |
+
# ── Data Track 2 trên Kaggle ────────────────────────────────────────────────
|
| 32 |
+
# Kaggle TỰ giải nén .tar.gz khi tạo Dataset → có sẵn thư mục `vmc2026-track2/`.
|
| 33 |
+
# Đổi <track2-data> thành slug dataset của bạn (xem thanh path bên phải khi Add Input).
|
| 34 |
+
DATA_ROOT = "/kaggle/input/<track2-data>/vmc2026-track2" # << SỬA slug
|
| 35 |
+
WAV_DIR = f"{DATA_ROOT}/wav"
|
| 36 |
+
METADATA_CSV = f"{DATA_ROOT}/metadata.csv" # định dạng: wavID|emotion|transcript (KHÔNG header)
|
| 37 |
+
DEV_SCP = f"{DATA_ROOT}/sets/dev.scp" # danh sách wav của tập DEV (tập cần nộp ở train phase)
|
| 38 |
+
|
| 39 |
+
# Muốn test nhanh trên ESD (chưa có data thật): trỏ WAV_DIR vào ESD, đặt DEV_SCP=None, METADATA_CSV=None.
|
| 40 |
+
# WAV_DIR = "/kaggle/input/datasets/nguyenthanhlim/emotional-speech-dataset-esd/Emotion Speech Dataset"
|
| 41 |
+
# DEV_SCP = None; METADATA_CSV = None
|
| 42 |
+
|
| 43 |
+
LIMIT = None # << số nhỏ (vd 20) để chạy thử cho nhanh; None = chạy toàn bộ tập DEV
|
| 44 |
+
|
| 45 |
+
OUT_DIR = "/kaggle/working"
|
| 46 |
+
|
| 47 |
+
# ── Cách tính EMOS ──────────────────────────────────────────────────────────
|
| 48 |
+
# "emotion2vec": OFFLINE, MIỄN PHÍ (exp01, khuyến nghị) — lấy P(cảm xúc target) từ
|
| 49 |
+
# emotion2vec (model đã chạy cho CAT) rồi scale về thang 1–5. Không cần API.
|
| 50 |
+
# "gemini" : LLM-as-judge qua Gemini API (cần GEMINI_API_KEY, tốn phí). Chỉ cách này có VAD.
|
| 51 |
+
EMOS_METHOD = "emotion2vec"
|
| 52 |
+
|
| 53 |
+
_have_meta = bool(METADATA_CSV) and os.path.exists(METADATA_CSV) # cần nhãn cảm xúc target
|
| 54 |
+
RUN_QMOS = True
|
| 55 |
+
RUN_EMOCAT = True
|
| 56 |
+
RUN_EMOS = _have_meta # cả 2 cách đều cần target từ metadata
|
| 57 |
+
RUN_VAD = _have_meta and EMOS_METHOD == "gemini" # VAD chỉ có ở Gemini
|
| 58 |
+
|
| 59 |
+
EMOTIONS5 = ["angry", "happy", "neutral", "sad", "surprised"]
|
| 60 |
+
|
| 61 |
+
# Chuẩn hóa nhãn cảm xúc target (metadata) → đúng 1 trong 5 lớp của emotion2vec.
|
| 62 |
+
_EMO_ALIAS = {
|
| 63 |
+
"angry": "angry", "anger": "angry",
|
| 64 |
+
"happy": "happy", "happiness": "happy", "joy": "happy",
|
| 65 |
+
"neutral": "neutral", "calm": "neutral",
|
| 66 |
+
"sad": "sad", "sadness": "sad",
|
| 67 |
+
"surprise": "surprised", "surprised": "surprised", "surprising": "surprised",
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
def norm_emotion(label):
|
| 71 |
+
"""Đưa nhãn cảm xúc bất kỳ về 1 trong EMOTIONS5; None nếu không khớp."""
|
| 72 |
+
key = str(label).strip().lower()
|
| 73 |
+
return _EMO_ALIAS.get(key, key if key in EMOTIONS5 else None)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def list_wavs(d):
|
| 77 |
+
"""Trả về list đường dẫn .wav đầy đủ cần dự đoán.
|
| 78 |
+
- Có DEV_SCP → đọc danh sách tên file tập DEV (đúng tập cần nộp, wav nằm phẳng trong wav/).
|
| 79 |
+
- Không có → quét đệ quy mọi .wav trong thư mục (chế độ test ESD, lồng speaker/emotion)."""
|
| 80 |
+
if DEV_SCP and os.path.exists(DEV_SCP):
|
| 81 |
+
with open(DEV_SCP) as f:
|
| 82 |
+
names = [ln.strip() for ln in f if ln.strip()]
|
| 83 |
+
wavs = [os.path.join(d, n) for n in names]
|
| 84 |
+
else:
|
| 85 |
+
wavs = sorted(glob.glob(os.path.join(d, "**", "*.wav"), recursive=True))
|
| 86 |
+
return wavs[:LIMIT] if LIMIT else wavs
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
print("WAV_DIR:", WAV_DIR)
|
| 90 |
+
print("Số wav:", len(list_wavs(WAV_DIR)) if os.path.isdir(WAV_DIR) else "(chưa thấy thư mục)")
|
| 91 |
+
print("Chế độ DEV (dev.scp):", bool(DEV_SCP and os.path.exists(DEV_SCP)))
|
| 92 |
+
|
| 93 |
+
# %% [markdown]
|
| 94 |
+
# ## 1. Cài đặt
|
| 95 |
+
|
| 96 |
+
# %%
|
| 97 |
+
# !pip install -q speechmos funasr librosa soundfile pandas google-genai loguru tqdm
|
| 98 |
+
|
| 99 |
+
# %% [markdown]
|
| 100 |
+
# ## 2. QMOS — SpeechMOS (UTMOS)
|
| 101 |
+
# Dùng SpeechMOS qua torch.hub (không cần fairseq). Output: dict {wav: score 1-5}.
|
| 102 |
+
|
| 103 |
+
# %%
|
| 104 |
+
def run_qmos(wav_dir):
|
| 105 |
+
import torch, librosa
|
| 106 |
+
# SpeechMOS yêu cầu 16kHz; input shape (Batch, Time); sr truyền dạng keyword.
|
| 107 |
+
predictor = torch.hub.load(
|
| 108 |
+
"tarepan/SpeechMOS:v1.2.0", "utmos22_strong", trust_repo=True
|
| 109 |
+
)
|
| 110 |
+
scores = {}
|
| 111 |
+
missing = 0
|
| 112 |
+
for w in list_wavs(wav_dir):
|
| 113 |
+
if not os.path.exists(w): # mẫu ESD/DailyTalk chưa lấy ngoài → bỏ qua, không crash
|
| 114 |
+
missing += 1
|
| 115 |
+
continue
|
| 116 |
+
wave, _ = librosa.load(w, sr=16000, mono=True)
|
| 117 |
+
wave_t = torch.from_numpy(wave).unsqueeze(0) # (1, Time)
|
| 118 |
+
score = predictor(wave_t, sr=16000) # → tensor shape (1,)
|
| 119 |
+
scores[w] = float(score.mean().item())
|
| 120 |
+
if missing:
|
| 121 |
+
print(f"[QMOS] Bỏ qua {missing} file thiếu (chưa có ESD/DailyTalk) → sẽ nhận điểm mặc định.")
|
| 122 |
+
return scores
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
qmos_scores = run_qmos(WAV_DIR) if RUN_QMOS else {}
|
| 126 |
+
print("QMOS xong:", len(qmos_scores), "mẫu")
|
| 127 |
+
list(qmos_scores.items())[:3]
|
| 128 |
+
|
| 129 |
+
# %% [markdown]
|
| 130 |
+
# ## 3. EmoCat — emotion2vec+ large (funasr)
|
| 131 |
+
# Sửa bug bản gốc + lọc 5 lớp + **chuẩn hóa tổng = 1** (đúng format CAT).
|
| 132 |
+
# Output: dict {wav: {angry:p, happy:p, neutral:p, sad:p, surprised:p}}.
|
| 133 |
+
|
| 134 |
+
# %%
|
| 135 |
+
def run_emocat(wav_dir):
|
| 136 |
+
from funasr import AutoModel
|
| 137 |
+
model = AutoModel(model="iic/emotion2vec_plus_large", hub="hf")
|
| 138 |
+
results = {}
|
| 139 |
+
missing = 0
|
| 140 |
+
for w in list_wavs(wav_dir):
|
| 141 |
+
if not os.path.exists(w): # mẫu ESD/DailyTalk chưa lấy ngoài → bỏ qua
|
| 142 |
+
missing += 1
|
| 143 |
+
continue
|
| 144 |
+
rec = model.generate(
|
| 145 |
+
w,
|
| 146 |
+
granularity="utterance",
|
| 147 |
+
extract_embedding=False,
|
| 148 |
+
)
|
| 149 |
+
labels = rec[0]["labels"]
|
| 150 |
+
scores = rec[0]["scores"]
|
| 151 |
+
# gom điểm 5 lớp quan tâm (label có thể dạng "xx/angry")
|
| 152 |
+
probs = {e: 0.0 for e in EMOTIONS5}
|
| 153 |
+
for lab, sc in zip(labels, scores):
|
| 154 |
+
name = lab.split("/")[-1]
|
| 155 |
+
if name in probs:
|
| 156 |
+
probs[name] = float(sc)
|
| 157 |
+
total = sum(probs.values())
|
| 158 |
+
if total > 0: # chuẩn hóa lại trên 5 lớp
|
| 159 |
+
probs = {k: v / total for k, v in probs.items()}
|
| 160 |
+
results[w] = probs
|
| 161 |
+
if missing:
|
| 162 |
+
print(f"[EmoCat] Bỏ qua {missing} file thiếu (chưa có ESD/DailyTalk) → sẽ nhận phân bố mặc định.")
|
| 163 |
+
return results
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
emocat_probs = run_emocat(WAV_DIR) if RUN_EMOCAT else {}
|
| 167 |
+
print("EmoCat xong:", len(emocat_probs), "mẫu")
|
| 168 |
+
list(emocat_probs.items())[:2]
|
| 169 |
+
|
| 170 |
+
# %% [markdown]
|
| 171 |
+
# ## 4. EMOS — emotion2vec target-prob (mặc định) hoặc Gemini
|
| 172 |
+
# **Cách emotion2vec (exp01, offline):** với mỗi wav, lấy xác suất emotion2vec gán cho ĐÚNG cảm xúc
|
| 173 |
+
# target (đọc từ `metadata.csv`), scale [0,1] → [1,5]. Vì EMOS chấm bằng SRCC (thứ hạng) nên scale
|
| 174 |
+
# tuyến tính không đổi tương quan — chỉ cần thứ tự đúng. KHÔNG cần train, KHÔNG cần API.
|
| 175 |
+
# **Cách Gemini:** gọi script baseline gốc (cần `GEMINI_API_KEY`); chỉ cách này mới có VAD.
|
| 176 |
+
# `metadata.csv` dạng `wavID|emotion|transcript` (không header); `emotion` = cảm xúc target.
|
| 177 |
+
|
| 178 |
+
# %%
|
| 179 |
+
def load_target_emotions():
|
| 180 |
+
"""Đọc metadata.csv (wavID|emotion|transcript, không header) → {stem: emotion_chuẩn}."""
|
| 181 |
+
tgt = {}
|
| 182 |
+
if not (METADATA_CSV and os.path.exists(METADATA_CSV)):
|
| 183 |
+
return tgt
|
| 184 |
+
with open(METADATA_CSV, encoding="utf-8") as f:
|
| 185 |
+
for ln in f:
|
| 186 |
+
parts = ln.strip().split("|")
|
| 187 |
+
if len(parts) < 2:
|
| 188 |
+
continue
|
| 189 |
+
stem = os.path.splitext(os.path.basename(parts[0]))[0]
|
| 190 |
+
tgt[stem] = norm_emotion(parts[1])
|
| 191 |
+
return tgt
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def run_emos_emotion2vec(wav_dir, target_map):
|
| 195 |
+
"""EMOS offline = P(cảm xúc target) từ emotion2vec, scale [0,1] → [1,5].
|
| 196 |
+
Dùng lại emocat_probs (đã tính ở mục 3) nên KHÔNG tốn thêm tính toán."""
|
| 197 |
+
out, miss_tgt, miss_prob = {}, 0, 0
|
| 198 |
+
for w in list_wavs(wav_dir):
|
| 199 |
+
name = os.path.basename(w)
|
| 200 |
+
tgt = target_map.get(os.path.splitext(name)[0])
|
| 201 |
+
probs = emocat_probs.get(w)
|
| 202 |
+
if tgt is None:
|
| 203 |
+
miss_tgt += 1; continue
|
| 204 |
+
if not probs:
|
| 205 |
+
miss_prob += 1; continue
|
| 206 |
+
out[name] = 1.0 + 4.0 * probs.get(tgt, 0.0) # p=0→1 điểm, p=1→5 điểm
|
| 207 |
+
if miss_tgt: print(f"[EMOS-e2v] {miss_tgt} mẫu thiếu nhãn target → mặc định 3.")
|
| 208 |
+
if miss_prob: print(f"[EMOS-e2v] {miss_prob} mẫu thiếu prob emotion2vec → mặc định 3.")
|
| 209 |
+
return out
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def setup_gemini_key():
|
| 213 |
+
try:
|
| 214 |
+
from kaggle_secrets import UserSecretsClient
|
| 215 |
+
os.environ["GEMINI_API_KEY"] = UserSecretsClient().get_secret("GEMINI_API_KEY")
|
| 216 |
+
print("Đã nạp GEMINI_API_KEY từ Kaggle Secrets")
|
| 217 |
+
except Exception as e:
|
| 218 |
+
print("Chưa nạp được key từ Secrets:", e, "→ set thủ công os.environ['GEMINI_API_KEY']")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
emos_scores = {} # {uttID(tên file .wav): điểm EMOS}
|
| 222 |
+
vad_scores = {} # {uttID: (val, aro, dom)} — chỉ có khi dùng Gemini
|
| 223 |
+
|
| 224 |
+
target_map = load_target_emotions()
|
| 225 |
+
print("Nhãn cảm xúc target đọc được:", len(target_map))
|
| 226 |
+
|
| 227 |
+
if RUN_EMOS and EMOS_METHOD == "emotion2vec":
|
| 228 |
+
assert RUN_EMOCAT and emocat_probs, "EMOS theo emotion2vec cần EmoCat chạy trước (RUN_EMOCAT=True)."
|
| 229 |
+
emos_scores = run_emos_emotion2vec(WAV_DIR, target_map)
|
| 230 |
+
|
| 231 |
+
elif RUN_EMOS and EMOS_METHOD == "gemini":
|
| 232 |
+
setup_gemini_key()
|
| 233 |
+
# !git clone -q https://github.com/voicemos-challenge/vmc2026-baselines.git /kaggle/working/vmc2026-baselines
|
| 234 |
+
# Chạy (1-based, inclusive). Với eval lớn nên chia batch + giảm --workers do quota free tier.
|
| 235 |
+
# Gemini chỉ chấm các dòng metadata.csv trùng với DEV; cách đơn giản: chạy hết rồi lọc lại ở build_answer.
|
| 236 |
+
# !cd /kaggle/working/vmc2026-baselines/track2/EMOS && python Gemini_EMOS.py \
|
| 237 |
+
# --metadata-path {METADATA_CSV} --base-path {WAV_DIR} \
|
| 238 |
+
# --output-file /kaggle/working/emos.csv --workers 4 --resume
|
| 239 |
+
# !cd /kaggle/working/vmc2026-baselines/track2/VAD && python Gemini_VAD.py \
|
| 240 |
+
# --metadata-path {METADATA_CSV} --base-path {WAV_DIR} \
|
| 241 |
+
# --output-file /kaggle/working/vad.csv --workers 4 --resume
|
| 242 |
+
import pandas as pd
|
| 243 |
+
if os.path.exists("/kaggle/working/emos.csv"):
|
| 244 |
+
df = pd.read_csv("/kaggle/working/emos.csv")
|
| 245 |
+
emos_scores = dict(zip(df["uttID"], df["emos"]))
|
| 246 |
+
if os.path.exists("/kaggle/working/vad.csv"):
|
| 247 |
+
df = pd.read_csv("/kaggle/working/vad.csv")
|
| 248 |
+
# cột output VAD chuẩn của Gemini_VAD.py: uttID, val, aro, dom
|
| 249 |
+
for _, r in df.iterrows():
|
| 250 |
+
vad_scores[r["uttID"]] = (r["val"], r["aro"], r["dom"])
|
| 251 |
+
|
| 252 |
+
print("EMOS:", len(emos_scores), "| VAD:", len(vad_scores))
|
| 253 |
+
|
| 254 |
+
# %% [markdown]
|
| 255 |
+
# ## 5. Gộp thành `answer.txt`
|
| 256 |
+
# QMOS & EMOS bắt buộc. Tự bỏ cột nếu thiếu dữ liệu (nộp tập con hợp lệ).
|
| 257 |
+
# Lưu ý key: qmos/emocat theo path đầy đủ; emos/vad theo TÊN FILE → tra cứu bằng basename.
|
| 258 |
+
|
| 259 |
+
# %%
|
| 260 |
+
def fmt_cat(p):
|
| 261 |
+
return "|".join(f"{e}:{p[e]:.6g}" for e in EMOTIONS5)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def build_answer(out_path):
|
| 265 |
+
wavs = list_wavs(WAV_DIR)
|
| 266 |
+
have_emos = RUN_EMOS and len(emos_scores) > 0
|
| 267 |
+
have_cat = RUN_EMOCAT and len(emocat_probs) > 0
|
| 268 |
+
have_vad = RUN_VAD and len(vad_scores) > 0
|
| 269 |
+
|
| 270 |
+
cols = ["wav", "QMOS", "EMOS"] # QMOS+EMOS bắt buộc
|
| 271 |
+
if have_cat: cols.append("CAT")
|
| 272 |
+
if have_vad: cols += ["VAL", "ARO", "DOM"]
|
| 273 |
+
|
| 274 |
+
n = 0
|
| 275 |
+
with open(out_path, "w") as f:
|
| 276 |
+
f.write(",".join(cols) + "\n")
|
| 277 |
+
for w in wavs:
|
| 278 |
+
name = os.path.basename(w) # tên file = key của emos/vad, và là giá trị cột wav
|
| 279 |
+
row = [name,
|
| 280 |
+
f"{qmos_scores.get(w, 3.0):.6g}",
|
| 281 |
+
str(emos_scores.get(name, 3))]
|
| 282 |
+
if have_cat:
|
| 283 |
+
row.append(fmt_cat(emocat_probs.get(w, {e: 0.2 for e in EMOTIONS5})))
|
| 284 |
+
if have_vad:
|
| 285 |
+
v = vad_scores.get(name, (3, 3, 3))
|
| 286 |
+
row += [str(v[0]), str(v[1]), str(v[2])]
|
| 287 |
+
f.write(",".join(row) + "\n")
|
| 288 |
+
n += 1
|
| 289 |
+
print(f"Ghi {n} dòng → {out_path} | cột: {cols}")
|
| 290 |
+
return cols
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
answer_path = os.path.join(OUT_DIR, "answer.txt")
|
| 294 |
+
cols = build_answer(answer_path)
|
| 295 |
+
|
| 296 |
+
# %% [markdown]
|
| 297 |
+
# ## 6. Validate + đóng zip
|
| 298 |
+
|
| 299 |
+
# %%
|
| 300 |
+
def validate(path):
|
| 301 |
+
import csv
|
| 302 |
+
with open(path) as f:
|
| 303 |
+
rows = list(csv.reader(f))
|
| 304 |
+
header = rows[0]
|
| 305 |
+
assert header[0] == "wav" and "QMOS" in header and "EMOS" in header, "Header sai"
|
| 306 |
+
for i, r in enumerate(rows[1:], 2):
|
| 307 |
+
assert len(r) == len(header), f"Dòng {i} sai số cột"
|
| 308 |
+
print(f"OK: {len(rows)-1} dòng, header = {header}")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
validate(answer_path)
|
| 312 |
+
# !cd /kaggle/working && zip -j submission_track2.zip answer.txt && unzip -l submission_track2.zip
|
| 313 |
+
print("Sẵn sàng nộp: /kaggle/working/submission_track2.zip (chứa answer.txt)")
|
| 314 |
+
|
| 315 |
+
# %% [markdown]
|
| 316 |
+
# ## Ghi chú
|
| 317 |
+
# - Nộp: My Submissions → chọn **Track 2**, **bỏ chọn** track khác → upload `submission_track2.zip`.
|
| 318 |
+
# - `metadata.csv` (wavID|emotion|transcript, không header) chứa nhãn cảm xúc target cho Gemini EMOS/VAD.
|
| 319 |
+
# - Train phase: dự đoán tập DEV (`sets/dev.scp`). `sets/train.csv` có nhãn người nghe để train mô hình riêng.
|
| 320 |
+
# - Quota Gemini free tier dễ hết với eval lớn → chia batch `--start-row/--end-row`, giảm `--workers`, dùng `--resume`.
|
| 321 |
+
# - Khi có data thật: sửa `DATA_ROOT` ở cell 0 rồi chạy lại từ đầu.
|
track2/track2_prepare_data.ipynb
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# VMC2026 Track 2 — Chuẩn bị data (gộp ESD + DailyTalk) trên Kaggle\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"Gói Track 2 thiếu **1.417 mẫu giọng thật** (license tách ra):\n",
|
| 10 |
+
"- **sys006** = ESD (1.379 file) · **sys001** = DailyTalk (38 file)\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"Notebook này: cài **SoX** + build **sv56** → gom đúng utterance từ ESD/DailyTalk →\n",
|
| 13 |
+
"**chuẩn hóa âm lượng** (giống mẫu TTS) → ráp vào `wav/` đủ **15.477 file**.\n",
|
| 14 |
+
"\n",
|
| 15 |
+
"### Cách dùng\n",
|
| 16 |
+
"1. Settings → **Internet = On** (cần tải/biên dịch sv56). GPU không bắt buộc.\n",
|
| 17 |
+
"2. **+ Add Input** 3 dataset:\n",
|
| 18 |
+
" - Gói Track 2 (`vmc2026_track2_..._v3.tar.gz` — Kaggle tự giải nén ra `vmc2026-track2/`).\n",
|
| 19 |
+
" - ESD: `Emotional Speech Dataset (ESD).zip` (Kaggle tự giải nén ra `Emotion Speech Dataset/`).\n",
|
| 20 |
+
" - DailyTalk: `dailytalk.zip` (giải nén ra `dailytalk/data/...`).\n",
|
| 21 |
+
"3. **Run All**. Xong → **Save Version** (Commit) để lưu `wav/` ra output.\n",
|
| 22 |
+
"4. Từ output đó → **Create Dataset** → dùng làm input cho notebook train/baseline.\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"> Notebook tự dò vị trí ESD/DailyTalk dù Kaggle giải nén ra thư mục tên gì."
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "markdown",
|
| 29 |
+
"metadata": {},
|
| 30 |
+
"source": [
|
| 31 |
+
"## 0. Tìm gói Track 2 + copy ra thư mục ghi được"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "code",
|
| 36 |
+
"execution_count": null,
|
| 37 |
+
"metadata": {},
|
| 38 |
+
"outputs": [],
|
| 39 |
+
"source": [
|
| 40 |
+
"import os, glob, shutil, subprocess\n",
|
| 41 |
+
"\n",
|
| 42 |
+
"# Tự dò thư mục vmc2026-track2 trong mọi input đã add.\n",
|
| 43 |
+
"_cands = glob.glob(\"/kaggle/input/*/vmc2026-track2\") + glob.glob(\"/kaggle/input/**/vmc2026-track2\", recursive=True)\n",
|
| 44 |
+
"TRACK2_SRC = _cands[0] if _cands else None\n",
|
| 45 |
+
"assert TRACK2_SRC, \"Không thấy thư mục vmc2026-track2 — đã Add Input gói Track 2 chưa?\"\n",
|
| 46 |
+
"\n",
|
| 47 |
+
"WORK = \"/kaggle/working/vmc2026-track2\" # bản ghi được (input là read-only)\n",
|
| 48 |
+
"print(\"Track2 source :\", TRACK2_SRC)\n",
|
| 49 |
+
"\n",
|
| 50 |
+
"# Copy toàn bộ gói ra working (gồm wav/ 14060 file + scripts + csv). Mất vài phút.\n",
|
| 51 |
+
"if not os.path.exists(WORK):\n",
|
| 52 |
+
" print(\"Đang copy gói Track 2 ra working (vài phút)...\")\n",
|
| 53 |
+
" shutil.copytree(TRACK2_SRC, WORK)\n",
|
| 54 |
+
"print(\"Số wav hiện có:\", len(os.listdir(f\"{WORK}/wav\")))"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "markdown",
|
| 59 |
+
"metadata": {},
|
| 60 |
+
"source": [
|
| 61 |
+
"## 1. Cài SoX + build sv56 (cần Internet = On)\n",
|
| 62 |
+
"sv56 = công cụ chuẩn hóa âm lượng của ITU-T, build từ source openitu/STL."
|
| 63 |
+
]
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"cell_type": "code",
|
| 67 |
+
"execution_count": null,
|
| 68 |
+
"metadata": {},
|
| 69 |
+
"outputs": [],
|
| 70 |
+
"source": [
|
| 71 |
+
"def sh(cmd):\n",
|
| 72 |
+
" print(\"$\", cmd)\n",
|
| 73 |
+
" print(subprocess.run(cmd, shell=True, capture_output=True, text=True).stdout[-2000:])\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"# SoX + trình biên dịch (để build sv56)\n",
|
| 76 |
+
"sh(\"apt-get -qq update && apt-get -qq install -y sox make gcc\")\n",
|
| 77 |
+
"sh(\"which sox && sox --version\")\n",
|
| 78 |
+
"\n",
|
| 79 |
+
"# sv56demo\n",
|
| 80 |
+
"SV56_DIR = \"/kaggle/working/STL-2009\"\n",
|
| 81 |
+
"SV56_BIN_DIR = f\"{SV56_DIR}/src/sv56\"\n",
|
| 82 |
+
"if not os.path.exists(f\"{SV56_BIN_DIR}/sv56demo\"):\n",
|
| 83 |
+
" sh(\"cd /kaggle/working && wget -q https://github.com/openitu/STL/archive/refs/tags/v2009.tar.gz\")\n",
|
| 84 |
+
" sh(\"cd /kaggle/working && tar -xf v2009.tar.gz\")\n",
|
| 85 |
+
" sh(f\"cd {SV56_BIN_DIR} && make -f makefile.unx\")\n",
|
| 86 |
+
"assert os.path.exists(f\"{SV56_BIN_DIR}/sv56demo\"), \"Build sv56 thất bại — kiểm tra Internet=On + log make.\"\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"# Đưa cả sox và sv56demo vào PATH cho các script .sh dùng được\n",
|
| 89 |
+
"os.environ[\"PATH\"] = SV56_BIN_DIR + \":\" + os.environ[\"PATH\"]\n",
|
| 90 |
+
"sh(\"which sv56demo\")"
|
| 91 |
+
]
|
| 92 |
+
},
|
| 93 |
+
{
|
| 94 |
+
"cell_type": "markdown",
|
| 95 |
+
"metadata": {},
|
| 96 |
+
"source": [
|
| 97 |
+
"## 2. Dò vị trí ESD + DailyTalk (tự tìm dù tên thư mục khác nhau)"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "code",
|
| 102 |
+
"execution_count": null,
|
| 103 |
+
"metadata": {},
|
| 104 |
+
"outputs": [],
|
| 105 |
+
"source": [
|
| 106 |
+
"def find_root(rel_path):\n",
|
| 107 |
+
" \"\"\"Tìm thư mục ROOT trong /kaggle/input sao cho ROOT/rel_path tồn tại.\"\"\"\n",
|
| 108 |
+
" base = os.path.basename(rel_path)\n",
|
| 109 |
+
" for hit in glob.glob(f\"/kaggle/input/**/{base}\", recursive=True):\n",
|
| 110 |
+
" if hit.endswith(rel_path.replace(\"/\", os.sep)) or hit.endswith(rel_path):\n",
|
| 111 |
+
" return hit[: -len(rel_path)].rstrip(\"/\")\n",
|
| 112 |
+
" return None\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"# ESD: dòng CSV \"0014/Angry/000381.wav\" → file thật là \"0014/Angry/0014_000381.wav\"\n",
|
| 115 |
+
"_esd_first = open(f\"{WORK}/ESD_utts_train_dev.csv\").readline().strip().split(\",\")[0]\n",
|
| 116 |
+
"_p = _esd_first.split(\"/\") # [spk, emo, uttID.wav]\n",
|
| 117 |
+
"ESD_REL = f\"{_p[0]}/{_p[1]}/{_p[0]}_{_p[2]}\"\n",
|
| 118 |
+
"ESD_ROOT = find_root(ESD_REL)\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"# DailyTalk: dòng CSV \"1020/0_1_d1020.wav\" → file thật \".../data/1020/0_1_d1020.wav\"\n",
|
| 121 |
+
"_dt_first = open(f\"{WORK}/DT_utts_train_dev.csv\").readline().strip().split(\",\")[0]\n",
|
| 122 |
+
"DT_ROOT = find_root(_dt_first) # ROOT sao cho ROOT/1020/0_1_d1020.wav tồn tại\n",
|
| 123 |
+
"\n",
|
| 124 |
+
"print(\"ESD_REL :\", ESD_REL)\n",
|
| 125 |
+
"print(\"ESD_ROOT :\", ESD_ROOT)\n",
|
| 126 |
+
"print(\"DT_ROOT :\", DT_ROOT)\n",
|
| 127 |
+
"assert ESD_ROOT, \"Không thấy ESD — đã Add Input 'Emotional Speech Dataset (ESD).zip' chưa?\"\n",
|
| 128 |
+
"assert DT_ROOT, \"Không thấy DailyTalk — đã Add Input 'dailytalk.zip' chưa?\""
|
| 129 |
+
]
|
| 130 |
+
},
|
| 131 |
+
{
|
| 132 |
+
"cell_type": "markdown",
|
| 133 |
+
"metadata": {},
|
| 134 |
+
"source": [
|
| 135 |
+
"## 3. Gom các utterance cần dùng → thư mục gathered/"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "code",
|
| 140 |
+
"execution_count": null,
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"outputs": [],
|
| 143 |
+
"source": [
|
| 144 |
+
"GATHERED = f\"{WORK}/gathered\"\n",
|
| 145 |
+
"os.makedirs(GATHERED, exist_ok=True)\n",
|
| 146 |
+
"\n",
|
| 147 |
+
"# ESD: copy ROOT/spk/emo/spk_uttID → gathered/<tên vmc2026>\n",
|
| 148 |
+
"n_esd = 0\n",
|
| 149 |
+
"for line in open(f\"{WORK}/ESD_utts_train_dev.csv\"):\n",
|
| 150 |
+
" src_rel, dst = line.strip().split(\",\")\n",
|
| 151 |
+
" p = src_rel.split(\"/\")\n",
|
| 152 |
+
" src = f\"{ESD_ROOT}/{p[0]}/{p[1]}/{p[0]}_{p[2]}\"\n",
|
| 153 |
+
" if os.path.exists(src):\n",
|
| 154 |
+
" shutil.copy(src, f\"{GATHERED}/{dst}\")\n",
|
| 155 |
+
" n_esd += 1\n",
|
| 156 |
+
" else:\n",
|
| 157 |
+
" print(\"ESD thiếu:\", src)\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"# DailyTalk: copy ROOT/parts[0] → gathered/<tên vmc2026>\n",
|
| 160 |
+
"n_dt = 0\n",
|
| 161 |
+
"for line in open(f\"{WORK}/DT_utts_train_dev.csv\"):\n",
|
| 162 |
+
" src_rel, dst = line.strip().split(\",\")\n",
|
| 163 |
+
" src = f\"{DT_ROOT}/{src_rel}\"\n",
|
| 164 |
+
" if os.path.exists(src):\n",
|
| 165 |
+
" shutil.copy(src, f\"{GATHERED}/{dst}\")\n",
|
| 166 |
+
" n_dt += 1\n",
|
| 167 |
+
" else:\n",
|
| 168 |
+
" print(\"DailyTalk thiếu:\", src)\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"print(f\"Đã gom: ESD {n_esd}/1379 · DailyTalk {n_dt}/38 · tổng {len(os.listdir(GATHERED))}\")"
|
| 171 |
+
]
|
| 172 |
+
},
|
| 173 |
+
{
|
| 174 |
+
"cell_type": "markdown",
|
| 175 |
+
"metadata": {},
|
| 176 |
+
"source": [
|
| 177 |
+
"## 4. Chuẩn hóa âm lượng bằng sv56 (mức -26 dB, giữ nguyên sample rate)"
|
| 178 |
+
]
|
| 179 |
+
},
|
| 180 |
+
{
|
| 181 |
+
"cell_type": "code",
|
| 182 |
+
"execution_count": null,
|
| 183 |
+
"metadata": {},
|
| 184 |
+
"outputs": [],
|
| 185 |
+
"source": [
|
| 186 |
+
"# Dùng chính script gốc trong gói: batch_normRMSE.sh → tạo *_norm.wav trong gathered/\n",
|
| 187 |
+
"sh(f\"bash {WORK}/sv56scripts/batch_normRMSE.sh {GATHERED}\")\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"# Move các file đã chuẩn hóa vào wav/ với đúng tên (bỏ hậu tố _norm)\n",
|
| 190 |
+
"moved = 0\n",
|
| 191 |
+
"for n in os.listdir(GATHERED):\n",
|
| 192 |
+
" if n.endswith(\"_norm.wav\"):\n",
|
| 193 |
+
" final = \"_\".join(n.split(\"_\")[:-1]) + \".wav\" # bỏ \"_norm\"\n",
|
| 194 |
+
" shutil.move(f\"{GATHERED}/{n}\", f\"{WORK}/wav/{final}\")\n",
|
| 195 |
+
" moved += 1\n",
|
| 196 |
+
"print(\"Đã chuẩn hóa & move vào wav/:\", moved, \"file\")"
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"cell_type": "markdown",
|
| 201 |
+
"metadata": {},
|
| 202 |
+
"source": [
|
| 203 |
+
"## 5. Kiểm tra đủ 15.477 + dọn rác"
|
| 204 |
+
]
|
| 205 |
+
},
|
| 206 |
+
{
|
| 207 |
+
"cell_type": "code",
|
| 208 |
+
"execution_count": null,
|
| 209 |
+
"metadata": {},
|
| 210 |
+
"outputs": [],
|
| 211 |
+
"source": [
|
| 212 |
+
"total = len(glob.glob(f\"{WORK}/wav/*.wav\"))\n",
|
| 213 |
+
"print(\"Tổng wav trong wav/:\", total)\n",
|
| 214 |
+
"if total == 15477:\n",
|
| 215 |
+
" shutil.rmtree(GATHERED, ignore_errors=True)\n",
|
| 216 |
+
" # Dọn artifact build để dataset lưu ra GỌN (chỉ giữ vmc2026-track2/)\n",
|
| 217 |
+
" shutil.rmtree(SV56_DIR, ignore_errors=True)\n",
|
| 218 |
+
" for f in glob.glob(\"/kaggle/working/v2009.tar.gz\"):\n",
|
| 219 |
+
" os.remove(f)\n",
|
| 220 |
+
" print(\"✅ ĐỦ 15.477 file. Sẵn sàng Save Version → tạo Dataset.\")\n",
|
| 221 |
+
" print(\" Dùng dataset này cho notebook baseline: DATA_ROOT = '/kaggle/input/<dataset-mới>/vmc2026-track2'\")\n",
|
| 222 |
+
"else:\n",
|
| 223 |
+
" print(f\"⚠️ Chưa đủ (đang {total}). Kiểm tra log bước 3-4 xem ESD/DailyTalk có thiếu file nào.\")"
|
| 224 |
+
]
|
| 225 |
+
},
|
| 226 |
+
{
|
| 227 |
+
"cell_type": "markdown",
|
| 228 |
+
"metadata": {},
|
| 229 |
+
"source": [
|
| 230 |
+
"## Ghi chú\n",
|
| 231 |
+
"- Output nặng (~2-3GB do 15.477 wav). `wav/` đã gồm cả train+dev nên dùng được cho cả fine-tune lẫn inference.\n",
|
| 232 |
+
"- sv56 chuẩn hóa để mẫu ESD/DailyTalk cùng mức âm lượng với mẫu TTS → tránh model bị nhiễu bởi độ to.\n",
|
| 233 |
+
"- Nếu Internet Off: SoX có thể có sẵn nhưng KHÔNG build được sv56 → bắt buộc bật Internet."
|
| 234 |
+
]
|
| 235 |
+
}
|
| 236 |
+
],
|
| 237 |
+
"metadata": {
|
| 238 |
+
"kernelspec": {
|
| 239 |
+
"display_name": "Python 3",
|
| 240 |
+
"language": "python",
|
| 241 |
+
"name": "python3"
|
| 242 |
+
},
|
| 243 |
+
"language_info": {
|
| 244 |
+
"name": "python"
|
| 245 |
+
}
|
| 246 |
+
},
|
| 247 |
+
"nbformat": 4,
|
| 248 |
+
"nbformat_minor": 5
|
| 249 |
+
}
|
track2/track2_prepare_data_pipeline.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# %% [markdown]
|
| 2 |
+
# # VMC2026 Track 2 — Chuẩn bị data (gộp ESD + DailyTalk) trên Kaggle
|
| 3 |
+
#
|
| 4 |
+
# Gói Track 2 thiếu **1.417 mẫu giọng thật** (license tách ra):
|
| 5 |
+
# - **sys006** = ESD (1.379 file) · **sys001** = DailyTalk (38 file)
|
| 6 |
+
#
|
| 7 |
+
# Notebook này: cài **SoX** + build **sv56** → gom đúng utterance từ ESD/DailyTalk →
|
| 8 |
+
# **chuẩn hóa âm lượng** (giống mẫu TTS) → ráp vào `wav/` đủ **15.477 file**.
|
| 9 |
+
#
|
| 10 |
+
# ### Cách dùng
|
| 11 |
+
# 1. Settings → **Internet = On** (cần tải/biên dịch sv56). GPU không bắt buộc.
|
| 12 |
+
# 2. **+ Add Input** 3 dataset:
|
| 13 |
+
# - Gói Track 2 (`vmc2026_track2_..._v3.tar.gz` — Kaggle tự giải nén ra `vmc2026-track2/`).
|
| 14 |
+
# - ESD: `Emotional Speech Dataset (ESD).zip` (Kaggle tự giải nén ra `Emotion Speech Dataset/`).
|
| 15 |
+
# - DailyTalk: `dailytalk.zip` (giải nén ra `dailytalk/data/...`).
|
| 16 |
+
# 3. **Run All**. Xong → **Save Version** (Commit) để lưu `wav/` ra output.
|
| 17 |
+
# 4. Từ output đó → **Create Dataset** → dùng làm input cho notebook train/baseline.
|
| 18 |
+
#
|
| 19 |
+
# > Notebook tự dò vị trí ESD/DailyTalk dù Kaggle giải nén ra thư mục tên gì.
|
| 20 |
+
|
| 21 |
+
# %% [markdown]
|
| 22 |
+
# ## 0. Tìm gói Track 2 + copy ra thư mục ghi được
|
| 23 |
+
|
| 24 |
+
# %%
|
| 25 |
+
import os, glob, shutil, subprocess
|
| 26 |
+
|
| 27 |
+
# Tự dò thư mục vmc2026-track2 trong mọi input đã add.
|
| 28 |
+
_cands = glob.glob("/kaggle/input/*/vmc2026-track2") + glob.glob("/kaggle/input/**/vmc2026-track2", recursive=True)
|
| 29 |
+
TRACK2_SRC = _cands[0] if _cands else None
|
| 30 |
+
assert TRACK2_SRC, "Không thấy thư mục vmc2026-track2 — đã Add Input gói Track 2 chưa?"
|
| 31 |
+
|
| 32 |
+
WORK = "/kaggle/working/vmc2026-track2" # bản ghi được (input là read-only)
|
| 33 |
+
print("Track2 source :", TRACK2_SRC)
|
| 34 |
+
|
| 35 |
+
# Copy toàn bộ gói ra working (gồm wav/ 14060 file + scripts + csv). Mất vài phút.
|
| 36 |
+
if not os.path.exists(WORK):
|
| 37 |
+
print("Đang copy gói Track 2 ra working (vài phút)...")
|
| 38 |
+
shutil.copytree(TRACK2_SRC, WORK)
|
| 39 |
+
print("Số wav hiện có:", len(os.listdir(f"{WORK}/wav")))
|
| 40 |
+
|
| 41 |
+
# %% [markdown]
|
| 42 |
+
# ## 1. Cài SoX + build sv56 (cần Internet = On)
|
| 43 |
+
# sv56 = công cụ chuẩn hóa âm lượng của ITU-T, build từ source openitu/STL.
|
| 44 |
+
|
| 45 |
+
# %%
|
| 46 |
+
def sh(cmd):
|
| 47 |
+
print("$", cmd)
|
| 48 |
+
print(subprocess.run(cmd, shell=True, capture_output=True, text=True).stdout[-2000:])
|
| 49 |
+
|
| 50 |
+
# SoX + trình biên dịch (để build sv56)
|
| 51 |
+
sh("apt-get -qq update && apt-get -qq install -y sox make gcc")
|
| 52 |
+
sh("which sox && sox --version")
|
| 53 |
+
|
| 54 |
+
# sv56demo
|
| 55 |
+
SV56_DIR = "/kaggle/working/STL-2009"
|
| 56 |
+
SV56_BIN_DIR = f"{SV56_DIR}/src/sv56"
|
| 57 |
+
if not os.path.exists(f"{SV56_BIN_DIR}/sv56demo"):
|
| 58 |
+
sh("cd /kaggle/working && wget -q https://github.com/openitu/STL/archive/refs/tags/v2009.tar.gz")
|
| 59 |
+
sh("cd /kaggle/working && tar -xf v2009.tar.gz")
|
| 60 |
+
sh(f"cd {SV56_BIN_DIR} && make -f makefile.unx")
|
| 61 |
+
assert os.path.exists(f"{SV56_BIN_DIR}/sv56demo"), "Build sv56 thất bại — kiểm tra Internet=On + log make."
|
| 62 |
+
|
| 63 |
+
# Đưa cả sox và sv56demo vào PATH cho các script .sh dùng được
|
| 64 |
+
os.environ["PATH"] = SV56_BIN_DIR + ":" + os.environ["PATH"]
|
| 65 |
+
sh("which sv56demo")
|
| 66 |
+
|
| 67 |
+
# %% [markdown]
|
| 68 |
+
# ## 2. Dò vị trí ESD + DailyTalk (tự tìm dù tên thư mục khác nhau)
|
| 69 |
+
|
| 70 |
+
# %%
|
| 71 |
+
def find_root(rel_path):
|
| 72 |
+
"""Tìm thư mục ROOT trong /kaggle/input sao cho ROOT/rel_path tồn tại."""
|
| 73 |
+
base = os.path.basename(rel_path)
|
| 74 |
+
for hit in glob.glob(f"/kaggle/input/**/{base}", recursive=True):
|
| 75 |
+
if hit.endswith(rel_path.replace("/", os.sep)) or hit.endswith(rel_path):
|
| 76 |
+
return hit[: -len(rel_path)].rstrip("/")
|
| 77 |
+
return None
|
| 78 |
+
|
| 79 |
+
# ESD: dòng CSV "0014/Angry/000381.wav" → file thật là "0014/Angry/0014_000381.wav"
|
| 80 |
+
_esd_first = open(f"{WORK}/ESD_utts_train_dev.csv").readline().strip().split(",")[0]
|
| 81 |
+
_p = _esd_first.split("/") # [spk, emo, uttID.wav]
|
| 82 |
+
ESD_REL = f"{_p[0]}/{_p[1]}/{_p[0]}_{_p[2]}"
|
| 83 |
+
ESD_ROOT = find_root(ESD_REL)
|
| 84 |
+
|
| 85 |
+
# DailyTalk: dòng CSV "1020/0_1_d1020.wav" → file thật ".../data/1020/0_1_d1020.wav"
|
| 86 |
+
_dt_first = open(f"{WORK}/DT_utts_train_dev.csv").readline().strip().split(",")[0]
|
| 87 |
+
DT_ROOT = find_root(_dt_first) # ROOT sao cho ROOT/1020/0_1_d1020.wav tồn tại
|
| 88 |
+
|
| 89 |
+
print("ESD_REL :", ESD_REL)
|
| 90 |
+
print("ESD_ROOT :", ESD_ROOT)
|
| 91 |
+
print("DT_ROOT :", DT_ROOT)
|
| 92 |
+
assert ESD_ROOT, "Không thấy ESD — đã Add Input 'Emotional Speech Dataset (ESD).zip' chưa?"
|
| 93 |
+
assert DT_ROOT, "Không thấy DailyTalk — đã Add Input 'dailytalk.zip' chưa?"
|
| 94 |
+
|
| 95 |
+
# %% [markdown]
|
| 96 |
+
# ## 3. Gom các utterance cần dùng → thư mục gathered/
|
| 97 |
+
|
| 98 |
+
# %%
|
| 99 |
+
GATHERED = f"{WORK}/gathered"
|
| 100 |
+
os.makedirs(GATHERED, exist_ok=True)
|
| 101 |
+
|
| 102 |
+
# ESD: copy ROOT/spk/emo/spk_uttID → gathered/<tên vmc2026>
|
| 103 |
+
n_esd = 0
|
| 104 |
+
for line in open(f"{WORK}/ESD_utts_train_dev.csv"):
|
| 105 |
+
src_rel, dst = line.strip().split(",")
|
| 106 |
+
p = src_rel.split("/")
|
| 107 |
+
src = f"{ESD_ROOT}/{p[0]}/{p[1]}/{p[0]}_{p[2]}"
|
| 108 |
+
if os.path.exists(src):
|
| 109 |
+
shutil.copy(src, f"{GATHERED}/{dst}")
|
| 110 |
+
n_esd += 1
|
| 111 |
+
else:
|
| 112 |
+
print("ESD thiếu:", src)
|
| 113 |
+
|
| 114 |
+
# DailyTalk: copy ROOT/parts[0] → gathered/<tên vmc2026>
|
| 115 |
+
n_dt = 0
|
| 116 |
+
for line in open(f"{WORK}/DT_utts_train_dev.csv"):
|
| 117 |
+
src_rel, dst = line.strip().split(",")
|
| 118 |
+
src = f"{DT_ROOT}/{src_rel}"
|
| 119 |
+
if os.path.exists(src):
|
| 120 |
+
shutil.copy(src, f"{GATHERED}/{dst}")
|
| 121 |
+
n_dt += 1
|
| 122 |
+
else:
|
| 123 |
+
print("DailyTalk thiếu:", src)
|
| 124 |
+
|
| 125 |
+
print(f"Đã gom: ESD {n_esd}/1379 · DailyTalk {n_dt}/38 · tổng {len(os.listdir(GATHERED))}")
|
| 126 |
+
|
| 127 |
+
# %% [markdown]
|
| 128 |
+
# ## 4. Chuẩn hóa âm lượng bằng sv56 (mức -26 dB, giữ nguyên sample rate)
|
| 129 |
+
|
| 130 |
+
# %%
|
| 131 |
+
# Dùng chính script gốc trong gói: batch_normRMSE.sh → tạo *_norm.wav trong gathered/
|
| 132 |
+
sh(f"bash {WORK}/sv56scripts/batch_normRMSE.sh {GATHERED}")
|
| 133 |
+
|
| 134 |
+
# Move các file đã chuẩn hóa vào wav/ với đúng tên (bỏ hậu tố _norm)
|
| 135 |
+
moved = 0
|
| 136 |
+
for n in os.listdir(GATHERED):
|
| 137 |
+
if n.endswith("_norm.wav"):
|
| 138 |
+
final = "_".join(n.split("_")[:-1]) + ".wav" # bỏ "_norm"
|
| 139 |
+
shutil.move(f"{GATHERED}/{n}", f"{WORK}/wav/{final}")
|
| 140 |
+
moved += 1
|
| 141 |
+
print("Đã chuẩn hóa & move vào wav/:", moved, "file")
|
| 142 |
+
|
| 143 |
+
# %% [markdown]
|
| 144 |
+
# ## 5. Kiểm tra đủ 15.477 + dọn rác
|
| 145 |
+
|
| 146 |
+
# %%
|
| 147 |
+
total = len(glob.glob(f"{WORK}/wav/*.wav"))
|
| 148 |
+
print("Tổng wav trong wav/:", total)
|
| 149 |
+
if total == 15477:
|
| 150 |
+
shutil.rmtree(GATHERED, ignore_errors=True)
|
| 151 |
+
# Dọn artifact build để dataset lưu ra GỌN (chỉ giữ vmc2026-track2/)
|
| 152 |
+
shutil.rmtree(SV56_DIR, ignore_errors=True)
|
| 153 |
+
for f in glob.glob("/kaggle/working/v2009.tar.gz"):
|
| 154 |
+
os.remove(f)
|
| 155 |
+
print("✅ ĐỦ 15.477 file. Sẵn sàng Save Version → tạo Dataset.")
|
| 156 |
+
print(" Dùng dataset này cho notebook baseline: DATA_ROOT = '/kaggle/input/<dataset-mới>/vmc2026-track2'")
|
| 157 |
+
else:
|
| 158 |
+
print(f"⚠️ Chưa đủ (đang {total}). Kiểm tra log bước 3-4 xem ESD/DailyTalk có thiếu file nào.")
|
| 159 |
+
|
| 160 |
+
# %% [markdown]
|
| 161 |
+
# ## Ghi chú
|
| 162 |
+
# - Output nặng (~2-3GB do 15.477 wav). `wav/` đã gồm cả train+dev nên dùng được cho cả fine-tune lẫn inference.
|
| 163 |
+
# - sv56 chuẩn hóa để mẫu ESD/DailyTalk cùng mức âm lượng với mẫu TTS → tránh model bị nhiễu bởi độ to.
|
| 164 |
+
# - Nếu Internet Off: SoX có thể có sẵn nhưng KHÔNG build được sv56 → bắt buộc bật Internet.
|