dieKarotte commited on
Commit
bf04039
·
verified ·
1 Parent(s): dd39446

Add files using upload-large-folder tool

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png filter=lfs diff=lfs merge=lfs -text
37
+ Evaluation_Results/Comparing_Different_Pre-Training_Targets.png filter=lfs diff=lfs merge=lfs -text
38
+ Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png filter=lfs diff=lfs merge=lfs -text
Evaluation_Results/Comparing_Different_Pre-Training_Targets.png ADDED

Git LFS Details

  • SHA256: c2f5ea4f6c904c39e72d28b5587423d4f28260ef536cca4655568aebb70332ac
  • Pointer size: 131 Bytes
  • Size of remote file: 242 kB
Evaluation_Results/Comparing_with_the_SOTA_Ensemble_Models.png ADDED

Git LFS Details

  • SHA256: 4e4916c2f7c1d6cc32ce3093a8eb3a97cf52cc8fe181a91ed25dc7e1d908324a
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
Evaluation_Results/Comparing_with_the_SOTA_Single_Models.png ADDED

Git LFS Details

  • SHA256: e3b0e623169fa02769ed21d0e821ef5be5d7f9e1fd7aaaef0e2e386b61e52fe3
  • Pointer size: 131 Bytes
  • Size of remote file: 437 kB
__pycache__/BEATs.cpython-310.pyc ADDED
Binary file (4.13 kB). View file
 
__pycache__/BEATs.cpython-312.pyc ADDED
Binary file (7.34 kB). View file
 
__pycache__/modules.cpython-311.pyc ADDED
Binary file (11.1 kB). View file
 
__pycache__/modules.cpython-312.pyc ADDED
Binary file (10.2 kB). View file
 
__pycache__/spatial_beats.cpython-310.pyc ADDED
Binary file (49.1 kB). View file
 
__pycache__/spatial_dataset.cpython-311.pyc ADDED
Binary file (76 kB). View file
 
__pycache__/test_vectorized_matching.cpython-311.pyc ADDED
Binary file (19 kB). View file
 
docs/00_START_HERE.md ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🚀 START HERE — Spatial-BEATs Documentation Guide
2
+
3
+ **Welcome!** You have access to comprehensive analysis of the Spatial-BEATs codebase. This guide will direct you to exactly what you need.
4
+
5
+ ---
6
+
7
+ ## ⚡ Quick Pick Your Task
8
+
9
+ ### "I have 5 minutes"
10
+ → Read: [`SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md`](SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md)
11
+
12
+ ### "I have 15 minutes"
13
+ → Read: [`README_DOCUMENTATION_INDEX.md`](README_DOCUMENTATION_INDEX.md) then [`ANALYSIS_COMPLETION_SUMMARY.md`](ANALYSIS_COMPLETION_SUMMARY.md)
14
+
15
+ ### "I have 30 minutes"
16
+ → Choose one:
17
+ - **New to codebase?** Read: Part 1 of [`SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md`](SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md)
18
+ - **Debugging DOA gap?** Read: [`doa_train_valid_gap_analysis.md`](doa_train_valid_gap_analysis.md) Executive Summary
19
+ - **Planning experiments?** Read: [`0427_v11_series.md`](0427_v11_series.md) Section 1-2
20
+
21
+ ### "I have 1-2 hours"
22
+ → Full reading path for your role:
23
+ - **Researcher**: QUICK_REF → 0427_v11_series.md → ANALYSIS Part 2-3
24
+ - **Contributor**: QUICK_REF → ANALYSIS Part 1-2 → Pick component → read code
25
+ - **Investigator**: DOA_GAP Executive → Part 6 → Part 8 → Appendix
26
+
27
+ ---
28
+
29
+ ## 📚 The Five Documents
30
+
31
+ ### 1. **README_DOCUMENTATION_INDEX.md**
32
+ 🏠 **Navigation hub** — Where to find what
33
+ - Use case lookup (choose your problem)
34
+ - Code component quick reference
35
+ - Reading order for different roles
36
+ - Cross-reference guide
37
+
38
+ **👉 Read this first if**: You're not sure where to start
39
+
40
+ ---
41
+
42
+ ### 2. **SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md**
43
+ ⚡ **Practitioner's card** — Fast lookup
44
+ - Framework table (4 frameworks, 1 page)
45
+ - Route A/B/C comparison
46
+ - Version series highlights
47
+ - Code locations by component
48
+ - Loss weight patterns
49
+ - When to use each configuration
50
+
51
+ **👉 Read this for**: Quick answers, practitioner reference
52
+
53
+ ---
54
+
55
+ ### 3. **SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md**
56
+ 📖 **Deep technical reference** — Architecture bible
57
+ - Part 1: Four spatial frameworks (Spatial-AST, DCASE SELD, EINV2, DETR)
58
+ - Part 2: Routes A/B/C with full specifications
59
+ - Part 3: Version evolution (v7→v11)
60
+ - Part 4-10: Implementation, configs, metrics, future work
61
+ - Appendix: Code reference table
62
+
63
+ **👉 Read this for**: Deep understanding, architecture details, code paths
64
+
65
+ ---
66
+
67
+ ### 4. **doa_train_valid_gap_analysis.md**
68
+ 🔍 **Diagnostic & fix guide** — Root cause analysis
69
+ - Executive Summary: 6 critical mechanisms
70
+ - Part 1: Data pipeline analysis
71
+ - Part 2: Loss computation asymmetry
72
+ - Part 3: Training configuration (v9/v10)
73
+ - Part 4: Validation metrics
74
+ - **Part 6: Root causes ranked by severity**
75
+ - **Part 7: Diagnostic checklist**
76
+ - **Part 8: Recommended fixes (prioritized)**
77
+ - Appendix: Code reference with exact line numbers
78
+
79
+ **👉 Read this for**: Debugging train/val gaps, understanding root causes
80
+
81
+ ---
82
+
83
+ ### 5. **ANALYSIS_COMPLETION_SUMMARY.md**
84
+ 📋 **Executive overview** — What was found
85
+ - Deliverables summary (5 docs, 1,883 lines)
86
+ - Key findings (frameworks, routes, v11 series)
87
+ - Next steps (immediate vs experimental)
88
+ - How to use documents
89
+ - Verification checklist
90
+
91
+ **👉 Read this for**: Overview, decision-making, what comes next
92
+
93
+ ---
94
+
95
+ ## 🎯 Choose Your Path
96
+
97
+ ### Path 1: "I want to understand the architecture (60 min)"
98
+ 1. SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md (5 min)
99
+ 2. SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md Part 1-2 (25 min)
100
+ 3. Pick a component from Part 6 Appendix, find in code (20 min)
101
+ 4. spatial_beats_ov123_frame_routes.md if curious (10 min)
102
+
103
+ **Outcome**: Can navigate codebase, understand paradigms, modify code confidently
104
+
105
+ ---
106
+
107
+ ### Path 2: "I need to debug a train/val gap (30 min)"
108
+ 1. doa_train_valid_gap_analysis.md Executive Summary (2 min)
109
+ 2. Part 6: Check which mechanisms apply to your situation (5 min)
110
+ 3. Part 7: Diagnostics—check your logs (10 min)
111
+ 4. Part 8: Pick a fix priority (5 min)
112
+ 5. Appendix: Get code locations (2 min)
113
+ 6. SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md if you need to modify (optional)
114
+
115
+ **Outcome**: Root cause identified, fix strategy chosen, code locations ready
116
+
117
+ ---
118
+
119
+ ### Path 3: "I want to run an experiment (v11 series) (45 min)"
120
+ 1. SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md (5 min)
121
+ 2. 0427_v11_series.md Section 1-2 (15 min)
122
+ 3. SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md Part 3 (15 min)
123
+ 4. 0427_v11_series.md Part 4 (verification method) (5 min)
124
+ 5. Copy shell script from QUICK_REF (5 min)
125
+
126
+ **Outcome**: Experiment ready to launch, understanding of success metrics
127
+
128
+ ---
129
+
130
+ ### Path 4: "I'm new, I want to understand everything (2 hours)"
131
+ 1. SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md (10 min)
132
+ 2. SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md Part 1-3 (40 min)
133
+ 3. spatial_beats_ov123_frame_routes.md (25 min)
134
+ 4. spatial_beats_training_overview.md (20 min)
135
+ 5. Pick component, trace through code with Part 6 references (20 min)
136
+ 6. doa_train_valid_gap_analysis.md Part 6 for context (5 min)
137
+
138
+ **Outcome**: Comprehensive understanding of system, ready to contribute
139
+
140
+ ---
141
+
142
+ ## 🔑 Key Findings at a Glance
143
+
144
+ ### 4 Spatial Frameworks in Codebase
145
+ - **Spatial-AST**: Task tokens (pre-trunk)
146
+ - **DCASE SELD**: Per-class activity+DOA
147
+ - **EINV2**: Learnable track queries
148
+ - **DETR**: Per-frame K-slot allocation
149
+
150
+ ### 3 Parallel Routes (A/B/C)
151
+ - **Route A**: Per-frame K-slot, per-step Hungarian
152
+ - **Route B**: Learnable queries, clip-level Hungarian (PRODUCTION v9)
153
+ - **Route C**: Per-class vectors (PROTOTYPE, v11c test)
154
+
155
+ ### DOA Train/Val Gap Root Causes
156
+ 1. ⚠️⚠️⚠️ **ZERO spatial augmentation (rotations)** — 40-60% of variance
157
+ 2. ⚠️⚠️ **SpecAugment train-only** — 10-20% variance
158
+ 3. ⚠️⚠️ **v10 freezes direction head** — 30-40% on multi-source
159
+ 4. ⚠️ **Regression sensitivity** — 5-15% variance
160
+ 5. ⚠️ **Detached prediction asymmetry** — 2-5% variance
161
+
162
+ ### v11 Experiments (Parallel Runs)
163
+ - **v11a**: DOA demixer → ov2 angles ↓ 5pp+
164
+ - **v11b**: LocalSpatial pre-pool → test IV necessity
165
+ - **v11c**: ACCDOA paradigm → ov3 binding ↓ 5pp+
166
+ - **v11d**: Post-hoc calibration → ov1 ranking ↑ 5pp+
167
+
168
+ ---
169
+
170
+ ## 📞 FAQ
171
+
172
+ **Q: Where do I find the direction head loss?**
173
+ A: `SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md` Appendix → search "direction loss" → `spatial_loss.py:1562-1565`
174
+
175
+ **Q: What's the difference between routes?**
176
+ A: Compare table in `SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md` or detailed Part 2 of `SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md`
177
+
178
+ **Q: Should I implement fix #1, #2, or #3?**
179
+ A: Read `doa_train_valid_gap_analysis.md` Part 6, pick based on your gap size and risk tolerance.
180
+
181
+ **Q: How do I run v11a?**
182
+ A: Shell script in `SPATIAL_FRAMEWORKS_QUICK_REFERENCE.md` v11 section + spec in `0427_v11_series.md` Section 2.2
183
+
184
+ **Q: I'm stuck on a component. Where's the code?**
185
+ A: `SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS.md` Part 6 has complete reference table with file:line for every component.
186
+
187
+ ---
188
+
189
+ ## 🎁 You Now Have
190
+
191
+ ✅ **Navigation guide** for all documents
192
+ ✅ **Quick reference card** with all the essentials
193
+ ✅ **Architecture bible** with code paths
194
+ ✅ **Diagnostic guide** for train/val gaps
195
+ ✅ **Experimental specifications** for v11 series
196
+ ✅ **Comprehensive metadata** (1,883 lines, 77KB)
197
+ ✅ **All findings tied to exact code locations**
198
+
199
+ ---
200
+
201
+ ## 🚀 Next Steps
202
+
203
+ 1. **Choose your path above** based on how much time you have
204
+ 2. **Follow the reading order** in that path
205
+ 3. **Use cross-references** when you need more detail
206
+ 4. **Check Appendices** for exact code locations
207
+ 5. **Reference Part 6/Part 8** when implementing
208
+
209
+ ---
210
+
211
+ ## 📊 Document Overview
212
+
213
+ | Document | Size | Time | Purpose |
214
+ |----------|------|------|---------|
215
+ | README_DOCUMENTATION_INDEX | 12KB | 5-10m | Navigation hub |
216
+ | SPATIAL_FRAMEWORKS_QUICK_REFERENCE | 7KB | 5-10m | Quick lookup |
217
+ | SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS | 28KB | 30-45m | Deep reference |
218
+ | doa_train_valid_gap_analysis | 19KB | 20-30m | Diagnostics |
219
+ | ANALYSIS_COMPLETION_SUMMARY | 11KB | 10m | Executive summary |
220
+ | **TOTAL** | **77KB** | **2-4 hours** | **Complete set** |
221
+
222
+ ---
223
+
224
+ **Status**: ✅ Complete and ready for use
225
+ **Created**: 2026-04-27
226
+ **Next update**: After v11 experiments
227
+
228
+ 👉 **Pick your path above and start reading!**
docs/0427_v11_series.md ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 2026-04-27 — v11 系列实验:DOA demixer / ACCDOA 范式 / 校准对照
2
+
3
+ 本文档对应「v9 之后做什么」的对话定型。把 docs/0424.md 的诊断结论拆成四个独立、可单独评估、可并行跑的实验:v11a / v11b / v11c / v11d。
4
+
5
+ ## 1. 上一轮诊断回顾
6
+
7
+ 参考 docs/0423.md(v9 design)+ docs/0424.md(v9 real dump 拆解)。三个真实 split 的失败模式互不重合:
8
+
9
+ | split | 主要症状 | 责任层 |
10
+ | -------- | --------------------------------------------- | ------------------------------- |
11
+ | real_ov1 | raw 4-track 100% 同类候选,但 act>=0.5 后 37% GT 丢同类预测 | 排序 / activity 校准(**非架构**) |
12
+ | real_ov2 | 73.9% 的预测「同类有,但角度 >20°」 | direction head 本体(**架构**) |
13
+ | real_ov3 | 24.5% raw 层无同类候选 + avg_pred 1.82 < avg_gt 2.88 | binding(query→source)+ 少亮轨(**架构**) |
14
+
15
+ v9 已经给 class head 加过 `ClassHeadSpectralDemixer`(Fix C:track latent → BEATs trunk pre-pool grid 的频率轴 cross-attn,零门控残差)。但 **direction / distance head 没有对应通路**——它们只看 `track_time_features` 这一个 D 维向量,输入再往前是 `fused_spatial_embeddings`,已经被 `FrequencyPool(mean)` 平均掉。多源情况下,频率维被平均后单 D 向量无法同时表达两个方向,这是 real_ov2 的物理根因。
16
+
17
+ ## 2. 实验定义
18
+
19
+ ### v11a — DOA / 距离对称 demixer(最小改动)
20
+
21
+ **目的**:直接验证「v9 Fix C 没补到 DOA」是否就是 real_ov2 angle 错的根因。
22
+
23
+ **做法**:
24
+
25
+ * 在 `FrameTrackPredictionHeads` 里再加一个 `ClassHeadSpectralDemixer` 实例(参数独立,结构完全相同),命名 `spatial_head_demixer`,作用于 `direction_head` 和 `distance_head` 的输入。
26
+ * KV 与 class demixer 共用:BEATs trunk 的 pre-pool tokens `[B, T_p*F_p, D]` + grid_size `(T_p, F_p)` + pre-pool 时间 mask。
27
+ * 零门控初始化:`out_proj.weight = out_proj.bias = 0`,`gate = 1e-2`,所以 epoch-0 forward 与 v9 bit-equivalent。
28
+
29
+ **热启动**:`RESUME_CKPT` = v9 best.pt;`strict=False`;新增 13 个参数走默认零门控初值。`--no-resume-optimizer --reset-epoch-on-resume --reset-best-on-resume`。
30
+
31
+ **预期**:
32
+
33
+ * real_ov2:`class_right_angle_wrong` 从 73.9% 显著下降;`mean_best_angle_when_same_class_exists` 缩小。
34
+ * sim split(`valid__hm3d__`)的 `F20 / LE_CD / ocls` 不退化(零门控保证 epoch-0 安全,训练只能修不能炸)。
35
+ * real_ov1 / real_ov3 不一定改善——它们卡的不是 DOA 本身。
36
+
37
+ **证明的事**:v9 在 class 上加 demixer 是有效的,但 DOA 也需要同样的通路才能拿到等价收益;后 frequency_pool 单向量是 DOA 多源 demix 的硬瓶颈。
38
+
39
+ **入口**:`run_ov1_v11a_ov123_top4.sh` → preset `ov1_local_spatial_v11a_ov123_top4`。
40
+
41
+ ### v11b — DOA demixer 的 KV 换成 LocalSpatial pre-pool
42
+
43
+ **目的**:进一步追问——v11a 的 KV 是 BEATs 的 mono fbank pre-pool,本身没有方向信息(IV 是后面 `local_spatial_fuser` 才混进来的);如果让 DOA demixer 直接 attend 到 7 通道 FOA + IV 的 CNN pre-pool,会不会比 v11a 更好?
44
+
45
+ **做法**:
46
+
47
+ * 让 `LocalSpatialEncoder.forward(foa_feat, return_pre_pool=True)` 额外返回 4D CNN 特征 `[B, D_s, T_f, F_cnn]`(在 `mean(dim=-1)` 频率塌缩之前)。
48
+ * `build_local_spatial_fusion(..., return_local_pre_pool=True)` 透传,并 reshape 到 `[B, T_f*F_cnn, D_s]` + grid `(T_f, F_cnn)`。
49
+ * 新增 `local_spatial_pre_pool_proj: Linear(D_s -> D=768)`,xavier 初始化乘以 `local_spatial_proj_scale_init`,bias=0。
50
+ * `FrameTrackPredictionHeads.forward` 接 `spatial_pre_pool_features / spatial_pre_pool_grid_size / spatial_pre_pool_time_mask`;当传入时,DOA demixer 用这条 KV,否则回落到 v11a 的 BEATs trunk pre-pool。
51
+
52
+ **热启动**:同样 v9 best.pt + strict=False。`spatial_head_demixer` 仍然零门控;`local_spatial_pre_pool_proj` 是新参数,但 demixer gate=1e-2,`out_proj=0` 保证 epoch-0 数值与 v9 一致。
53
+
54
+ **预期**:
55
+
56
+ * 如果 v11b 比 v11a 在 real_ov2 上明显更好 → 物理 IV 信号确实是 DOA 必要输入,BEATs trunk pre-pool 信息不足。
57
+ * 如果 v11b ≈ v11a → BEATs trunk pre-pool 已经携带足够 spatial 上下文(`local_spatial_fuser` 把 IV 混回去了),DOA 头的瓶颈只在 demixer 本身。
58
+ * 如果 v11b < v11a → 新 KV 引入太多噪声 / projection 没充分训。
59
+
60
+ **证明的事**:架构里「方向先验来自哪里」的问题——是 fuser 后已经够,还是必须 pre-fuser 直读 IV。
61
+
62
+ **入口**:`run_ov1_v11b_ov123_top4.sh` → preset `ov1_local_spatial_v11b_ov123_top4`。
63
+
64
+ ### v11c — ACCDOA 范式对照
65
+
66
+ **目的**:质询「K-track DETR 范式」本身。v11a/v11b 都在补 head;但 real_ov3 24.5% 的 GT 在 raw 4-track 里就找不到同类候选——这意味着问题不在 head,而在 **��哪个 query 该负责哪个 source」** 的 binding 阶段。把范式整体换掉看是否绕得过去。
67
+
68
+ **做法**:
69
+
70
+ * `readout_scheme = local_spatial_accdoa`,per-class 3D 向量场 `v_c` :`||v_c||` = activity_c,`v_c/||v_c||` = DOA_c。无 query、无 Hungarian。
71
+ * ov2/ov3 的 same-class-in-same-frame 几乎为零,所以 per-class 输出是无歧义的。
72
+ * 接现有 `accdoa_heads` + `compute_frame_accdoa_losses`(仓库已实现)。
73
+
74
+ **冷启动**:拓扑与 v9 不兼容(无 `source_query_decoder`、无 `FrameTrackPredictionHeads`)。改用 `--init-from-spatial-ckpt` 从 ov1 local_spatial warmup ckpt 初始化(继承 BEATs trunk + LocalSpatialEncoder + fuser),strict=False。
75
+
76
+ **调度**:24 epochs,`lr=3e-5`(默认 1e-4 是 ov1 单源用的,没在多源上调过;3e-5 与 v9 同档,更安全)。
77
+
78
+ **预期**:
79
+
80
+ * real_ov3 `no_same_class_pred_but_other_preds_exist` 显著下降:query binding 不存在了,每个类天然有自己的 vector slot。
81
+ * real_ov2 也可能改善——因为 DOA 来自 per-class 向量,已经按类分离,避免多源 head 共享单 D 向量。
82
+ * sim ov1 上 ocls 可能略低于 v9(ACCDOA 把 activity/DOA 耦合,class 信号被向量幅值稀释);这个代价是已知的。
83
+
84
+ **证明的事**:query-binding 阶段是不是 real_ov3 的真正瓶颈。如果 v11c 就把 ov3 拉起来了,说明 head 修补(v11a/v11b)治不好 ov3;如果 v11c 也救不了 ov3,说明问题在更前——可能是 `LocalSpatialEncoder` 时空分辨率不够。
85
+
86
+ **入口**:`run_ov1_v11c_ov123_accdoa.sh` → preset `ov1_local_spatial_v11c_ov123_accdoa`。
87
+
88
+ ### v11d — activity 校准 + Top-K̂ 解码(纯后处理)
89
+
90
+ **目的**:real_ov1 的失血点在「raw 100% 有同类候选 → act>=0.5 后 37% 丢失」,是阈值 / 排序问题,**不是模型**。所以不改架构、不重训,只改 decode。同步给 v9 / v11a / v11b 出个 pareto,避免后续把 ranking 收益错算到 head 改动上。
91
+
92
+ **做法**:
93
+
94
+ `scripts/calibrate_activity.py` 重读已 dump 的 `*__pred.csv`(这些 CSV 里 `eval_v7k_real_valid.py` 已经写入 `activity_prob` 和——v10 head 在的话——`num_active_pred`)。三种 decode 模式:
95
+
96
+ * `threshold`:固定阈值(扫 0.3 / 0.4 / 0.5 / 0.6)。
97
+ * `topk_hat`:每帧按 `activity_prob` 降序取前 K̂ 个,K̂ = `num_active_pred`(v10 的 num_active_head argmax)。
98
+ * `topk_hat_min`:上面两条的 AND(K̂ 之内还要过最低阈值)。
99
+
100
+ 每个 (split, mode, thr) 计算与 `analyze_csv_dump.py` 同口径的指标:`hit_share`、`class_right_angle_wrong`、`matched_tp_precision/recall`、`mean_best_angle_when_same_class_exists`。再用 `pick_best` 给每个 split 找最优配置。
101
+
102
+ **输入**:任何已经包含 `__pred.csv / __gt.csv` 的目录(v9 best 的 dump、v11a/b/c 任一 epoch 的 dump 都行)。
103
+
104
+ **预期**:
105
+
106
+ * real_ov1:降阈值或换 `topk_hat` 后 `no_same_class_pred_but_other_preds_exist` 显著下降;存在「降阈值 → recall 大涨而 precision 小掉」的清晰拐点。
107
+ * real_ov2 / real_ov3:阈值调节收益有限——它们的问题不在 ranking。
108
+ * 出一个表,给每个 split 单独决定上线 decode 配置(线上不必三 split 共用同一阈值)。
109
+
110
+ **证明的事**:real_ov1 的 37% 丢失是 decode 过紧 / activity 分布漂移的可校正问题,**不需要拿训练侧资源去解**;同时给 v11a/v11b/v11c 报告时分离「ranking 收益」与「head 改动收益」,避免归因混淆。
111
+
112
+ **入口**:`scripts/calibrate_activity.py --dump-dir <csv_dir>`。
113
+
114
+ ## 3. 改动清单
115
+
116
+ 代码改动集中在四个文件:
117
+
118
+ * `spatial_modules.py`
119
+ * `LocalSpatialEncoder.forward` 新增 `return_pre_pool=False` 参数。
120
+ * `FrameTrackPredictionHeads.__init__` 新增 5 个 spatial demixer 配置项;`forward` 新增 3 个 `spatial_pre_pool_*` kwargs,复用 `ClassHeadSpectralDemixer` 类作 `spatial_head_demixer`。
121
+ * `spatial_beats.py`
122
+ * `SpatialBEATsConfig` 新增 5 个字段:`use_spatial_head_demixer / spatial_head_demixer_layers / heads / dropout / spatial_demixer_use_local_spatial_kv`。
123
+ * `__init__` 两处 `FrameTrackPredictionHeads` 构造点透传新 kwargs;同处按需创建 `local_spatial_pre_pool_proj`。
124
+ * `build_local_spatial_fusion` 新增 `return_local_pre_pool` 选项,6-tuple 返回;reshape 4D CNN feature 到 `[B, T_f*F_cnn, D_s]`。
125
+ * `forward` 两条 readout 分支(`local_spatial`、`local_spatial_track`)按配置切换 6-tuple,并把 `local_spatial_pre_pool_proj` 投影后的张量传给 frame_track head。
126
+ * `train_spatial_beats.py`
127
+ * 新增 3 个 preset factory:`make_ov1_local_spatial_v11a_ov123_top4_config`(继承 v9 + `use_spatial_head_demixer=True`),`v11b`(v11a + `spatial_demixer_use_local_spatial_kv=True`),`v11c`(包装 `make_ov123_local_spatial_accdoa_config` + 24 epochs + lr=3e-5)。
128
+ * 三处 dispatch 分支 + argparse `--preset` choices。
129
+ * `scripts/calibrate_activity.py`(新文件,纯 stdlib)
130
+ * 复用 `analyze_csv_dump.py` 的几何 / 计数逻辑,但读 `num_active_pred` 列做 Top-K̂ 解码;输出 per-split 最优配置。
131
+
132
+ 新增 shell 入口:
133
+
134
+ * `run_ov1_v11a_ov123_top4.sh`(master_port 29561)
135
+ * `run_ov1_v11b_ov123_top4.sh`(master_port 29562)
136
+ * `run_ov1_v11c_ov123_accdoa.sh`(master_port 29563)
137
+
138
+ ## 4. 验证方法
139
+
140
+ 每个实验跑完后用同一套口径回验:
141
+
142
+ ```bash
143
+ # 1. dump real valid CSV
144
+ python3 scripts/eval_v7k_real_valid.py \
145
+ --ckpt <ckpt_path> \
146
+ --dump-pred-dir <dump_dir> \
147
+ --dump-splits real_ov1,real_ov2,real_ov3 \
148
+ --activity-threshold 0.5
149
+
150
+ # 2. 阈值口径下的指标
151
+ python3 scripts/analyze_csv_dump.py \
152
+ --dump-dir <dump_dir> \
153
+ --threshold 0.5 \
154
+ --threshold-sweep 0.3 0.4 0.6
155
+
156
+ # 3. (v11d 用) 全 decode 模式扫描
157
+ python3 scripts/calibrate_activity.py \
158
+ --dump-dir <dump_dir> \
159
+ --thresholds 0.3 0.4 0.5 0.6 \
160
+ --modes threshold topk_hat topk_hat_min \
161
+ --json-out calibration.json
162
+ ```
163
+
164
+ 判 pass 的硬指标:
165
+
166
+ | 实验 | 主要观测 | 副作用看护 |
167
+ | ---- | ---------------------------------------------------------- | -------------------------------- |
168
+ | v11a | real_ov2 `class_right_angle_wrong` 下降 ≥ 5 pp | sim ov1 ocls / F20 不退化 |
169
+ | v11b | real_ov2 / real_ov3 vs v11a 是否更优 | 同 v11a + projection 不爆显存 |
170
+ | v11c | real_ov3 `no_same_class_pred_but_other_preds_exist` 下降 ≥ 5 pp | sim ov1 ocls 退化 < 3 pp |
171
+ | v11d | real_ov1 `hit_share` 在最佳 decode 下提升 ≥ 5 pp | precision 不崩(>= 0.6×baseline) |
172
+
173
+ ## 5. 跑的顺序建议
174
+
175
+ 1. **v11a**(最便宜、最可能直接命中 real_ov2,hot-start 完整保留)。
176
+ 2. **v11d** 并行:用 v9 best 的现成 dump 出 ranking pareto,拿 real_ov1 的「decode 收益基线」。这样 v11a 的 dump 出来后能立刻分离「DOA 收益」vs「decode 收益」。
177
+ 3. **v11b**:仅在 v11a 收益不达标、或想验「IV 直读 vs fuser 后」时启动。
178
+ 4. **v11c**:作为范式对照单独跑,主要看 real_ov3。和 v11a/v11b 不可比同一基础(拓扑不同),但 sim ov1 应保持在可接受退化内。
179
+
180
+ ## 6. 不在本轮范围内的事
181
+
182
+ * K=4 → 6/8 + query 正交正则(计划里的 P3)。代价高且需要重启动;放在 v11c 结果出来之后再决定要不要做。
183
+ * `SourceQueryDecoder` memory 升级为 `[B, T_s*F_p, D]` pre-freq-pool(计划里的 P2a)。改动面太大,先用 v11a/v11b 的 head-side demixer 取等效收益。
184
+ * 真实数据 finetune 调度(`run_ov1_v9_real_balanced_*.sh` 已经在跑),与 v11 系列正交。
docs/0429_v11a_with_dynamic.md ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 0429 · v11a_with_dynamic_10hz —— 动态 DOA 监督 + 真实/QA 数据扩展
2
+
3
+ > 记录当前 `ov1_local_spatial_v11a_with_dynamic_10hz` preset 的完整链路:
4
+ > 数据 → loader → 模型 → loss;并标明相对于 `v11a_real_balanced_10hz` 的每一处差异。
5
+ > 对应代码入口:
6
+ > - Preset: `train_spatial_beats.py::make_ov1_local_spatial_v11a_with_dynamic_10hz_config`
7
+ > - Run script: `run_ov1_v11a_with_dynamic_10hz.sh`
8
+ > - DCASE 转换器: `tools/dcase_starss_to_jsonl.py`
9
+
10
+ ---
11
+
12
+ ## 1. 一句话定位
13
+
14
+ **v11a_with_dynamic_10hz = v11a_real_balanced_10hz 的训练数据扩展 + loader/loss 升级到逐帧 target。**
15
+
16
+ - **模型结构不变**:仍然是 `local_spatial_track` + `SourceQueryDecoder (K=4)` + spatial_head_demixer
17
+ (v11a 的新组件),10 Hz token rate,ov123 top4 目录。
18
+ - **预测侧不变**:`FrameTrackPredictionOutput` 仍然是 `[B, K, T_s, ...]` 的逐帧四元组
19
+ `(activity, class, direction, distance)`。
20
+ - **监督侧升级**:target 张量从 `[B, N_gt]` 扩展为 `[B, N_gt, T_s]`——静态源沿 T_s 轴广播
21
+ (和旧行为一致),动态源按每帧轨迹线性插值到 10 Hz 栅格。
22
+ - **数据扩展**:新增 5 个训练 manifest(qa_moving / qa_counting / qa_lr_pair /
23
+ qa_same_doa / dcase_starss_foa.train)和 1 个验证 manifest(dcase_starss_foa.valid)。
24
+ - **Hot-start**:默认从 `v11a_real_balanced_10hz/03_ov123_top4/best.pt` 继续训练,`strict=False`
25
+ 且不继承 optimizer/epoch/best。上游 ov123 静态 clip 的 epoch 0 loss 应与 v11a 吻合
26
+ (逐帧 target 对静态源退化为标量广播)。
27
+
28
+ ---
29
+
30
+ ## 2. 数据:新增的 manifest 与样本格式
31
+
32
+ ### 2.1 训练集组成
33
+
34
+ | manifest | 记录数 | 类型 | DOA 来源 | distance 有效 | 复制次数 |
35
+ | --- | --- | --- | --- | --- | --- |
36
+ | `ov1_foa.jsonl` (sim) | — | ov1 静态 | scalar | ✓ | 1 |
37
+ | `ov2_foa.jsonl` (sim) | — | ov2 静态 | scalar | ✓ | 3 |
38
+ | `ov3_foa.jsonl` (sim) | — | ov3 静态 | scalar | ✓ | 3 |
39
+ | `ov1_real_static_foa_mapped.jsonl` | — | ov1 real 静态 | scalar | ✗ (null) | 4 |
40
+ | `ov2_real_static_foa_mapped.jsonl` | — | ov2 real 静态 | scalar | ✗ | 8 |
41
+ | `ov3_real_static_foa_mapped.jsonl` | — | ov3 real 静态 | scalar | ✗ | 8 |
42
+ | **`qa_moving.jsonl`** | 19 597 | QA sim 动态(单源平滑轨迹) | `frames[]` per-frame | ✓ | **2** |
43
+ | **`qa_counting.jsonl`** | 2 428 | QA sim 静态(多源 2-5 个) | scalar | ✓ | **1** |
44
+ | **`qa_lr_pair.jsonl`** | 6 631 | QA sim 静态(左右成对) | scalar | ✓ | **1** |
45
+ | **`qa_same_doa.jsonl`** | 7 896 | QA sim 静态(同一方向多源) | scalar | ✓ | **1** |
46
+ | **`dcase_starss_foa.train.jsonl`** | 12 805 | DCASE 真录 20s 多源动态 | `frames[]` per-frame | ✗ (-1) | **2** |
47
+
48
+ 对应的 `train_manifest_replication = (1, 3, 3, 4, 8, 8, 2, 1, 1, 1, 2)`。
49
+ - 总训练 clip 数(未复制)= ov123sim + ov123real + 5 × 新 manifest ≈ ov123 基础 + 49 357。
50
+ - 验证集相比 v11a 增加 `dcase_starss_foa.valid.jsonl` (4 560 clips),作为真实录音的统一评估入口。
51
+
52
+ ### 2.2 manifest schema(以 qa_moving / DCASE 为例)
53
+
54
+ **qa_moving.jsonl**(来自 `build_qa_foa_moving.py` 合成管道)
55
+
56
+ ```jsonc
57
+ {
58
+ "scene_id": "...",
59
+ "output_foa_path": "/abs/.../foa.wav",
60
+ "output_duration_seconds": 10.0,
61
+ "sample_rate": 16000,
62
+ "frame_rate": 10.0,
63
+ "num_frames": 100,
64
+ "sources": [
65
+ {
66
+ "source_index": 0,
67
+ "is_moving": true,
68
+ "mono_target_label": "speech", // FSD50K 63-class 名字
69
+ "active_time": [0.0, 10.0],
70
+ "doa": null, // ← 动态源 scalar doa 为空
71
+ "distance_cm": 150.0, // clip 级 fallback 距离
72
+ "trajectory": "...", // sweep_arc / lshape / ...
73
+ "frames": [
74
+ {"frame_idx": 0, "doa": {"azimuth_deg": 175.6, "elevation_deg": -3.2},
75
+ "distance_cm": 150.2},
76
+ {"frame_idx": 1, "doa": {"azimuth_deg": 177.9, "elevation_deg": -3.1},
77
+ "distance_cm": 150.1},
78
+ ...
79
+ ]
80
+ }
81
+ ]
82
+ }
83
+ ```
84
+
85
+ **dcase_starss_foa.{train,valid,test}.jsonl**(由 `tools/dcase_starss_to_jsonl.py` 生成)
86
+
87
+ ```jsonc
88
+ {
89
+ "scene_id": "fold1_starss22__fold4_room10_mix001_0",
90
+ "dataset_source": "dcase_starss",
91
+ "split": "train",
92
+ "output_foa_path": "/abs/.../foa.wav",
93
+ "output_duration_seconds": 20.0,
94
+ "sample_rate": 16000,
95
+ "frame_rate": 10.0,
96
+ "sources": [
97
+ {
98
+ "source_index": 0,
99
+ "is_moving": true,
100
+ "dcase_class_idx": 5,
101
+ "dcase_source_idx": 3,
102
+ "mono_target_label": "speech", // 经 DCASE_TO_FSD50K 重映射
103
+ "mono_primary_label": "male_speech",
104
+ "active_time": [1.3, 4.8],
105
+ "full_time": [0.0, 20.0],
106
+ "doa": null,
107
+ "distance_cm": -1, // DCASE 没有距离
108
+ "distance_valid": false,
109
+ "frames": [
110
+ {"frame_idx": 13, "time_s": 1.3,
111
+ "doa": {"azimuth_deg": -45.0, "elevation_deg": 10.0},
112
+ "distance_cm": -1},
113
+ ...
114
+ ]
115
+ },
116
+ ... // 可能 4+ 个 track,但逐帧同时 active 的 ≤ 4(K=4 由 matcher 保证)
117
+ ]
118
+ }
119
+ ```
120
+
121
+ ### 2.3 生成 DCASE manifest 的一次性步骤
122
+
123
+ ```bash
124
+ python tools/dcase_starss_to_jsonl.py \
125
+ --dcase-root /apdcephfs_cq10/.../DCASE2024_seld_baseline/prepared_datasets/starss23_foa_plus_29cls_20s \
126
+ --output /apdcephfs_cq10/.../data/metadata/dcase_starss_foa.jsonl \
127
+ --per-split-output
128
+ ```
129
+
130
+ - 扫描 `metadata_dev/<dataset>/<stem>.csv`,每行 `frame_idx, class_idx, source_idx, az_deg, el_deg, dist_cm`。
131
+ - 按 `(class_idx, source_idx)` 分 track,若相邻 labelled frame 间隔 > `gap_split_frames` (默认 50 帧 = 5s),
132
+ 就把 track 切成多段 `SourceEvent`,避免在静默区间乱插值。
133
+ - **类别空间压缩**:DCASE 29 类 → FSD50K 63 类的语义最近邻映射
134
+ (`DCASE_TO_FSD50K` dict,见文件头部);碰不上 FSD50K 词表的 DCASE 类(如 `unknown_*`)整 track 丢弃。
135
+ - 输出统计:18 061 CSV → train 12 805 / valid 4 560 / test 505。
136
+
137
+ ### 2.4 FSD50K 63 类别名
138
+
139
+ qa_*/DCASE manifest 里可能出现细粒度标签(`male_singing` / `female_singing`),
140
+ 但 v11a 的 vocab 只有压缩后的 63 类(含 `singing` 不含性别变体)。在 `spatial_dataset.py` 的
141
+ `_resolve_class_index` / `_resolve_class_label` 之前先跑一次 `_LABEL_ALIASES.get(raw, raw)` 归一化:
142
+
143
+ ```python
144
+ _LABEL_ALIASES = {
145
+ "male_singing": "singing",
146
+ "female_singing": "singing",
147
+ }
148
+ ```
149
+
150
+ 这样不用重新生成 jsonl,就能把 508 条 `male_singing` + 527 条 `female_singing` 折到 `singing` 上。
151
+
152
+ ---
153
+
154
+ ## 3. Loader:`spatial_dataset.py` 的逐帧化改造
155
+
156
+ ### 3.1 `SourceEvent` 新增 5 个可选字段
157
+
158
+ ```python
159
+ @dataclass
160
+ class SourceEvent:
161
+ class_index: int
162
+ class_label: str
163
+ azimuth_deg: float # 静态 scalar;动态时是 frames[0] 的 fallback
164
+ elevation_deg: float
165
+ distance: float # 动态时是第一个 valid frame 的 fallback
166
+ distance_valid: bool
167
+ start_time_seconds: float
168
+ end_time_seconds: float
169
+ # ---- 动态轨迹(仅动态源设置)----
170
+ frame_times_s: Optional[Tensor] = None # [N_f] 秒,相对 clip 起点
171
+ frame_azi_deg: Optional[Tensor] = None # [N_f] 度,未 unwrap
172
+ frame_ele_deg: Optional[Tensor] = None # [N_f] 度
173
+ frame_distance_m: Optional[Tensor] = None # [N_f] 米
174
+ frame_distance_valid: Optional[Tensor] = None # [N_f] bool
175
+ ```
176
+
177
+ ### 3.2 `_parse_frame_trajectory`
178
+
179
+ 从 manifest 的 `frames[]` 里抽出 5 个 1D tensor,同时支持两种 layout:
180
+
181
+ 1. **qa_moving**:每帧带 `frame_idx`(不带 `time_s`),clip 级 `frame_rate` 用于换算 `time_s = frame_idx / frame_rate`。
182
+ 2. **DCASE 转换器输出**:每帧直接给 `time_s`,跳过 `frame_idx / frame_rate` 换算。
183
+
184
+ 距离单位处理:优先读 `distance_cm`(除以 100 得米),`-1` 或缺失标记为 `distance_valid=False`;
185
+ 其次读 `distance_m`。
186
+
187
+ ### 3.3 `_build_source_event_from_nested_entry` 的 fallback
188
+
189
+ - 动态源 top-level `doa` 通常是 `null`,所以把 `_get_float` 换成 `_maybe_get_float`,
190
+ 再用 `frames[0]` 的 DOA 补 `azimuth_deg` / `elevation_deg`(scalar fallback,只有在
191
+ loss 层碰到静态路径时才会用到)。
192
+ - 距离同理:若 source-level `distance_valid=False`,但 `frames[]` 里有至少一个 `distance_cm >= 0`,
193
+ 就用第一个 valid frame 的距离作为 scalar fallback;否则保留 `distance_valid=False`。
194
+
195
+ ### 3.4 `_maybe_crop_sample` 的轨迹裁剪
196
+
197
+ 随机/中心裁剪时,除了裁 waveform 和更新 `start/end_time_seconds`,还要:
198
+
199
+ - 用 `new_start/new_end` 窗口过滤 `frame_times_s`,把留下来的帧时间重置到新 clip 起点(`- crop_start_seconds`)。
200
+ - `frame_azi_deg` / `frame_ele_deg` / `frame_distance_m` / `frame_distance_valid` 一起按索引截断。
201
+
202
+ 保证裁剪后的 `SourceEvent` 时间轴仍和 waveform 同源。
203
+
204
+ ### 3.5 Collate:`[B, N_gt, T_s]` 逐帧 target
205
+
206
+ `collate_spatial_batch` 相对旧实现的关键变化:
207
+
208
+ ```python
209
+ t_s_max = int(target_num_steps.max()) # batch 内最大 token 数
210
+ source_azimuth_deg = zeros(B, N_gt_max, t_s_max) # 原来是 (B, N_gt_max)
211
+ source_elevation_deg = zeros(B, N_gt_max, t_s_max)
212
+ source_distance = zeros(B, N_gt_max, t_s_max)
213
+ source_distance_valid = ones (B, N_gt_max, t_s_max, dtype=bool) # 默认 True
214
+
215
+ for b, sample in enumerate(samples):
216
+ t_axis = arange(t_s_i) / target_token_rate # 该 sample 的有效时间轴
217
+ for s, source in enumerate(sample.sources):
218
+ azi_row, ele_row, dist_row, dist_valid_row = _build_per_frame_targets(
219
+ source=source, t_axis=t_axis, t_s_max=t_s_max,
220
+ )
221
+ source_azimuth_deg[b, s] = azi_row
222
+ source_elevation_deg[b, s] = ele_row
223
+ source_distance[b, s] = dist_row
224
+ source_distance_valid[b, s] = dist_valid_row
225
+ ```
226
+
227
+ `_build_per_frame_targets` 的两条路径:
228
+
229
+ - **静态源**(`frame_times_s is None`���:在 `[0:t_s_i)` 填入 scalar;`[t_s_i:t_s_max)` 填零(padding)。
230
+ 行为等价于旧版广播。
231
+ - **动态源**:对 `t_axis` 做线性插值。方位角先用 `_unwrap_deg` 去掉 ±180° 的跳变
232
+ (qa_moving / DCASE 都可能有 170° → -170° 这样跨接的情况),插值后再 wrap 回 `[-180, 180]`;
233
+ elevation / distance 直接线性插值;`distance_valid` 用**两端都 valid 才 valid**的逻辑
234
+ (`_linear_interp_valid_mask`),避免在未知距离段里猜出假的 valid。
235
+
236
+ ### 3.6 `SpatialBatch` 的契约变化
237
+
238
+ ```python
239
+ @dataclass
240
+ class SpatialBatch:
241
+ ...
242
+ source_azimuth_deg: Tensor # [B, N_gt_max, T_s_max] ← 原 [B, N_gt_max]
243
+ source_elevation_deg: Tensor # [B, N_gt_max, T_s_max]
244
+ source_distance: Tensor # [B, N_gt_max, T_s_max]
245
+ source_distance_valid: Tensor # [B, N_gt_max, T_s_max] 新字段
246
+ source_class_indices: Tensor # [B, N_gt_max] (class 仍是 clip 级)
247
+ source_start_time_seconds: Tensor # [B, N_gt_max]
248
+ source_end_time_seconds: Tensor # [B, N_gt_max]
249
+ source_valid_mask: Tensor # [B, N_gt_max]
250
+ ```
251
+
252
+ `source_class_indices` 保持 clip 级:v11a 没有「同一 source 换类」的需求,且对应 track 内 class 恒定。
253
+
254
+ ---
255
+
256
+ ## 4. 模型:和 v11a 完全一致
257
+
258
+ **没改**,为了让 hot-start 生效。这里简要记录一下 v11a 已有的配置,便于对照:
259
+
260
+ ```
261
+ FOA 4-ch waveform @ 16 kHz
262
+
263
+
264
+ SpatialBEATsPreprocessor (mel, iv feat)
265
+
266
+ ┌───────────────────┴──────────────────┐
267
+ ▼ ▼
268
+ SpatialPatchEmbedding SpatialDeltaPatchAdapter
269
+ (mel → 768-d patch tokens) (+IV residual contribution)
270
+ └───────────────┬──────────────────────┘
271
+
272
+ BEATs TransformerEncoder (12 层,冻结)
273
+
274
+
275
+ LocalSpatialEncoder (IV-aware conv over (T_p, F_p))
276
+
277
+
278
+ TemporalResampler → fused_spatial_embeddings [B, T_s, 768]
279
+ (T_s @ 10 Hz ,cfg.target_token_rate=10)
280
+
281
+
282
+ SourceQueryDecoder (K=4 queries × T_s 次 decode,两段式)
283
+ • track_latents: [B, K, D]
284
+ • track_time_feat:[B, K, T_s, D]
285
+
286
+
287
+ FrameTrackHeads (+ SpatialHeadDemixer 1 层 attn refine, heads=8)
288
+ • pred_activity: [B, K, T_s]
289
+ • pred_class_logits: [B, K, T_s, 63]
290
+ • pred_direction: [B, K, T_s, 3] L2-normed
291
+ • pred_distance: [B, K, T_s] softplus 米
292
+ • pred_num_active_logits: [B, T_s, K+1] (v10 num_active head)
293
+ ```
294
+
295
+ v11a 相对 v9 的关键新组件(均保留):
296
+ - `use_spatial_head_demixer=True`(1 层 self-attn,8 heads,dropout 0.1)—— 在 FrameTrack head 输出后做一次
297
+ track 维解相关。
298
+ - `local_spatial_lr_scale=1.0` —— LocalSpatialEncoder 和 head 用相同 LR(v9 默认 0.3 偏低)。
299
+
300
+ ---
301
+
302
+ ## 5. Loss:逐帧 target + distance valid mask
303
+
304
+ 入口 `compute_frame_track_losses(prediction_output, batch, temporal_padding_mask, config)`。
305
+
306
+ ### 5.1 target 抽取
307
+
308
+ ```python
309
+ targets = _frame_source_target_tensors(batch, t_s_max, device)
310
+ # 返回:
311
+ # window_mask: [B, N_gt, T_s] (active_time 内为 True)
312
+ # source_valid: [B, N_gt]
313
+ # source_class: [B, N_gt]
314
+ # source_direction: [B, N_gt, T_s, 3] ← 逐帧 unit vector
315
+ # source_distance: [B, N_gt, T_s] ← 逐帧米
316
+ # source_distance_valid: [B, N_gt, T_s] ← 逐帧 bool
317
+ ```
318
+
319
+ 对于来自 loader 的 `source_azimuth_deg` / `source_elevation_deg`,`_align_t` 处理长度不匹配:
320
+ batch 内 `t_s_max` 可能与 loader 构造时的大小不同(不同 DataLoader 的 collate 边界),
321
+ 短则 pad 末帧的值,长则截断。
322
+
323
+ ### 5.2 Hungarian 匹配(代价按每帧)
324
+
325
+ `_match_frame_tracks`(`per_frame` 或 `segment` 两种策略,preset 里走 `segment`)
326
+ 在 `[B, N, K, T]` 的代价张量上做匹配:
327
+
328
+ ```
329
+ cost[b, n, k, t] = class_cost_w * NLL(pred_class, target_class[b, n])
330
+ + dir_cost_w * (1 - pred_direction[b,k,t] · target_direction[b,n,t])
331
+ + dist_cost_w * |pred_distance[b,k,t] - target_distance[b,n,t]|
332
+ + (1 - σ(pred_activity[b,k,t])) # include_activity_cost
333
+ ```
334
+
335
+ **关键点**:`target_direction` / `target_distance` 从 `[B, N, T, *]` 广播到 `[B, N, 1, T, *]`��之前是 clip 级标量),
336
+ 代价按每帧独立累加,所以动态源在不同帧的 best-match track 可以不同。segment matching
337
+ 额外加了一个 −2.0 的 continuity bonus,让同一 GT 在连续的相同 active-set segment 里尽量停留在同一 track。
338
+
339
+ ### 5.3 监督张量的构建
340
+
341
+ ```python
342
+ matched_track: [B, N_gt, T_s] (匹配结果 k∈[0,K) 或 -1)
343
+ valid_assign = matched_track >= 0
344
+ idx_b, idx_gt, idx_t = valid_assign.nonzero(as_tuple=True)
345
+ idx_k = matched_track[idx_b, idx_gt, idx_t]
346
+
347
+ activity_target[idx_b, idx_k, idx_t] = 1.0
348
+ class_target [idx_b, idx_k, idx_t] = targets["source_class"][idx_b, idx_gt]
349
+ direction_target[idx_b, idx_k, idx_t] = targets["source_direction"][idx_b, idx_gt, idx_t] # ← 3D 索引
350
+ distance_target [idx_b, idx_k, idx_t] = targets["source_distance" ][idx_b, idx_gt, idx_t] # ← 3D 索引
351
+ dist_supervise_mask[idx_b, idx_k, idx_t] = targets["source_distance_valid"][idx_b, idx_gt, idx_t]
352
+ ```
353
+
354
+ 相对旧版(`targets["source_direction"][idx_b, idx_gt]` 是 `[M, 3]`)的差别是把第三个 axis 替换为具体
355
+ 的 `idx_t`,真正拿到逐帧 GT。`supervise_mask` 是 activity-winning 的 mask,`dist_supervise_mask`
356
+ 在其基础上再 AND 一个**逐帧 distance validity**:STARSS/DCASE 整源为 False 时,
357
+ distance loss 就不会回传任何梯度。
358
+
359
+ ### 5.4 各项损失
360
+
361
+ | 项 | 公式 | mask |
362
+ | --- | --- | --- |
363
+ | activity | `BCE_with_logits(pred_activity, activity_target, pos_weight=dyn)` | `valid_time` 扩到 [B, K, T_s] |
364
+ | num_active (v10) | `CE(pred_num_active_logits, active_count)` | `valid_time` |
365
+ | class | `CE(pred_class_logits, class_target)` + 可选 ontology smoothing | `supervise_mask` |
366
+ | direction | `mean(1 - pred · target)` | `supervise_mask` |
367
+ | distance | `smooth_l1(pred, target)` | `dist_supervise_mask` ← **逐帧 validity** |
368
+
369
+ **ADPIT duplicate** & **nonwinner soft activity**(v9/v10 的两个辅助)同样全部走逐帧 `source_direction[..., t]`
370
+ 和 `source_distance_valid[..., t]` 索引;旧版的 `batch.source_azimuth_deg[:, 0]` 之类的 2D 访问被替换成
371
+ `[:, 0, 0]`(共 44 处)以避免 shape 冲突。
372
+
373
+ 最终汇总:
374
+ ```
375
+ loss_total = λ_act · loss_activity
376
+ + λ_cls · loss_class
377
+ + λ_dir · loss_direction
378
+ + λ_dist · loss_distance
379
+ + λ_na · loss_num_active
380
+ ```
381
+ λ 与 v11a 完全相同(在 `SpatialLossConfig` 里走 v10 phase-2 的基线数值)。
382
+
383
+ ### 5.5 静态源的退化等价性
384
+
385
+ 因为 loader 把静态源沿 T_s 轴广播,`target_direction[b, gt, 0:T_s_i]` 每一帧都一致,
386
+ Hungarian 代价 `(1 - pred·target)` 和 per-clip 版本逐项相等;distance 同理。因此
387
+ **ov123 sim/real clip 的 epoch 0 loss 与 v11a 数值吻合**,这也是为什么可以直接从 v11a best.pt 热启动。
388
+
389
+ ---
390
+
391
+ ## 6. 训练配置(preset diff)
392
+
393
+ ```python
394
+ def make_ov1_local_spatial_v11a_with_dynamic_10hz_config(...):
395
+ cfg = make_ov1_local_spatial_v11a_real_balanced_10hz_config(...)
396
+
397
+ # —— 只改了数据和轮次 ——
398
+ cfg.train_manifest_paths = (
399
+ ov1_sim, ov2_sim, ov3_sim,
400
+ ov1_real, ov2_real, ov3_real,
401
+ qa_moving, qa_counting, qa_lr_pair, qa_same_doa,
402
+ dcase_starss_train,
403
+ )
404
+ cfg.train_manifest_replication = (1, 3, 3, 4, 8, 8, 2, 1, 1, 1, 2)
405
+
406
+ cfg.val_manifest_paths = (
407
+ ov1_sim, ov2_sim, ov3_sim,
408
+ ov1_real, ov2_real, ov3_real,
409
+ dcase_starss_valid,
410
+ )
411
+ cfg.test_manifest_paths = cfg.val_manifest_paths
412
+
413
+ cfg.num_epochs = 15
414
+ cfg.output_dir = "checkpoints/spatial_beats_ov1_local_spatial_v11a_with_dynamic_10hz_exp/03_ov123_top4"
415
+ return cfg
416
+ ```
417
+
418
+ 运行入口(`run_ov1_v11a_with_dynamic_10hz.sh`)默认:
419
+ ```
420
+ GPUS=8 BATCH_SIZE=4 NUM_WORKERS=8
421
+ SPATIAL_EPOCHS=15 SPATIAL_LR=1.5e-5 AMP=fp32
422
+ RESUME_CKPT=checkpoints/spatial_beats_ov1_local_spatial_v11a_real_balanced_10hz_exp/03_ov123_top4/best.pt
423
+ --no-resume-optimizer --reset-epoch-on-resume --reset-best-on-resume
424
+ ```
425
+
426
+ ---
427
+
428
+ ## 7. 与 v11a_real_balanced_10hz 的差异一览
429
+
430
+ | 维度 | v11a_real_balanced_10hz | **v11a_with_dynamic_10hz** |
431
+ | --- | --- | --- |
432
+ | 训练 manifest 数 | 6 (ov123 sim + ov123 real) | **11**(+5 个 QA/DCASE) |
433
+ | 训练 clip 量(未复制) | ~O(20k) | +49 357 |
434
+ | 真实录音监督 | ov123 real 静态 scalar | **+DCASE STARSS 20s 逐帧** |
435
+ | 动态源监督 | 无 | **qa_moving + DCASE 动态 track** |
436
+ | `SourceEvent` 字段 | 无 `frame_*` | +5 个 `frame_times_s/azi/ele/dist/distance_valid` |
437
+ | Loader target shape | `source_*: [B, N_gt]` | **`source_*: [B, N_gt, T_s]`**(静态广播,动态插值) |
438
+ | `source_distance_valid` | `[B, N_gt]` | **`[B, N_gt, T_s]`**(逐帧) |
439
+ | Hungarian 代价 | dir/dist 代价用 clip 级标量 target | **用逐帧 `[B, N, T, *]` target 广播** |
440
+ | distance 监督 mask | 按源 `distance_valid` | **按 (源, 帧) `dist_supervise_mask`**,STARSS/DCASE 不回传梯度 |
441
+ | Class 词表别名 | 无 | **`_LABEL_ALIASES` 处理 `male_singing`/`female_singing → singing`** |
442
+ | 验证集 | ov123 sim+real | **+ DCASE valid (4 560 clips)** |
443
+ | `num_epochs` | 15 (继承 v9) | 15(不变) |
444
+ | 模型结构 | local_spatial_track + K=4 + demixer | **完全一致** |
445
+ | 热启动 | v9_real_balanced_10hz best.pt | **v11a_real_balanced_10hz best.pt**, strict=False |
446
+ | 输出目录 | `.../v11a_real_balanced_10hz_exp/03_ov123_top4` | `.../v11a_with_dynamic_10hz_exp/03_ov123_top4` |
447
+
448
+ ---
449
+
450
+ ## 8. 已知坑 & 兼容性备忘
451
+
452
+ 1. **`pred_*` 的 t_s 与 loader 的 `T_s_max` 有可能不一致**。loss 侧 `_align_t` 会 pad/truncate GT 的最后一维;
453
+ loader 侧已保证 `target_num_steps = round(duration × target_token_rate)` 与模型 temporal resampler 对齐,
454
+ 正常只有在 batch 内最长样本决定 `t_s_max` 而 frame-track head 输出更短时需要截断。
455
+ 2. **方位角跨 ±180°**:qa_moving 里观测到 `175.6° → -177.3°` 的自然轨迹,
456
+ `_unwrap_deg` 会把它解包成 `175.6° → 182.7°` 插值,最后 wrap 回 `[-180, 180]`,不会出现 -340° 的错误跨越。
457
+ 3. **DCASE 一个 clip 可以出现 6+ 个 track**,但任意帧同时 active 的 ≤ 4(DCASE 规范约束);
458
+ `_match_frame_tracks_per_frame` 里 `active_count.clamp(max=K)` 是一层防御。
459
+ 4. **distance=-1 的 clip** 在 loss 内走 `dist_supervise_mask` 全 False 的分支:
460
+ `loss_distance = pred_distance.sum() * 0.0`,梯度为 0,这一 clip 只贡献 activity/class/direction loss。
461
+ 5. **`male_singing/female_singing`** 是 loader 侧别名,若以后又有新细粒度标签冒出来,直接在
462
+ `spatial_dataset.py` 的 `_LABEL_ALIASES` 里加映射即可,不需要重跑数据。
463
+ 6. **ov123 静态 clip 的 epoch 0 loss 必须与 v11a 一致**——这是验证 loader/loss 升级没有引入
464
+ 回归的烟雾测试;曾经跑过 smoke script 确认 variance=0 for 静态 target、variance>0 for qa_moving target。
465
+
466
+ ---
467
+
468
+ ## 9. 相关脚本索引
469
+
470
+ - `tools/dcase_starss_to_jsonl.py` —— DCASE CSV → jsonl + FSD50K 类别映射
471
+ - `run_ov1_v11a_with_dynamic_10hz.sh` —— 训练入口(含 6 个动态 manifest 存在性 warning)
472
+ - `spatial_dataset.py::_parse_frame_trajectory` / `_build_per_frame_targets` —— 动态 target 构建
473
+ - `spatial_loss.py::_frame_source_target_tensors` / `compute_frame_track_losses` —— 逐帧 loss
474
+ - Preset: `train_spatial_beats.py::make_ov1_local_spatial_v11a_with_dynamic_10hz_config`
475
+ (CLI 名: `--preset ov1_local_spatial_v11a_with_dynamic_10hz`)
docs/V11_QUICK_START.md ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v11 Architecture - Quick Start Guide
2
+
3
+ ## What is v11?
4
+
5
+ The v11 series represents a major architectural enhancement addressing the classification accuracy plateau at ~51% in the Spatial-BEATs model. It introduces:
6
+
7
+ 1. **SpatialDeltaPatchAdapterV2**: Enhanced front-end spatial encoder (17.4M params)
8
+ 2. **SpatialAdapterLayer**: In-trunk spatial conditioning (1.2M params total)
9
+ 3. **Multiple routing options**: Route A/B (track-based) or Route C (ACCDOA class-based)
10
+
11
+ ---
12
+
13
+ ## Four v11 Variants
14
+
15
+ ### 1. v11_phase1_cls - Phase 1 Classification Refinement
16
+
17
+ **Use this first** to diagnose if the new adapter improves classification accuracy.
18
+
19
+ **What it does:**
20
+ - Enables SpatialDeltaPatchAdapterV2 only
21
+ - Freezes direction/distance heads
22
+ - Trains classification + num_active heads
23
+ - Hot-starts from v10 phase-1 best checkpoint
24
+
25
+ **Command:**
26
+ ```bash
27
+ ./run_ov1_v11_phase1_cls.sh
28
+ ```
29
+
30
+ **Environment variables:**
31
+ ```bash
32
+ SPATIAL_EPOCHS=10 # Default: 10 epochs
33
+ SPATIAL_LR=7.5e-6 # Default: 7.5e-6
34
+ BATCH_SIZE=8 # Default: 8
35
+ GPUS=8 # Default: 8
36
+ ```
37
+
38
+ **Expected output:**
39
+ ```
40
+ Epoch 1: cls_acc=0.720 (should be at least v10 level)
41
+ Epoch 5: cls_acc=0.755 (expect improvement trend)
42
+ Epoch 10: cls_acc=0.78+ (best of phase-1)
43
+ ```
44
+
45
+ **What to look for:**
46
+ - Does cls_acc improve beyond v10 phase-1 peak (0.78)?
47
+ - How quickly does it converge?
48
+ - Does val_loss plateau or continue improving?
49
+
50
+ ---
51
+
52
+ ### 2. v11a_ov123_top4 - Route B + Spatial Demixer (Full Architecture)
53
+
54
+ **Use after v11_phase1_cls confirms improvement.**
55
+
56
+ **What it does:**
57
+ - Enables SpatialDeltaPatchAdapterV2 + trunk adapters + spatial_head_demixer
58
+ - Trains all heads (activity, class, direction, distance)
59
+ - Uses demixer for both class AND spatial heads
60
+ - Hot-starts from v9 ov123_top4 best checkpoint
61
+
62
+ **Command:**
63
+ ```bash
64
+ ./run_ov1_v11a_ov123_top4.sh
65
+ ```
66
+
67
+ **Key metrics:**
68
+ - `azi_mae_deg`: Azimuth mean absolute error (primary DOA metric)
69
+ - `class_acc`: Matched-source class accuracy
70
+ - `activity_f1`: Source presence F1-score
71
+
72
+ **Expected improvement:**
73
+ ```
74
+ Metric v9 Baseline v11a Target Expected Delta
75
+ ────────────────────────────────────────────────────────────────────
76
+ azi_mae_deg (train) 10° 8-9° -1 to -2°
77
+ azi_mae_deg (val) 30° 24-26° -4 to -6°
78
+ class_acc (val) 73% 75%+ +2%
79
+ ```
80
+
81
+ **What to look for:**
82
+ - Validation azimuth error should be significantly lower
83
+ - Train/val gap should narrow (from ~20° toward ~15°)
84
+ - No collapse in accuracy metrics
85
+
86
+ ---
87
+
88
+ ### 3. v11b_ov123_top4 - Route B + LocalSpatial Demixer KV
89
+
90
+ **Use for comparison with v11a.**
91
+
92
+ **What it does:**
93
+ - Same as v11a, BUT
94
+ - Demixer attends to LocalSpatial's 7-channel pre-pool (FOA + IV)
95
+ - Instead of BEATs mono mel-filterbank features
96
+ - Hypothesis: Spatial features better for DOA decomposition
97
+
98
+ **Command:**
99
+ ```bash
100
+ ./run_ov1_v11b_ov123_top4.sh
101
+ ```
102
+
103
+ **Comparison with v11a:**
104
+ ```
105
+ Aspect v11a (BEATs KV) v11b (LocalSpatial KV)
106
+ ─────────────────────────────────────────────────────────────────
107
+ Demixer KV source BEATs trunk LocalSpatial pre-pool
108
+ Channels 1 (mono fbank) 7 (4 FOA + 3 IV)
109
+ Prior knowledge Semantic Spatial physics
110
+ Expected advantage Better for class Better for direction
111
+ Computational cost Lower Higher
112
+ ```
113
+
114
+ **When to pick v11b over v11a:**
115
+ - If DOA error (azi_mae_deg) is more important than class accuracy
116
+ - If you have GPU budget for extra feature processing
117
+ - For acoustic scenes where spatial features matter more
118
+
119
+ ---
120
+
121
+ ### 4. v11c_ov123_accdoa - Paradigm Shift to ACCDOA (Route C)
122
+
123
+ **Use as a "simplicity first" baseline.**
124
+
125
+ **What it does:**
126
+ - Enables SpatialDeltaPatchAdapterV2 + trunk adapters
127
+ - Replaces query decoder + Hungarian matching with per-class ACCDOA heads
128
+ - Each class gets its own spatial slot (no matching needed)
129
+ - Activity encoded in vector magnitude, direction in unit vector
130
+
131
+ **Command:**
132
+ ```bash
133
+ ./run_ov1_v11c_ov123_accdoa.sh
134
+ ```
135
+
136
+ **Key differences from v11a:**
137
+ ```
138
+ Aspect v11a (Route B, Track) v11c (Route C, ACCDOA)
139
+ ─────────────────────────────────────────────────────────────────────
140
+ Paradigm K learnable tracks Per-class slots
141
+ Matching Hungarian (clip-level) None (inherent per-class)
142
+ Activity loss Binary cross-entropy MSE on magnitude
143
+ Direction repr. L2 normalized vector Unit vector (normalized)
144
+ Scalability O(K×T_s) per-frame O(num_classes×T_s)
145
+ ov2/ov3 fit Good (overlap ambiguity) Better (same-class=0)
146
+ ```
147
+
148
+ **When to pick v11c:**
149
+ - For DCASE evaluation (uses official SELD metrics)
150
+ - If Hungarian matching is a bottleneck
151
+ - For datasets with no overlapping same-class sources (ov2/ov3 constraints)
152
+ - For interpretability (each class = one direction)
153
+
154
+ ---
155
+
156
+ ## Decision Tree: Which v11 to Run?
157
+
158
+ ```
159
+ START
160
+
161
+ ├─→ "Do I want to diagnose if new adapters help classification?"
162
+ │ └─→ YES: Run v11_phase1_cls
163
+ │ ↓ (wait for results)
164
+ │ Does cls_acc improve?
165
+ │ ├─→ YES ✓
166
+ │ │ └─→ Proceed to multi-head experiments
167
+ │ └─→ NO ✗
168
+ │ └─→ Back to drawing board (architecture issue)
169
+
170
+ ├─→ "Is direction-of-arrival (DOA) error my primary concern?"
171
+ │ ├─→ YES: Need DOA focus
172
+ │ │ ├─→ "Do I have GPU budget for LocalSpatial features?"
173
+ │ │ │ ├─→ YES: Run v11b_ov123_top4
174
+ │ │ │ └─→ NO: Run v11a_ov123_top4
175
+ │ └─→ NO: Skip v11a/v11b
176
+
177
+ └─→ "Am I targeting DCASE evaluation / ov2/ov3 constraints?"
178
+ ├─→ YES: Run v11c_ov123_accdoa
179
+ └─→ NO: Run v11a_ov123_top4 (default full-featured)
180
+ ```
181
+
182
+ ---
183
+
184
+ ## Monitoring Experiments
185
+
186
+ ### Key Metrics to Track
187
+
188
+ **Classification**:
189
+ - `class_acc`: Top-1 accuracy on matched sources
190
+ - `class_precision`: Per-class precision
191
+ - `class_recall`: Per-class recall
192
+
193
+ **Direction (DOA)**:
194
+ - `azi_mae_deg`: **Primary metric** - azimuth mean absolute error
195
+ - `ele_mae_deg`: Elevation mean absolute error
196
+ - `azi_std_deg`: Azimuth error standard deviation
197
+
198
+ **Distance**:
199
+ - `dist_mae_m`: Distance mean absolute error
200
+
201
+ **Activity**:
202
+ - `activity_f1`: Source presence F1-score
203
+ - `num_active_mae`: Mean absolute error in source count
204
+
205
+ **Gap Analysis**:
206
+ - `train_azi_mae_deg`: Training set azimuth error
207
+ - `val_azi_mae_deg`: Validation set azimuth error
208
+ - `gap = val - train`: **Gap should decrease with v11**
209
+
210
+ ### TensorBoard Visualization
211
+
212
+ ```bash
213
+ tensorboard --logdir=checkpoints/spatial_beats_v11_phase1_cls_exp/ov123_top4 --port=6006
214
+ ```
215
+
216
+ **Plots to monitor:**
217
+ - `metrics/val_azi_mae_deg`: Should decrease smoothly
218
+ - `metrics/train_azi_mae_deg`: Should decrease with training
219
+ - `loss/total`: Should follow training dynamics (may oscillate)
220
+ - `loss/frame_direction`: DOA-specific loss component
221
+
222
+ ---
223
+
224
+ ## Checkpoint Management
225
+
226
+ ### Hot-Start Strategy
227
+
228
+ Each v11 variant is designed to hot-start from a previous checkpoint:
229
+
230
+ **v11_phase1_cls**:
231
+ ```
232
+ Loads from: v10_phase1_cls best.pt
233
+ Missing params: V2 adapter + trunk adapters
234
+ Initialize with: Zero-init adapters (identity at epoch-0)
235
+ Benefit: Inherits v10's frozen classification features
236
+ ```
237
+
238
+ **v11a_ov123_top4**:
239
+ ```
240
+ Loads from: v9_ov123_top4 best.pt
241
+ Missing params: V2 + trunk adapters + spatial_demixer (added to heads)
242
+ Initialize with: Zero-init everything (identity at epoch-0)
243
+ Benefit: Inherits v9's proven multi-head balance
244
+ ```
245
+
246
+ **v11b_ov123_top4**:
247
+ ```
248
+ Same as v11a, but adds LocalSpatial pre-pool processing
249
+ ```
250
+
251
+ **v11c_ov123_accdoa**:
252
+ ```
253
+ Loads from: ov1_local_spatial baseline (v9 incompatible)
254
+ Missing params: ACCDOAHeads (entire head replacement)
255
+ Initialize with: Zero-init (no class/spatial heads to inherit)
256
+ Benefit: Simpler routing = faster convergence
257
+ ```
258
+
259
+ ### How to Load a Checkpoint Manually
260
+
261
+ ```python
262
+ import torch
263
+ from train_spatial_beats import make_ov1_local_spatial_v11a_ov123_top4_config
264
+ from spatial_beats import SpatialBEATs
265
+
266
+ # Create model with v11a config
267
+ cfg = make_ov1_local_spatial_v11a_ov123_top4_config()
268
+ model = SpatialBEATs(cfg)
269
+
270
+ # Load v9 checkpoint (strict=False ignores new params)
271
+ ckpt = torch.load('checkpoints/.../v9_best.pt')
272
+ model.load_state_dict(ckpt['model'], strict=False)
273
+
274
+ # New params are zero-initialized (identity behavior)
275
+ # Ready to train!
276
+ model.train()
277
+ ```
278
+
279
+ ---
280
+
281
+ ## Troubleshooting
282
+
283
+ ### Issue: "CUDA out of memory"
284
+ **Solution**: Reduce batch size or sequence length
285
+ ```bash
286
+ BATCH_SIZE=4 ./run_ov1_v11a_ov123_top4.sh
287
+ ```
288
+
289
+ ### Issue: "ClassHeadSpectralDemixer not initialized"
290
+ **Solution**: Ensure config enables it:
291
+ ```python
292
+ cfg.use_class_head_demixer = True # For v11a
293
+ cfg.use_spatial_head_demixer = True # For v11a (added in v11)
294
+ ```
295
+
296
+ ### Issue: "Large train/val gap not shrinking"
297
+ **Diagnosis steps**:
298
+ 1. Check if Dropout is OFF during evaluation
299
+ 2. Verify SpecAugment is applied only during training
300
+ 3. Run diagnostic: evaluate same checkpoint in train/eval modes
301
+ ```bash
302
+ python -c "
303
+ model.eval()
304
+ val_error_no_dropout = evaluate(model, val_loader)
305
+ model.train()
306
+ val_error_with_dropout = evaluate(model, val_loader)
307
+ print(f'Dropout effect: {val_error_with_dropout - val_error_no_dropout:.1f}°')
308
+ "
309
+ ```
310
+
311
+ ### Issue: "Trunk adapters not being applied"
312
+ **Check**: Verify config flag is True
313
+ ```python
314
+ if not cfg.use_trunk_spatial_adapters:
315
+ print("WARNING: Trunk adapters disabled!")
316
+ cfg.use_trunk_spatial_adapters = True
317
+ ```
318
+
319
+ ---
320
+
321
+ ## Next Steps After v11 Experiments
322
+
323
+ 1. **Analyze results** (docs/V11_IMPLEMENTATION_SUMMARY.md contains diagnostic templates)
324
+ 2. **Pick best variant** based on your primary metric
325
+ 3. **Fine-tune hyperparameters** (learning rate, dropout rate if you modify later)
326
+ 4. **Run official evaluation** on test set using DCASE metrics
327
+ 5. **Consider multi-stage training**:
328
+ - Stage 1: Classification only (v11_phase1_cls)
329
+ - Stage 2: Full pipeline (v11a/b/c)
330
+ - Stage 3: Fine-tuning (reduce LR, increase epochs)
331
+
332
+ ---
333
+
334
+ ## Citation & References
335
+
336
+ This architecture is built on:
337
+ - **BEATs** (Microsoft): Base semantic encoder (https://arxiv.org/abs/2212.09058)
338
+ - **DCASE SELD**: Official evaluation metrics (https://github.com/sharathadavanne/seld-dcase2023)
339
+ - **EINV2 paradigm**: Track-based source modeling
340
+ - **Spatial audio physics**: FOA (First-Order Ambisonics) + Intensity Vectors
341
+
342
+ For detailed technical justification, see:
343
+ - docs/V11_IMPLEMENTATION_SUMMARY.md
344
+ - docs/doa_train_valid_gap_analysis.md
345
+ - SPATIAL_AUDIO_FRAMEWORKS_ANALYSIS_COMPREHENSIVE.md
docs/gemini.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Spatial-BEATs 最终实施指南 (Reference Implementation Guide)
2
+
3
+ 本文档定义了 `Spatial-BEATs` 的模型架构、特征工程与训练流程的最终技术细节,作为代码实现的唯一参照。
4
+
5
+ ## 1. 模型架构细节 (Architecture Specification)
6
+
7
+ ### 1.1 输入前端 (Stem)
8
+ - **输入特征图**: $7 \times 128 \times 1024$ (Channels $\times$ Mel-bins $\times$ Time-frames)。
9
+ - **通道定义**:
10
+ - `[0:4]`: W, X, Y, Z 的 Log-mel。
11
+ - `[4:7]`: IVx, IVy, IVz (Intensity Vector),按时间/频率对齐。
12
+ - **Patch Embedding**:
13
+ - 结构: `nn.Conv2d(7, embed_dim, kernel_size=16, stride=16)`。
14
+ - 初始化: 通道 0 (W) 复用 BEATs 预训练权重,通道 1-6 随机初始化。
15
+
16
+ ### 1.2 空间 Token 提取 (Source Queries)
17
+ - **Token 数量 ($K$)**: 4 个。
18
+ - **实现方式**:
19
+ - 定义 `nn.Parameter(torch.randn(1, 4, embed_dim))` 作为 Source Queries。
20
+ - 使用 2 层 Transformer Decoder 层。
21
+ - **Query**: Source Queries。
22
+ - **Key/Value**: BEATs Trunk 的输出序列 (Dense Patch Tokens)。
23
+ - **输出**: 4 个维度为 `embed_dim` 的 `Spatial Tokens`。
24
+
25
+ ### 1.3 预测头 (Prediction Heads)
26
+ 每个 Spatial Token 独立连接以下 MLP 层:
27
+ - **Objectness**: `Linear -> Sigmoid` (1 unit)。
28
+ - **Azimuth**: `Linear -> tanh` (2 units: $\sin, \cos$)。计算角度使用 `atan2`。
29
+ - **Elevation**: `Linear -> tanh` (2 units: $\sin, \cos$)。
30
+ - **Distance**: `Linear` (1 unit, 单位:**Centimeters**)。
31
+ - **Class**: `Linear -> Softmax` (N units, 对应 FSD50k 类别)。
32
+
33
+ ## 2. 坐标系与物理特征 (Spatial Physics)
34
+
35
+ ### 2.1 坐标系 (DCASE Standard)
36
+ - **轴向**: +x 前, +y 左, +z 上。
37
+ - **方位角 (Azimuth)**: $[-180, 180]$,逆时针增加。+90 度为左,-90 度为右。
38
+ - **仰角 (Elevation)**: $[-90, 90]$,向上增加。
39
+ - **距离 (Distance)**: 以 **厘米 (cm)** 为单位进行回归。
40
+
41
+ ### 2.2 IV 计算 (Intensity Vector)
42
+ 在特征提取阶段,按以下逻辑计算 IV:
43
+ - $I_x = \text{Re}\{W^* \cdot X\}$
44
+ - $I_y = \text{Re}\{W^* \cdot Y\}$
45
+ - $I_z = \text{Re}\{W^* \cdot Z\}$
46
+ - 所有的 $I$ 均通过 Mel 滤波器组进行映射,以匹配 Log-mel 的分辨率。
47
+
48
+ ## 3. 训练策略 (Training Recipe)
49
+
50
+ ### 3.1 损失函数 (Hungarian Loss)
51
+ - **匹配算法**: 使用 `scipy.optimize.linear_sum_assignment` (Hungarian Matching) 匹配 4 个预测 Token 与 $N$ 个 GT 声源 ($N \le 4$)。
52
+ - **匹配代价 (Matching Cost)**: 综合位置误差 (Az/El/Dist)、类别误差和 Objectness 分数。
53
+ - **总损失**:
54
+ - 对匹配成功的 Token:计算 $L_{MSE}(pos) + L_{BCE}(obj) + L_{CrossEntropy}(cls)$。
55
+ - 对未匹配的 Token:计算 $L_{BCE}(obj, 0)$。
56
+
57
+ ### 3.2 训练阶段
58
+ 1. **Stage 1 (Stem & Head Warmup)**: 冻结 BEATs Trunk (Transformer 层),仅训练新 Patch Embedding 和 Spatial Decoder/Heads。
59
+ 2. **Stage 2 (Joint Fine-tuning)**: 以 $1 \times 10^{-5}$ 的低学习率解冻整个 Trunk 进行微调。
60
+
61
+ ## 4. LLM 接入接口 (LLM Interface)
62
+ - 提取后的 4 个 `Spatial Tokens` 将通过一个 `Linear` 投影层对齐到 LLM 的隐藏层空间。
63
+ - 在 Prompt 中,这 4 个 tokens 将按 object-wise 顺序排列,代表音频中的空间实体。
docs/spatial_beats_implementation_spec.md ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Spatial-BEATs 实现规格
2
+
3
+ ## 1. 目标
4
+
5
+ 本规格文档用于将前期讨论收敛为一个可以直接实施的 `Spatial-BEATs` 方案。
6
+
7
+ 目标是构建一个独立的 `Spatial Encoder`:
8
+
9
+ - 输入为完整 `FOA` 音频及其派生空间特征
10
+ - 完整的 `FOA` 特征经过 `BEATs backbone`
11
+ - 最大化复用 `BEATs` 预训练权重
12
+ - 输出一组 `source-level spatial tokens`
13
+ - 这些 token 作为独立模态输入给 LLM
14
+ - 原有语义 audio encoder 保持不动
15
+
16
+ 这里的关键原则是:
17
+
18
+ > 不是让 `W-only` 走主干,再外挂一个小空间 adapter;而是让完整 FOA 空间特征真正进入 BEATs 主干,并在主干之后产出结构化空间 token。
19
+
20
+ ## 2. 最终任务定义
21
+
22
+ ### 2.1 核心任务
23
+
24
+ `Spatial-BEATs` 的主任务定义为:
25
+
26
+ - 给定一个多源 `FOA` 音频片段
27
+ - 预测其中最多 `K` 个潜在声源的空间表示
28
+ - 每个表示对应一个 `source token`
29
+
30
+ 每个 source token 至少承载:
31
+
32
+ - `objectness`
33
+ - `azimuth`
34
+ - `elevation`
35
+ - `distance`
36
+
37
+ 可选承载:
38
+
39
+ - `source class auxiliary logits`
40
+ - `source embedding`
41
+
42
+ ### 2.2 推荐监督形式
43
+
44
+ 如果训练数据中每个源都有标注,则推荐采用:
45
+
46
+ - `set prediction`
47
+ - `K` 个预测 token 对 `N` 个 GT sources
48
+ - 用 `Hungarian matching` 做一一匹配
49
+
50
+ 不建议采用:
51
+
52
+ - 单一 scene-level spatial token
53
+ - 仅回归整段音频的全局空间摘要
54
+
55
+ 原因是这会损失多源结构,不利于后续 LLM 做关系推理。
56
+
57
+ ## 3. 最终架构
58
+
59
+ 推荐最终架构:
60
+
61
+ ```text
62
+ FOA waveform
63
+ -> SpatialBEATsPreprocessor
64
+ -> FOA feature map [B, C_foa, T, F]
65
+ -> FOA patch embedding
66
+ -> BEATs trunk
67
+ -> Spatial query decoder
68
+ -> K source tokens
69
+ -> Spatial prediction heads
70
+ -> LLM projector
71
+ ```
72
+
73
+ 为了最大化复用 BEATs 主干,本方案尽量不改 trunk 内部的 Transformer 结构。
74
+
75
+ ## 4. 输入特征定义
76
+
77
+ ### 4.1 默认推荐特征
78
+
79
+ 第一版推荐输入通道:
80
+
81
+ - `W_logmel`
82
+ - `X_logmel`
83
+ - `Y_logmel`
84
+ - `Z_logmel`
85
+ - `IVx`
86
+ - `IVy`
87
+ - `IVz`
88
+
89
+ 即:
90
+
91
+ - `C_foa = 7`
92
+
93
+ 这是默认推荐方案。
94
+
95
+ ### 4.2 备选输入特征
96
+
97
+ 若希望先降低复杂度,可以使用:
98
+
99
+ - `WXYZ logmel`
100
+
101
+ 即:
102
+
103
+ - `C_foa = 4`
104
+
105
+ 但这只适合最小原型。
106
+ 如果目标是稳定学习空间方向与结构,优先使用 `WXYZ + IV`。
107
+
108
+ ### 4.3 前端参数建议
109
+
110
+ 为了最大化复用 BEATs 主干,推荐保持与 BEATs 接近的时频分辨率:
111
+
112
+ - sample rate:优先 `16k`
113
+ - mel bins:`128`
114
+ - frame length:`25 ms`
115
+ - frame shift:`10 ms`
116
+
117
+ 原因:
118
+
119
+ - 这能让 trunk 看到与原始 BEATs 更接近的 patch 几何结构
120
+ - patch embedding 和后续序列长度更容易保持一致
121
+ - 预训练权重复用更稳定
122
+
123
+ ### 4.4 为什么不沿用 Spatial-AST 的 binaural 前端
124
+
125
+ Spatial-AST 采用的是:
126
+
127
+ - 双耳 log-mel
128
+ - IPD
129
+
130
+ 这适合 binaural,不适合直接迁移到 FOA。
131
+
132
+ FOA 下应优先利用:
133
+
134
+ - ambisonic 通道本身
135
+ - intensity vector
136
+ - 其他 FOA 物理特征
137
+
138
+ ## 5. 对 BEATs 具体修改哪些模块
139
+
140
+ 下面按模块说明修改方案。
141
+
142
+ ### 5.1 保留不动的模块
143
+
144
+ 建议尽量保留:
145
+
146
+ - `TransformerEncoder`
147
+ - `TransformerSentenceEncoderLayer`
148
+ - `MultiheadAttention`
149
+ - `conv_pos`
150
+ - `LayerNorm`
151
+ - `FFN`
152
+ - `post_extract_proj`
153
+
154
+ 也就是 `backbone.py` 内的主干结构和 `BEATs.py` 中的 trunk 逻辑尽量不动。
155
+
156
+ ### 5.2 必须修改的模块
157
+
158
+ 必须重做:
159
+
160
+ 1. `preprocess`
161
+ 2. `patch_embedding`
162
+ 3. `extract_features` 输出头部逻辑
163
+ 4. 下游 `predictor`
164
+
165
+ ### 5.3 推荐新增的模块
166
+
167
+ 建议新增:
168
+
169
+ 1. `SpatialBEATsPreprocessor`
170
+ 2. `SpatialPatchEmbedding`
171
+ 3. `SpatialQueryDecoder`
172
+ 4. `SpatialPredictionHead`
173
+ 5. `SpatialTokenProjector`
174
+ 6. `HungarianMatcher`
175
+ 7. `SpatialSetCriterion`
176
+
177
+ ## 6. 代码级映射建议
178
+
179
+ ### 6.1 现有文件建议
180
+
181
+ 建议保留和复用:
182
+
183
+ - [BEATs.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/BEATs.py)
184
+ - [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py)
185
+
186
+ 建议新增:
187
+
188
+ - `spatial_beats.py`
189
+ - `spatial_modules.py`
190
+ - `spatial_loss.py`
191
+ - `spatial_dataset.py`
192
+ - `train_spatial_beats.py`
193
+
194
+ ### 6.2 `spatial_beats.py` 建议包含
195
+
196
+ 建议实现:
197
+
198
+ - `SpatialBEATsConfig`
199
+ - `SpatialBEATs`
200
+ - `SpatialBEATs.extract_spatial_tokens()`
201
+ - `SpatialBEATs.forward()`
202
+
203
+ ### 6.3 `spatial_modules.py` 建议包含
204
+
205
+ 建议实现:
206
+
207
+ - `SpatialBEATsPreprocessor`
208
+ - `SpatialPatchEmbedding`
209
+ - `SpatialQueryDecoder`
210
+ - `SpatialPredictionHead`
211
+ - `SpatialTokenProjector`
212
+
213
+ ### 6.4 `spatial_loss.py` 建议包含
214
+
215
+ 建议实现:
216
+
217
+ - `HungarianMatcher`
218
+ - `SpatialSetCriterion`
219
+
220
+ ## 7. 预训练权重如何复用
221
+
222
+ ## 7.1 默认推荐权重
223
+
224
+ 默认推荐:
225
+
226
+ - `BEATs_iter3+ (AS2M) pre-trained`
227
+
228
+ 而不是:
229
+
230
+ - fine-tuned checkpoints
231
+
232
+ 原因:
233
+
234
+ - `pre-trained` 更适合作为 trunk 初始化
235
+ - `fine-tuned` 更偏向 AudioSet 分类判别
236
+ - 你这里的 spatial encoder 应与原语义 encoder 职责分离
237
+
238
+ ### 7.2 必须直接加载的层
239
+
240
+ 这些层建议直接加载原 BEATs checkpoint:
241
+
242
+ - `post_extract_proj`
243
+ - `encoder.pos_conv`
244
+ - `encoder.layers.*`
245
+ - `encoder.layer_norm`
246
+ - `layer_norm`
247
+
248
+ 即除了输入 stem 和输出头,主干参数都尽量继承。
249
+
250
+ ### 7.3 需要特殊初始化的层
251
+
252
+ 以下层因为 shape 不同,不能直接 strict load:
253
+
254
+ - `patch_embedding`
255
+ - 新增的 `query decoder`
256
+ - 新增的 `spatial heads`
257
+ - 新增的 `LLM projector`
258
+
259
+ ### 7.4 新 patch embedding 的初始化策略
260
+
261
+ 原 BEATs stem 是:
262
+
263
+ - `Conv2d(1, embed_dim, kernel_size=patch, stride=patch)`
264
+
265
+ 新 stem 建议是:
266
+
267
+ - `Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)`
268
+
269
+ 推荐初始化策略:
270
+
271
+ #### 方案 A:保守初始化,默认推荐
272
+
273
+ - `W_logmel` 通道继承原 stem 权重
274
+ - 其他空间通道初始化为 `0` 或较小随机值
275
+
276
+ 优点:
277
+
278
+ - 最大程度保留原 BEATs 初始分布
279
+ - trunk 适配更稳
280
+
281
+ 缺点:
282
+
283
+ - 训练初期空间通道利用较慢
284
+
285
+ #### 方案 B:通道 inflation
286
+
287
+ - 把原 stem 权重复制到全部输入通道
288
+ - 再按通道数做归一化
289
+
290
+ 优点:
291
+
292
+ - 所有通道一开始都能进入主干
293
+
294
+ 缺点:
295
+
296
+ - 初始统计更可能偏离原 BEATs
297
+
298
+ 最终推荐:
299
+
300
+ - 第一版用 `方案 A`
301
+ - 后续做 ablation 再比较 `方案 B`
302
+
303
+ ## 8. Spatial token 模块的最终设计
304
+
305
+ ### 8.1 为什么不用全局池化
306
+
307
+ 原始 BEATs 的输出方式更接近:
308
+
309
+ - patch sequence
310
+ - mean pooling
311
+ - clip-level prediction
312
+
313
+ 这不适合多源空间任务。
314
+
315
+ ### 8.2 最终推荐:Query Decoder
316
+
317
+ 在 trunk 输出后新增:
318
+
319
+ - `K` 个 learnable source queries
320
+ - 一个轻量 `cross-attention decoder`
321
+
322
+ 输入:
323
+
324
+ - encoder memory:`H in R^{B x T x D}`
325
+ - source queries:`Q in R^{B x K x D}`
326
+
327
+ 输出:
328
+
329
+ - `Z in R^{B x K x D}`
330
+
331
+ 这里的 `Z[:, i, :]` 即第 `i` 个 `source token`
332
+
333
+ ### 8.3 为什么 query decoder 是当前最优解
334
+
335
+ 它的优点:
336
+
337
+ - 不改 trunk 内部结构
338
+ - 仍然让完整 FOA 特征经过 backbone
339
+ - 适合多源 set prediction
340
+ - 最利于最大化复用 trunk 权重
341
+
342
+ ## 9. 输出头设计
343
+
344
+ 对每个 source token `z_i`,预测:
345
+
346
+ - `objectness`
347
+ - `azimuth`
348
+ - `elevation`
349
+ - `distance`
350
+ - 可选 `class_aux`
351
+
352
+ ### 9.1 离散还是连续
353
+
354
+ 第一版推荐全部使用离散分类头:
355
+
356
+ - `azimuth`: 360 bins
357
+ - `elevation`: 180 bins
358
+ - `distance`: 按数据分桶,例如 `0.5m` 一档
359
+
360
+ 原因:
361
+
362
+ - 与已有 Spatial-AST/BAT 经验一致
363
+ - 分类头更稳
364
+ - 更便于构造离散坐标 embedding
365
+
366
+ ### 9.2 objectness 头
367
+
368
+ 推荐增加:
369
+
370
+ - `objectness_head: D -> 1`
371
+
372
+ 用于:
373
+
374
+ - 判断当前 token 是否对应真实声源
375
+ - 作为 Hungarian matching 的一部分
376
+ - 推理时做 token 保留/裁剪
377
+
378
+ ### 9.3 类别头
379
+
380
+ 类别头建议作为:
381
+
382
+ - `auxiliary head`
383
+
384
+ 而不是最终 LLM 的主要输入内容。
385
+
386
+ 这样做的作用:
387
+
388
+ - 让 query token 更容易学会 source slot 对齐
389
+ - 但不把 Spatial-BEATs 变成第二个强语义 encoder
390
+
391
+ ## 10. Loss 设计
392
+
393
+ 推荐总损失:
394
+
395
+ ```text
396
+ L_total =
397
+ lambda_obj * L_obj
398
+ + lambda_azi * L_azi
399
+ + lambda_ele * L_ele
400
+ + lambda_dist * L_dist
401
+ + lambda_cls * L_cls_aux
402
+ ```
403
+
404
+ ### 10.1 匹配方式
405
+
406
+ 使用 `Hungarian matching`:
407
+
408
+ - 预测:`K` 个 token
409
+ - GT:`N` 个 sources
410
+ - 成本由以下项构成:
411
+ - objectness cost
412
+ - azimuth cost
413
+ - elevation cost
414
+ - distance cost
415
+ - optional class cost
416
+
417
+ ### 10.2 损失项定义
418
+
419
+ 推荐:
420
+
421
+ - `L_obj`: BCE 或 focal loss
422
+ - `L_azi`: cross entropy
423
+ - `L_ele`: cross entropy
424
+ - `L_dist`: cross entropy
425
+ - `L_cls_aux`: cross entropy 或 BCE
426
+
427
+ ### 10.3 初始 loss 权重建议
428
+
429
+ 第一版建议从以下权重起步:
430
+
431
+ ```text
432
+ lambda_obj = 1.0
433
+ lambda_azi = 2.0
434
+ lambda_ele = 2.0
435
+ lambda_dist = 1.0
436
+ lambda_cls = 0.25
437
+ ```
438
+
439
+ 解释:
440
+
441
+ - 方向任务通常更关键
442
+ - 距离次之
443
+ - objectness 必须稳定
444
+ - 类别监督只作为辅助
445
+
446
+ ### 10.4 不建议的做法
447
+
448
+ 第一版不建议:
449
+
450
+ - 重分类损失压倒空间损失
451
+ - 直接照搬 Spatial-AST 的 `1250 * cls`
452
+
453
+ 原因:
454
+
455
+ - Spatial-AST 的目标之一是保住 sound event detection
456
+ - 这里 `Spatial-BEATs` 的主要目标是空间 token
457
+ - 原项目已有独立语义 encoder
458
+
459
+ ## 11. 训练策略
460
+
461
+ ### 11.1 第一阶段是否需要 SSL
462
+
463
+ 当前最终结论:
464
+
465
+ - 第一版 **不需要** 重新做 BEATs 式 SSL
466
+
467
+ 因为当前已经有:
468
+
469
+ - 多源监督
470
+ - 每个源的空间标注
471
+ - 可复用的 BEATs 主干预训练
472
+
473
+ 所以第一阶段应优先做:
474
+
475
+ - `supervised multi-source spatial training`
476
+
477
+ ### 11.2 分阶段训练建议
478
+
479
+ #### Stage A:Warmup
480
+
481
+ 冻结:
482
+
483
+ - 大部分 trunk
484
+
485
+ 只训练:
486
+
487
+ - FOA preprocessor
488
+ - patch embedding
489
+ - query decoder
490
+ - spatial heads
491
+ - LLM projector
492
+
493
+ 目的:
494
+
495
+ - 让新输入 stem 和新输出头稳定接入 trunk
496
+
497
+ #### Stage B:Upper-trunk finetune
498
+
499
+ 解冻:
500
+
501
+ - trunk 上层若干层
502
+
503
+ 目的:
504
+
505
+ - 让主干逐步适应 FOA 空间任务
506
+
507
+ #### Stage C:Near-full finetune
508
+
509
+ 进一步解冻:
510
+
511
+ - 更多 encoder layers
512
+
513
+ 目的:
514
+
515
+ - 提升空间表示上限
516
+
517
+ ### 11.3 学习率建议
518
+
519
+ 推荐:
520
+
521
+ - trunk:较小 lr
522
+ - 新模块:较大学习率
523
+
524
+ 例如:
525
+
526
+ ```text
527
+ lr_trunk = 1e-5 ~ 5e-5
528
+ lr_new = 1e-4 ~ 5e-4
529
+ ```
530
+
531
+ 并配合:
532
+
533
+ - layer-wise lr decay
534
+
535
+ ## 12. 最终输出给 LLM 的 spatial token 形式
536
+
537
+ 这是本项目最关键的接口定义之一。
538
+
539
+ ### 12.1 内部 token 形式
540
+
541
+ `Spatial-BEATs` 内部输出:
542
+
543
+ - `Z in R^{B x K x D}`
544
+
545
+ 其中:
546
+
547
+ - `B`: batch size
548
+ - `K`: source token 数
549
+ - `D`: Spatial-BEATs hidden dim,建议与 BEATs trunk 一致
550
+
551
+ ### 12.2 不建议直接把 raw logits 喂给 LLM
552
+
553
+ 不建议直接给 LLM:
554
+
555
+ - azimuth logits
556
+ - elevation logits
557
+ - distance logits
558
+ - objectness logits
559
+
560
+ 这些是监督头,不是最终模态表示。
561
+
562
+ ### 12.3 最终推荐的 LLM spatial token 形式
563
+
564
+ 最终推荐送给 LLM 的每个 token 形式为:
565
+
566
+ ```text
567
+ s_i = Proj([z_i ; e_azi(i) ; e_ele(i) ; e_dist(i) ; e_obj(i)])
568
+ ```
569
+
570
+ 其中:
571
+
572
+ - `z_i`: query decoder 输出的 latent token
573
+ - `e_azi(i)`: 由预测 azimuth bin 查表得到的 embedding
574
+ - `e_ele(i)`: 由预测 elevation bin 查表得到的 embedding
575
+ - `e_dist(i)`: 由预测 distance bin 查表得到的 embedding
576
+ - `e_obj(i)`: 由 objectness/confidence 产生的 embedding
577
+ - `Proj`: 投影到 LLM hidden size 的 MLP/Linear
578
+
579
+ 最终:
580
+
581
+ - `s_i in R^{d_llm}`
582
+
583
+ ### 12.4 为什么采用“latent + structured embedding”的混合形式
584
+
585
+ 原因:
586
+
587
+ 1. `z_i` 保留丰富的隐式空间结构信息
588
+ 2. `坐标 embedding` 给 LLM 显式离散空间线索
589
+ 3. `confidence` 有助于 LLM 区分可靠/不可靠 token
590
+
591
+ 这比单纯只传:
592
+
593
+ - raw latent token
594
+
595
+ 或者只传:
596
+
597
+ - 显式坐标 one-hot / scalar
598
+
599
+ 都更合适。
600
+
601
+ ### 12.5 最终序列形式
602
+
603
+ 送入 LLM 时推荐:
604
+
605
+ ```text
606
+ <SPATIAL_START>, s_1, s_2, ..., s_K, <SPATIAL_END>
607
+ ```
608
+
609
+ 并且:
610
+
611
+ - 按 `objectness` 从高到低排序
612
+ - 对低置信 token 可直接截断或 mask
613
+
614
+ ### 12.6 是否保留全部 K 个 token
615
+
616
+ 默认推荐:
617
+
618
+ - 训练时保留全部 `K`
619
+ - 推理时按 `objectness` 过滤
620
+
621
+ 例如:
622
+
623
+ - 保留前 `K_keep`
624
+ - 或保留 `obj > threshold` 的 token
625
+
626
+ ## 13. 与原语义 audio encoder 的关系
627
+
628
+ 为了避免“两个 encoder 在做同样的事”,推荐如下职责划分:
629
+
630
+ - 原语义 audio encoder:负责 `what`
631
+ - Spatial-BEATs:负责 `where / spatial structure / relations`
632
+
633
+ ### 13.1 是否允许 Spatial-BEATs 学类别
634
+
635
+ 允许,但只作为辅助。
636
+
637
+ 建议:
638
+
639
+ - 类别头只用于训练
640
+ - 最终输入给 LLM 的空间 token 不直接暴露完整类别 logits
641
+
642
+ ### 13.2 是否需要和语义 encoder 做对齐
643
+
644
+ 第一版不是必须。
645
+
646
+ 若后续希望更强的 source grounding,可进一步加入:
647
+
648
+ - semantic distillation
649
+ - cross-encoder alignment
650
+ - source-wise contrastive loss
651
+
652
+ 但这些应放到第二阶段。
653
+
654
+ ## 14. 第一版推荐配置
655
+
656
+ 第一版默认建议:
657
+
658
+ - 输入特征:`WXYZ + IVxyz`
659
+ - `C_foa = 7`
660
+ - 采样率:`16k`
661
+ - mel bins:`128`
662
+ - patch 配置:与 BEATs 保持一致
663
+ - 预训练权重:`BEATs_iter3+ AS2M pre-trained`
664
+ - trunk:最大化加载
665
+ - patch stem:`W` 继承,其余通道小初始化
666
+ - 输出:`K` 个 source tokens
667
+ - token 解码:轻量 query decoder
668
+ - 监督:Hungarian matching + 多头空间分类
669
+ - LLM 输入:`latent + structured coordinate embedding` 的混合 token
670
+
671
+ ## 15. 实现优先级
672
+
673
+ 推荐按如下优先级推进:
674
+
675
+ 1. 实现 `FOA preprocessor`
676
+ 2. 实现多通道 `patch embedding`
677
+ 3. 完成 trunk ckpt 加载
678
+ 4. 实现 `query decoder`
679
+ 5. 实现 `objectness / azi / ele / dist` heads
680
+ 6. 实现 `Hungarian matcher + criterion`
681
+ 7. 实现 `LLM projector`
682
+ 8. 完成训练脚本
683
+
684
+ ## 16. 当前仍需用户确认的问题
685
+
686
+ 以下问题会直接影响第一版实现细节:
687
+
688
+ 1. `FOA` 数据当前主要采样率是多少?是 `16k`、`24k`、`32k` 还是 `48k`?
689
+ 2. 每个样本中 `最大同时源数` 大概是多少?这会影响 `K` 的默认设定。
690
+ 3. 每个源是否都有 `source-level class label`?如果有,类别头和匹配会更稳。
691
+ 4. 你希望 `distance` 是离散分类还是连续回归?当前默认推荐离散分类。
692
+ 5. 下游 LLM 的 hidden size 是多少?是否已有固定的 audio token projector?
693
+ 6. 你是否希望 Spatial-BEATs 在第一版就具备一定的 source semantic 辅助能力,还是严格只做空间?
694
+
695
+ ## 17. 结论
696
+
697
+ 当前最终方案已经明确:
698
+
699
+ - **完整 FOA 特征进入 BEATs 主干**
700
+ - **最大化复用 trunk 预训练**
701
+ - **重做输入 stem**
702
+ - **重做输出为多源 spatial tokens**
703
+ - **第一版采用监督式 set prediction**
704
+ - **最终给 LLM 的不是 raw logits,而是融合 latent 与坐标 embedding 的 spatial tokens**
705
+
706
+ 这是当前最符合项目目标、也最稳妥的 `Spatial-BEATs` 方案。
docs/spatial_beats_training_overview.md ADDED
@@ -0,0 +1,608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Spatial-BEATs Training And Architecture Overview
2
+
3
+ This document summarizes the current `Spatial-BEATs` implementation in this repository:
4
+
5
+ - model architecture
6
+ - tensor shape flow
7
+ - dataset contract
8
+ - variable-length batching
9
+ - supervision and losses
10
+ - stage-1 training setup
11
+ - current `ov1/ov2/ov3` presets
12
+
13
+ The implementation described here corresponds to:
14
+
15
+ - [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py)
16
+ - [spatial_modules.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_modules.py)
17
+ - [spatial_dataset.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_dataset.py)
18
+ - [spatial_loss.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_loss.py)
19
+ - [train_spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/train_spatial_beats.py)
20
+ - [spatial_beats_ov123_stage1_config.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats_ov123_stage1_config.py)
21
+
22
+ ## 1. Goal
23
+
24
+ `Spatial-BEATs` is a separate spatial encoder for FOA audio.
25
+
26
+ It is designed to:
27
+
28
+ - reuse the BEATs backbone and pretrained weights
29
+ - take full FOA input instead of only the `W` channel
30
+ - learn spatial structure through explicit supervision
31
+ - output fixed-rate spatial tokens for an LLM
32
+ - stay separate from the original audio encoder used for semantic audio understanding
33
+
34
+ The current implementation follows the simplified design:
35
+
36
+ - the main objective is to train the FOA front-end and BEATs trunk to produce spatially informative embeddings
37
+ - the supervision heads are lightweight readout heads
38
+ - the final LLM tokens are taken from the encoder-side spatial embeddings, not from the final logits
39
+
40
+ ## 2. High-Level Architecture
41
+
42
+ The end-to-end model path is:
43
+
44
+ ```text
45
+ FOA waveform
46
+ -> FOA spatial preprocessor
47
+ -> multi-channel patch embedding
48
+ -> BEATs trunk
49
+ -> frequency pooling
50
+ -> temporal resampling to 2.5 Hz
51
+ -> shallow temporal readout
52
+ -> spatial embeddings
53
+ -> fixed-slot supervision heads
54
+ -> projector
55
+ -> LLM spatial tokens
56
+ ```
57
+
58
+ More concretely:
59
+
60
+ ```text
61
+ [B, 4, T]
62
+ -> [B, 7, T_f, 128]
63
+ -> [B, N_p, 512]
64
+ -> [B, N_p, 768]
65
+ -> [B, T_p, 768]
66
+ -> [B, T_s_max, 768]
67
+ -> [B, T_s_max, 768]
68
+ -> [B, T_s_max, 4, 768]
69
+ -> [B, T_s_max, d_llm]
70
+ ```
71
+
72
+ Where:
73
+
74
+ - `B`: batch size
75
+ - `T`: waveform length in samples
76
+ - `T_f`: acoustic frame count before patching
77
+ - `N_p`: number of BEATs patches
78
+ - `T_p`: time-axis patch count after frequency pooling
79
+ - `T_s_max`: padded token count in the batch after resampling to `2.5 Hz`
80
+ - `d_llm`: spatial token width sent to the LLM
81
+
82
+ ## 3. Input And Front-End
83
+
84
+ ### 3.1 Input audio
85
+
86
+ The model expects:
87
+
88
+ - FOA waveform
89
+ - shape `[B, 4, T]`
90
+ - channel order: `W, X, Y, Z`
91
+ - sample rate: `16 kHz`
92
+
93
+ ### 3.2 Qwen-like low-level mel setup
94
+
95
+ The current front-end is aligned to the Qwen-2.5-Omni audio tower style low-level parameters:
96
+
97
+ - `sample_rate = 16000`
98
+ - `num_mel_bins = 128`
99
+ - `n_fft = 400`
100
+ - `win_length = 400`
101
+ - `hop_length = 160`
102
+ - `dither = 0.0`
103
+
104
+ These parameters are shared between:
105
+
106
+ - `SpatialBEATsConfig`
107
+ - `SpatialDatasetConfig`
108
+
109
+ This keeps the data pipeline and the model front-end consistent.
110
+
111
+ ### 3.3 FOA feature construction
112
+
113
+ The preprocessor converts FOA waveform into a 7-channel feature map:
114
+
115
+ - `W_logmel`
116
+ - `X_logmel`
117
+ - `Y_logmel`
118
+ - `Z_logmel`
119
+ - `IVx`
120
+ - `IVy`
121
+ - `IVz`
122
+
123
+ Output shape:
124
+
125
+ - `foa_feat: [B, 7, T_f, 128]`
126
+
127
+ This allows the whole FOA structure to enter the backbone instead of relying on only `W`.
128
+
129
+ ## 4. Backbone And Spatial Embedding Path
130
+
131
+ ### 4.1 Spatial patch embedding
132
+
133
+ The model replaces the original single-channel patch stem with a 7-channel patch embedding:
134
+
135
+ - input: `foa_feat [B, 7, T_f, 128]`
136
+ - output: `patch_tokens [B, N_p, 512]`
137
+ - also returns `grid_size = (T_p, F_p)`
138
+
139
+ This is the first modified entry point for reusing BEATs on FOA input.
140
+
141
+ ### 4.2 Reused BEATs trunk
142
+
143
+ The trunk reuses BEATs pretrained components:
144
+
145
+ - `layer_norm`
146
+ - `post_extract_proj`
147
+ - `encoder.pos_conv`
148
+ - all transformer layers
149
+ - `encoder.layer_norm`
150
+
151
+ Flow:
152
+
153
+ - input: `patch_tokens [B, N_p, 512]`
154
+ - output: `encoder_memory [B, N_p, 768]`
155
+
156
+ ### 4.3 Frequency pooling
157
+
158
+ The patch sequence is reshaped back into a patch grid and pooled over the frequency axis:
159
+
160
+ - input: `encoder_memory [B, N_p, 768]` with `grid_size=(T_p, F_p)`
161
+ - reshaped internally to `[B, T_p, F_p, 768]`
162
+ - pooled output: `temporal_patch_tokens [B, T_p, 768]`
163
+
164
+ This produces a time-aligned sequence before the final token-rate conversion.
165
+
166
+ ### 4.4 Temporal resampling
167
+
168
+ The temporal resampler converts the patch-rate sequence into the final spatial token rate:
169
+
170
+ - target token rate: `2.5 Hz`
171
+ - per-sample target length:
172
+
173
+ ```text
174
+ T_s_i = round(duration_i * 2.5)
175
+ ```
176
+
177
+ Batch handling:
178
+
179
+ - each sample is resampled independently
180
+ - the batch is padded to `T_s_max = max_i(T_s_i)`
181
+ - a temporal mask is produced
182
+
183
+ Outputs:
184
+
185
+ - `temporal_tokens: [B, T_s_max, 768]`
186
+ - `temporal_padding_mask: [B, T_s_max]`
187
+
188
+ Mask convention:
189
+
190
+ - `False`: valid time step
191
+ - `True`: padded time step
192
+
193
+ ### 4.5 Shallow temporal readout
194
+
195
+ The shallow temporal readout refines the resampled sequence with a lightweight transformer encoder:
196
+
197
+ - input: `temporal_tokens [B, T_s_max, 768]`
198
+ - output: `spatial_embeddings [B, T_s_max, 768]`
199
+
200
+ This is the main representation used for both:
201
+
202
+ - spatial supervision
203
+ - final projection to LLM tokens
204
+
205
+ ## 5. Supervision Heads
206
+
207
+ The current stage-1 design does not use a heavy decoder.
208
+
209
+ Instead, it uses a fixed-slot readout for supervision only.
210
+
211
+ ### 5.1 Fixed-slot readout
212
+
213
+ The readout expands each time step into a small number of internal supervision slots:
214
+
215
+ - max slots per step: `K = 4`
216
+ - input: `spatial_embeddings [B, T_s_max, 768]`
217
+ - output: `slot_latents [B, T_s_max, 4, 768]`
218
+
219
+ Important:
220
+
221
+ - `K=4` is only a supervision capacity
222
+ - it does not change the final LLM token count
223
+ - the final LLM-visible token rate is still `2.5 Hz`
224
+
225
+ ### 5.2 Prediction heads
226
+
227
+ Each supervision slot predicts:
228
+
229
+ - `pred_activity: [B, T_s_max, 4]`
230
+ - `pred_azi_logits: [B, T_s_max, 4, 360]`
231
+ - `pred_ele_logits: [B, T_s_max, 4, 180]`
232
+ - `pred_dist: [B, T_s_max, 4, 1]`
233
+ - `pred_class_logits: [B, T_s_max, 4, C]`
234
+
235
+ Where:
236
+
237
+ - `C = 65`
238
+ - the class vocabulary comes from:
239
+ - `/apdcephfs_cq12/share_302080740/user/schmittzhu/data/fsd50k/FSD50K.ground_truth/final_vocabulary.csv`
240
+
241
+ These heads are used to supply explicit training loss and push the front-end plus BEATs trunk to learn spatial structure.
242
+
243
+ ## 6. LLM Spatial Tokens
244
+
245
+ The final LLM tokens are not taken from slot logits.
246
+
247
+ They are projected from the encoder-side spatial embeddings:
248
+
249
+ - input: `spatial_embeddings [B, T_s_max, 768]`
250
+ - output: `llm_spatial_tokens [B, T_s_max, d_llm]`
251
+
252
+ Therefore:
253
+
254
+ - `2.5 Hz` means final LLM-visible tokens arrive at `2.5 tokens/second`
255
+ - a `20 s` clip produces about `50` spatial tokens
256
+ - a `10 s` clip produces about `25` spatial tokens
257
+
258
+ This is the externally visible spatial token interface.
259
+
260
+ ## 7. Pretrained Weight Reuse
261
+
262
+ The model initializes from `BEATs_iter3+ AS2M`.
263
+
264
+ Current pretrained loading logic:
265
+
266
+ - selectively load BEATs trunk modules
267
+ - skip task-specific components that do not match
268
+ - inflate the old single-channel patch embedding into the new 7-channel stem
269
+
270
+ Patch stem initialization rule:
271
+
272
+ - original BEATs patch weight is copied into channel `0` of the new 7-channel stem
273
+ - remaining channels start from zero
274
+
275
+ This is a conservative initialization intended to preserve BEATs trunk stability while enabling FOA adaptation.
276
+
277
+ ## 8. Dataset Contract
278
+
279
+ ### 8.1 Supported manifests
280
+
281
+ The dataset loader currently supports:
282
+
283
+ - `ov1_foa.jsonl`
284
+ - `ov2_foa.jsonl`
285
+ - `ov3_foa.jsonl`
286
+
287
+ It handles:
288
+
289
+ - single-source top-level manifest style
290
+ - nested multi-source manifest style with `sources`
291
+
292
+ ### 8.2 Required scene-level data
293
+
294
+ At scene level the dataset expects one FOA path, typically:
295
+
296
+ - `output_foa_path`
297
+
298
+ or compatible fallback names already handled in the parser.
299
+
300
+ ### 8.3 Required source-level data
301
+
302
+ For each source, the loader extracts:
303
+
304
+ - source class
305
+ - azimuth
306
+ - elevation
307
+ - distance
308
+ - weak time window
309
+
310
+ Internally each source is converted into a `SourceEvent` containing:
311
+
312
+ - `class_index`
313
+ - `class_label`
314
+ - `azimuth_deg`
315
+ - `elevation_deg`
316
+ - `distance_m`
317
+ - `start_time_seconds`
318
+ - `end_time_seconds`
319
+
320
+ ### 8.4 Vocabulary mapping
321
+
322
+ Source labels are mapped to `final_vocabulary.csv`.
323
+
324
+ The loader supports several field aliases, including:
325
+
326
+ - `mono_target_label`
327
+ - `mono_primary_label`
328
+ - `final_label`
329
+ - `source_label`
330
+ - `label`
331
+
332
+ and several id-style aliases if an integer class index is already present.
333
+
334
+ ## 9. Variable-Length Batching
335
+
336
+ Handling mixed-length FOA clips is a core part of the current implementation.
337
+
338
+ ### 9.1 Waveform padding
339
+
340
+ At batch time:
341
+
342
+ - each waveform is padded to the batch maximum waveform length
343
+ - the padded tensor has shape `[B, 4, T_max]`
344
+ - a waveform padding mask is created:
345
+
346
+ ```text
347
+ waveform_padding_mask: [B, T_max]
348
+ ```
349
+
350
+ Mask convention:
351
+
352
+ - `False`: valid waveform sample
353
+ - `True`: padded sample
354
+
355
+ dui
356
+ ### 9.2 Temporal token padding
357
+
358
+ After temporal resampling:
359
+
360
+ - each sample has its own `T_s_i = round(duration_i * 2.5)`
361
+ - the batch is padded to `T_s_max`
362
+ - the model returns:
363
+
364
+ ```text
365
+ temporal_padding_mask: [B, T_s_max]
366
+ target_num_steps: [B]
367
+ ```
368
+
369
+ All temporal supervision, matching, and loss computation respect these lengths.
370
+
371
+ ### 9.3 Long clip truncation
372
+
373
+ The current training presets cap clip duration at:
374
+
375
+ - `20.0 seconds`
376
+
377
+ The dataset applies cropping before batching.
378
+
379
+ Preset crop policy:
380
+
381
+ - `crop_mode = "start"`
382
+
383
+ This means:
384
+
385
+ - clips longer than 20 seconds are truncated from the beginning
386
+ - training and validation follow the same deterministic sequence policy
387
+
388
+ If needed later, the dataset also supports:
389
+
390
+ - `random`
391
+ - `center`
392
+ - `none`
393
+
394
+ ## 10. Matching And Losses
395
+
396
+ ### 10.1 Weak temporal supervision
397
+
398
+ The model uses weak source windows:
399
+
400
+ - each source provides `start_time_seconds` and `end_time_seconds`
401
+ - these define a valid supervision window, not guaranteed frame-level activity
402
+
403
+ The loss code first converts source windows into a time-window mask:
404
+
405
+ ```text
406
+ window_mask: [B, N_gt, T_s_max]
407
+ ```
408
+
409
+ ### 10.2 Per-step fixed-slot matching
410
+
411
+ Matching is performed per time step:
412
+
413
+ - only on valid temporal positions
414
+ - only within each source's weak time window
415
+ - between active GT sources and the `K=4` slot predictions
416
+
417
+ The current matcher uses a detached cost built from:
418
+
419
+ - activity
420
+ - class
421
+ - azimuth
422
+ - elevation
423
+ - distance
424
+
425
+ The output contains the assigned GT target for each valid slot-time pair.
426
+
427
+ ### 10.3 Multi-task loss terms
428
+
429
+ Current loss terms are:
430
+
431
+ - `loss_activity`
432
+ - `loss_azi`
433
+ - `loss_ele`
434
+ - `loss_dist`
435
+ - `loss_cls_aux`
436
+ - `loss_temp`
437
+
438
+ Their roles:
439
+
440
+ - `loss_activity`
441
+ - `BCEWithLogits` on slot activity
442
+ - computed over valid time steps
443
+ - `loss_azi`
444
+ - cross-entropy over 360 azimuth bins
445
+ - `loss_ele`
446
+ - cross-entropy over 180 elevation bins
447
+ - `loss_dist`
448
+ - `SmoothL1Loss` on continuous distance regression
449
+ - `loss_cls_aux`
450
+ - auxiliary source class cross-entropy
451
+ - `loss_temp`
452
+ - temporal smoothness regularization over valid consecutive steps
453
+
454
+ The total loss is the weighted sum defined in `SpatialLossConfig`.
455
+
456
+ ## 11. Stage-1 Training Flow
457
+
458
+ The current training entry is stage-1 encoder-focused training.
459
+
460
+ High-level flow per step:
461
+
462
+ ```text
463
+ batch
464
+ -> SpatialBEATs.forward()
465
+ -> match_fixed_slots()
466
+ -> compute_spatial_losses()
467
+ -> backward()
468
+ -> optimizer.step()
469
+ ```
470
+
471
+ ### 11.1 Trainable modules in stage 1
472
+
473
+ By default, stage 1 trains:
474
+
475
+ - `preprocessor`
476
+ - `patch_embedding`
477
+ - `frequency_pool`
478
+ - `temporal_resampler`
479
+ - `temporal_readout`
480
+ - `slot_readout`
481
+ - `prediction_heads`
482
+
483
+ It can also unfreeze the BEATs trunk.
484
+
485
+ The projector is kept frozen by default in stage 1.
486
+
487
+ ### 11.2 Optimizer
488
+
489
+ The current trainer uses:
490
+
491
+ - `AdamW`
492
+
493
+ Default preset values:
494
+
495
+ - `batch_size = 4`
496
+ - `num_epochs = 20`
497
+ - `learning_rate = 1e-4`
498
+ - `weight_decay = 0.05`
499
+
500
+ ## 12. Current Presets
501
+
502
+ ### 12.1 `OV123_STAGE1_CFG`
503
+
504
+ Defined in:
505
+
506
+ - [spatial_beats_ov123_stage1_config.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats_ov123_stage1_config.py)
507
+
508
+ This preset is intended to train on:
509
+
510
+ - `ov1_foa.jsonl`
511
+ - `ov2_foa.jsonl`
512
+ - `ov3_foa.jsonl`
513
+
514
+ with split filtering:
515
+
516
+ - train: `("train",)`
517
+ - val: `("valid",)`
518
+ - test: `("test",)`
519
+
520
+ and clip truncation:
521
+
522
+ - `max_clip_duration_seconds = 20.0`
523
+ - `crop_mode = "start"`
524
+
525
+ ### 12.2 `OV23_STAGE1_CFG`
526
+
527
+ This is the safer baseline preset using only:
528
+
529
+ - `ov2_foa.jsonl`
530
+ - `ov3_foa.jsonl`
531
+
532
+ It uses the same split and truncation policy.
533
+
534
+ ### 12.3 Important note on `ov1`
535
+
536
+ The trainer is already written to use `split` filtering for `ov1`.
537
+
538
+ If the active `ov1` manifest at the configured path does not yet contain `split`, then:
539
+
540
+ - the `OV123` preset will not automatically include those samples in train, valid, or test
541
+ - the fix is simply to point the preset at the updated `ov1` manifest path
542
+
543
+ The code path itself already supports split-aware loading.
544
+
545
+ ## 13. Current Runtime Status
546
+
547
+ The current implementation has already been checked on:
548
+
549
+ - a real FOA waveform file
550
+ - mixed-length real manifest samples
551
+ - full forward pass
552
+ - fixed-slot matching
553
+ - multi-task loss computation
554
+ - BEATs pretrained weight loading
555
+
556
+ The following paths are already operational:
557
+
558
+ - dataset parsing
559
+ - waveform batching
560
+ - mixed-length temporal masking
561
+ - model forward
562
+ - matching
563
+ - loss computation
564
+ - stage-1 optimization loop
565
+
566
+ ## 14. Recommended Launch Pattern
567
+
568
+ Example usage:
569
+
570
+ ```python
571
+ from spatial_beats_ov123_stage1_config import OV123_STAGE1_CFG
572
+ from train_spatial_beats import main
573
+
574
+ main(OV123_STAGE1_CFG)
575
+ ```
576
+
577
+ If `ov1` still needs a different manifest path, update only:
578
+
579
+ ```python
580
+ OV123_STAGE1_CFG.train_manifest_paths
581
+ OV123_STAGE1_CFG.val_manifest_paths
582
+ OV123_STAGE1_CFG.test_manifest_paths
583
+ ```
584
+
585
+ or rebuild the config through:
586
+
587
+ ```python
588
+ from train_spatial_beats import make_ov123_stage1_config
589
+ ```
590
+
591
+ ## 15. Summary
592
+
593
+ The current `Spatial-BEATs` implementation is a FOA-first BEATs-based spatial encoder with:
594
+
595
+ - Qwen-like low-level mel settings
596
+ - a 7-channel FOA front-end
597
+ - reused BEATs trunk
598
+ - fixed-rate `2.5 Hz` spatial token output
599
+ - fixed-slot supervision heads
600
+ - variable-length batching
601
+ - split-aware `ov1/ov2/ov3` training presets
602
+
603
+ The central training idea is:
604
+
605
+ - use explicit spatial supervision to shape the front-end and BEATs trunk
606
+ - keep the supervision head lightweight
607
+ - use encoder-side spatial embeddings as the final source of LLM spatial tokens
608
+
docs/v13_honest_postmortem.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v13 系列实验 Postmortem — 诚实记录我的错误判断
2
+
3
+ > 日期:2026-05-02
4
+ > 目的:保留错误判断的证据,避免后续重蹈覆辙
5
+ > 作者说明:多次对实验状态做出错误诊断,本文记录这些错误及其根因
6
+
7
+ ---
8
+
9
+ ## 1. 核心错误:没注意 v13_D 的 cls warmup 是 8 ep 不是 3 ep
10
+
11
+ ### 事实
12
+
13
+ v13_D 配置:
14
+ ```python
15
+ cfg.frame_spatial_loss_warmup_epochs = 8 # v12 是 3
16
+ cfg.num_epochs = 25
17
+ ```
18
+
19
+ 这意味着 **ep0-7 全是 cls-only 训练**(spatial lambda 被 warmup 机制压为 warmup_scale),ep8 才是 spatial loss 真正放开的第一个 epoch。
20
+
21
+ ### v13_D 真实训练曲线
22
+
23
+ ```
24
+ ep F20 o_cls azi 阶段
25
+ 0 0.311 0.650 28.6° cls warmup, spatial loss 几乎为 0
26
+ 1 0.340 0.768 25.5° ...
27
+ 3 0.308 0.788 26.5° ...
28
+ 7 0.193 0.786 31.0° spatial 还没 ramp,azi 越走越远
29
+ 8 0.397 0.876 18.5° ★ spatial loss 真正放开,F 跳涨
30
+ 9 0.402 0.868 17.3°
31
+ 10 0.402 0.864 17.2° (best so far)
32
+ ```
33
+
34
+ ### 我的错误判断(两次对话前)
35
+
36
+ 看到 ep0-7 的数据,我写了:
37
+
38
+ > v13_D:ep1 就到 best,之后一路发散
39
+
40
+ > 这是完全不同的问题,不是过拟合 — 是 top-k rank loss 本身没用对。
41
+
42
+ > Top-K rank loss + EMA + resume optimizer + cosine LR 这四个改动叠加起来打架了
43
+
44
+ **完全错误**。真相:
45
+ - ep1 的 "best" 只是 cls warmup 期间的 F20 虚高值,不代表 spatial 性能
46
+ - F 从 ep1 → ep7 下降,是因为 spatial loss 被压住但 trunk 在学 class(trunk 权重变化 → 没有 spatial 监督 → azi 漂移)
47
+ - ep8 spatial loss 放开后 F 立刻 **0.19 → 0.40**(+107%)
48
+ - **Top-K rank、EMA、cosine LR、resume optimizer 全部工作正常**
49
+
50
+ ### 我为什么犯这个错
51
+
52
+ - 用 v12 的 3 ep warmup 经验直觉套 v13_D 的 8 ep warmup
53
+ - 没有先拉长查看 ep0-10 的完整曲线,只看前几个 epoch 就下诊断
54
+ - "F 在下降" 这个表面现象让我急于给出解释,没有对照 **ramp schedule**
55
+
56
+ ## 2. 次要错误:误判了 v13_B 的状态
57
+
58
+ v13_B 实际跑到 ep4 停下:
59
+
60
+ ```
61
+ ep F20 o_cls
62
+ 0 0.255 0.647
63
+ 1 0.292 0.771
64
+ 2 0.298 0.779
65
+ 3 0.357 0.776 ← spatial 放开第一个 epoch(warmup=3)
66
+ 4 0.356 0.775
67
+ ```
68
+
69
+ v13_B 的 warmup 是 3(与 v12 一致),所以 ep3 才是 spatial 放开第一个 epoch。ep3-4 刚放开 2 个 epoch,F 还没上涨空间。
70
+
71
+ ### 我的错误判断
72
+
73
+ > v13_B 的 ASL + soft-F1 在这个数据规模下不足以扭转 precision/recall 权衡
74
+
75
+ **错**:ep4 就停了训练,根本没给 ASL + soft-F1 时间学习。如果和 v12 同样跑 15 epoch,F 可能到 0.38-0.40 也说不定。
76
+
77
+ ### 更正
78
+
79
+ - v13_B 的结论应该是 **"跑不完整,不能下结论"**,而不是 "设计失败"
80
+ - 如果用户还想看 v13_B 的真实效果,应该重启并跑满 15+ epoch
81
+
82
+ ## 3. v13_C 的判断基本正确,但也需校准
83
+
84
+ v13_C 的 spatial warmup 也是 3 ep(继承 v12)。跑满 15 epoch:
85
+
86
+ ```
87
+ ep F20 val_loss
88
+ 3 0.385 2.67 spatial 放开第一个 epoch,F 就跳到 0.385
89
+ 7 0.387 3.40
90
+ 10 0.385 3.38 F 平了
91
+ 15 0.385 3.60 val_loss 持续上升
92
+ ```
93
+
94
+ v13_C **确实** 从 ep3 就到 0.385,之后 12 个 epoch 没再涨,这是真的 overfitting(val_loss 从 2.67 → 3.60)。
95
+
96
+ 但我的归因可能也有偏差:
97
+ - 我说 "real replication 6× 是失败配方"
98
+ - 但实际上 v13_C 是 C-1(real 6×)+ C-2(refinement)+ C-3(V3 adapter)+ C-4(log-dist) 四个改动同时叠加,不能武断归罪 C-1
99
+ - 正确的做法:做 ablation,只开 C-2,或只开 C-1,分别看
100
+
101
+ ## 4. v13_E 的设计基于错误诊断
102
+
103
+ v13_E 的目标写的是 "F 0.40-0.43,基于 v13_D 崩溃的教训"。
104
+ 但 v13_D 其实没崩,F 已到 0.402 @ ep10,还在涨。
105
+
106
+ **v13_E 的实际价值**:
107
+ - 它开启 num_active head 训练 + SELD evaluator 的 top-K̂ gate —— v13_D 没做的
108
+ - 作为 v13_D 之后的 **扩展实验**(v13_F?),不是替代
109
+
110
+ ## 5. 学到的教训
111
+
112
+ ### 教训 1:看 full trajectory,不看 prefix
113
+ 以后评估实验状态,必须等到 **spatial ramp 结束** 且 **至少 5 个 epoch 的 spatial 阶段数据**。看 ep0-7 的 cls warmup 数据就下判断是严重错误。
114
+
115
+ ### 教训 2:warmup schedule 是不同实验的关键差异
116
+ v12: warmup=3, v13_B/C: warmup=3, v13_D: warmup=8。相同 epoch number 对应的训练阶段完全不同。画图时应标注 "spatial_enabled_epoch" 作为基准点对齐。
117
+
118
+ ### 教训 3:诊断要基于 pipeline 理解
119
+ 我多次说 "Top-K rank 和 Hungarian 对着干"、"ASL 和 gate 互相抵消",但这些都是**基于少量 ep 数据的事后解释**,不是真正的机制推导。下次先问:"这组数据在 pipeline 的哪个阶段?"
120
+
121
+ ### 教训 4:不要急于写新实验替代旧的
122
+ v13_E 本来不需要。v13_D 只要让它跑完 25 epoch 就够了。写 v13_E 的动机是 "v13_D 崩了所以要抢救",这个前提本身就错。
123
+
124
+ ## 6. 接下来该怎么办
125
+
126
+ ### 立即行动
127
+ - **让 v13_D 继续跑到 25 epoch**,不要停
128
+ - **v13_B 不要急着下结论**。如果有算力,重启跑满 15 epoch;没有就标记为 "incomplete"
129
+ - **v13_C 的结论保留**(确实 overfit),但不能归罪单一改动
130
+
131
+ ### v13_D 最终预期(基于 v12 曲线外推)
132
+
133
+ ```
134
+ v12 从 spatial 放开后曲线:
135
+ ep3: 0.353 (刚放开)
136
+ ep12: 0.378 (best, +0.025 / 9 ep)
137
+
138
+ v13_D 类比:
139
+ ep10: 0.402 (刚放开 ramp 结束)
140
+ ep19 (+9 ep): ~0.425 (保守估计)
141
+ ep22 (+12 ep): ~0.43~0.46 (乐观)
142
+ ```
143
+
144
+ EMA + cosine LR 可能再给 +0.005~0.01。最终 v13_D 预期 **0.43 ~ 0.46**。
145
+
146
+ ### v13_E 的定位调整
147
+
148
+ 从"替代 v13_D" 改为 "v13_D 之后的扩展实验":
149
+ - 先等 v13_D 跑完
150
+ - 用 v13_D best.pt 作为 hot-start
151
+ - 在其上启用 num_active head + top-K̂ gate 看是否再涨
152
+ - 如果涨了 → v13_F 路径
153
+ - 如果不涨 → num_active 在这个任务上意义不大
154
+
155
+ 代码和 run 脚本都已落地,随时可用,不影响 v13_D 的实验。
156
+
157
+ ## 7. 对用户的道歉
158
+
159
+ 我多次给出过于自信的错误判断:
160
+ - "v13_C F 卡在 0.38 不再上升" —— 其实那是 overfitting 没错,但原因归咎 C-1 过于武断
161
+ - "v13_D 不收敛" —— 完全错误
162
+ - "Top-K rank loss 本身没用对" —— 没有证据
163
+ - "改 activity loss 的改动都失败了" —— 基于错误数据的推断
164
+
165
+ 应该做但没做的事:
166
+ - 应该先看 full trajectory 再诊断
167
+ - 应该注意 warmup schedule 差异
168
+ - 应该说 "ep0-7 是 cls warmup 期间的数据,不能代表最终性能"
169
+
170
+ **今后对策**:评估实验前先读该实验的 preset 代码,看 schedule 是什么,再去解读数据。不基于前几个 epoch 就给"崩了/不收敛"的结论。
docs/v13_spatial_beats_design.md ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v13_B + v13_C 设计文档
2
+
3
+ > 日期:2026-05-01
4
+ > 作者:Claude + user
5
+ > 目标:在 v12(F20=0.378)基础上,分两个正交实验把 F20 推到 0.45+,合并实验 v14 目标 0.55~0.62
6
+ > 设计原则:**所有改动增量式**,通过 cfg flag 控制,默认全关 → 不破坏任何现有实验(v7/v11/v12 全可正常复现)
7
+
8
+ ---
9
+
10
+ ## 0. 背景与瓶颈诊断(摘自 v12 per-subset 分析)
11
+
12
+ v12 best.pt (F20=0.3779 聚合) 按子集拆分:
13
+
14
+ ```
15
+ 子集 N F20 ER20 LE_CD LR_CD o_cls o_azi aP aR
16
+ ov1_sim 4800 0.386 0.686 26.8° 0.640 0.796 25.5° 0.950 0.046
17
+ ov2_sim 1718 0.299 0.926 29.6° 0.546 0.653 30.0° 0.916 0.151
18
+ ov3_sim 1612 0.270 0.917 30.5° 0.499 0.599 31.8° 0.888 0.165
19
+ ov1_real 3374 0.140 0.924 117.7° 0.232 0.809 27.4° 0.767 0.144
20
+ ov2_real 2230 0.098 0.866 121.1° 0.198 0.747 38.3° 0.655 0.227
21
+ ov3_real 740 0.052 0.766 146.8° 0.125 0.624 51.5° 0.612 0.232
22
+ dcase_starss 4560 0.071 1.171 130.3° 0.185 0.698 36.0° 0.625 0.176
23
+ unified 35021 (~0.40, 推算)
24
+ ```
25
+
26
+ **三大诊断**:
27
+
28
+ 1. **瓶颈是 activity_recall (0.13),不是 cls 也不是 spatial。** oracle_class_acc 在 real 上仍有 0.80,oracle_azi_mae_deg 27-51°,说明表征是健康的,被 activity gate 挡住了。
29
+ 2. **Real/Sim gap 极大**:同 ov 级别下 real F20 比 sim 低 64~81%。原因是 train 里 dcase_real 只占 6%。
30
+ 3. **Overlap 惩罚显著**:sim ov3 F20 比 ov1 低 30%(0.27 vs 0.39),K=4 track head 在 overlap 下 slot 分配混乱。
31
+
32
+ ---
33
+
34
+ ## 1. 实验切分:B 和 C 是正交维度
35
+
36
+ | 维度 | v13_B | v13_C |
37
+ |---|---|---|
38
+ | 改 loss / head 输出接口 | ✅ | ❌(沿用 v12 loss) |
39
+ | 改数据比例 / augment | 仅 augment | ✅ replication |
40
+ | 改主干架构 / 容量 | ❌ | ✅ refinement + adapter V3 |
41
+ | 热启动 | v12 best.pt (strict=False) | v12 best.pt (strict=False) |
42
+ | 能解的子集 | sim + real 的 activity 瓶颈 | ov2/ov3 overlap + real gap |
43
+
44
+ 理由:让 B 的收益和 C 的收益互不污染,便于 ablation。v14 = B + C 合并。
45
+
46
+ ---
47
+
48
+ ## 2. 共同约束(关键:不破坏现有实验)
49
+
50
+ 所有 B/C 改动都要遵守:
51
+
52
+ ### 2.1 Cfg flag 默认全关
53
+
54
+ 每条新改动对应一个 `SpatialBEATsConfig` / `SpatialLossConfig` / `TrainSpatialBEATsConfig` 字段,**默认值就是"这条改动不启用"**。
55
+
56
+ ```python
57
+ # spatial_modules 侧
58
+ self.use_class_activity_bias: bool = False # [B-1]
59
+ self.use_class_conditional_gate: bool = False # [B-3]
60
+ self.use_track_refinement: bool = False # [C-2]
61
+ self.track_refinement_layers: int = 2
62
+ self.patch_adapter_version: str = "v1" # "v1"/"v2"/"v3" [C-3] 加 v3
63
+ self.use_log_distance_head: bool = False # [C-4]
64
+
65
+ # spatial_loss 侧
66
+ self.activity_loss_type: str = "bce" # "bce"/"asymmetric" [B-2]
67
+ self.asl_gamma_neg: float = 4.0
68
+ self.asl_gamma_pos: float = 0.0
69
+ self.asl_probability_margin: float = 0.05
70
+ self.soft_f1_weight: float = 0.0 # 0 = disabled [B-4]
71
+ self.distance_loss_type: str = "l1" # "l1"/"laplace_nll" [C-4]
72
+
73
+ # dataset 侧
74
+ self.use_spec_augment: bool = False # [B-5]
75
+ self.spec_augment_time_mask_ratio: float = 0.0
76
+ self.spec_augment_freq_mask_ratio: float = 0.0
77
+ self.random_gain_db: float = 0.0
78
+ self.channel_dropout_prob: float = 0.0
79
+ self.lowpass_sim_real_prob: float = 0.0
80
+ ```
81
+
82
+ `ov1_unified_v12` preset 和之前所有 preset 都**不动**,因为它们没设置这些 flag → 新代码分支不走。
83
+
84
+ ### 2.2 热启动安全
85
+
86
+ 新加的 `nn.Module` / `nn.Parameter` 全部**零初始化**或**identity 等价初始化**:
87
+
88
+ - `class_activity_bias`:`torch.zeros(num_classes)` → logit 不变
89
+ - `GatingMLP` 最后一层 bias/weight zero-init → gate_logit = 0
90
+ - `TrackRefinementDecoder`:`layer_scale = zeros(num_layers)` → 残差 x + 0 = x
91
+ - `log_distance_head`:bias 初始化为 `log(mean_distance_v12) ≈ log(1.5)`
92
+
93
+ 从 v12 best.pt `strict=False` 加载时:
94
+ - 不存在的 key → 走零初始化
95
+ - 存在的 key(所有 v12 组件)→ 正常加载
96
+ - ep0 forward 输出应与 v12 best.pt 完全一致(或数值上差 < 1e-5)
97
+
98
+ ### 2.3 完全向后兼容
99
+
100
+ 任何旧脚本跑起来(比如 `run_ov1_unified_v12.sh`)**不需要改一行**,因为所有新字段都有默认值且默认 disabled。
101
+
102
+ ### 2.4 Preset 命名
103
+
104
+ - `ov1_unified_v13b`:启用所有 B 相关 flag
105
+ - `ov1_unified_v13c`:启用所有 C 相关 flag
106
+ - `ov1_unified_v14`:B + C 全开,热启动 max(v13b.best, v13c.best)
107
+
108
+ ---
109
+
110
+ ## 3. v13_B 详细设计:Loss + Decision 全面重写
111
+
112
+ ### [B-1] Per-class learnable logit bias
113
+
114
+ **为什么**:全局阈值 0.5 对所有类别一视同仁不合理。稀有类(jackhammer)和常见类(singing)的 activity 先验差异巨大,应当让模型自己学各类的 logit bias(等价于 per-class threshold)。
115
+
116
+ **实现点**:
117
+ - 文件:`spatial_modules.py`
118
+ - 类:`FrameTrackPredictionHeads`(定位:`class_logits` 和 `activity_logits` 的出口处)
119
+ - 新增 parameter:`self.class_activity_bias = nn.Parameter(torch.zeros(num_classes))`
120
+ - 新增 buffer:`self.use_class_activity_bias: bool`
121
+ - Forward 改动:
122
+
123
+ ```python
124
+ # 原:
125
+ activity_logit = self.activity_head(token) # [B, K, T, 1]
126
+
127
+ # 新:
128
+ activity_logit_raw = self.activity_head(token) # [B, K, T, 1]
129
+ if self.use_class_activity_bias:
130
+ class_probs = F.softmax(class_logits, dim=-1) # [B, K, T, C]
131
+ # 用 class_probs 作为加权软分配(避免 argmax 阻断梯度)
132
+ expected_bias = torch.einsum('bktc,c->bkt', class_probs, self.class_activity_bias)
133
+ activity_logit = activity_logit_raw + expected_bias.unsqueeze(-1)
134
+ else:
135
+ activity_logit = activity_logit_raw
136
+ ```
137
+
138
+ **训练/推理一致性**:bias 在训练的 BCE loss 和推理的 sigmoid 里都是**同一个量**,因此不需要 threshold sweep。推理时 threshold 始终 = 0.0(logit 空间)或 0.5(prob 空间),完全等价。
139
+
140
+ **参数量**:63 个标量,忽略不计。
141
+
142
+ ### [B-2] Asymmetric Loss 替换 BCE
143
+
144
+ **为什么**:BCE 把 FN 和 FP 等权重。当前 activity_recall=0.13 说明 FN 惩罚严重不够。ASL 对 easy negatives 用 `(1-p)^γ-` 下压,对 positives 用弱 `γ+=0`,正负不均衡下表现显著好于 BCE。
145
+
146
+ **实现点**:
147
+ - 文件:`spatial_loss.py`
148
+ - 新增函数:`asymmetric_loss_with_logits(logits, targets, gamma_neg=4, gamma_pos=0, margin=0.05)`
149
+ - 在 `compute_frame_track_losses`(或同名函数)里根据 `config.activity_loss_type` 分支:
150
+
151
+ ```python
152
+ if config.activity_loss_type == "asymmetric":
153
+ loss_act = asymmetric_loss_with_logits(
154
+ activity_logit, target_active,
155
+ gamma_neg=config.asl_gamma_neg,
156
+ gamma_pos=config.asl_gamma_pos,
157
+ margin=config.asl_probability_margin,
158
+ )
159
+ else: # "bce"
160
+ loss_act = F.binary_cross_entropy_with_logits(activity_logit, target_active, ...)
161
+ ```
162
+
163
+ **数学**:
164
+ ```
165
+ p = sigmoid(logit)
166
+ positive: -( (1-p)**γ+ ) * log(p)
167
+ negative: p_shifted = max(p - m, 0)
168
+ -( p_shifted**γ- ) * log(1 - p_shifted)
169
+ ```
170
+
171
+ **参数**:`γ+ = 0`, `γ- = 4`, `margin = 0.05`(ASL paper 推荐起点)
172
+
173
+ ### [B-3] Class-conditional gating MLP
174
+
175
+ **为什么**:activity 当前只看 token embedding。应该让 activity 也依赖 class/DOA 的确信度 —— class softmax 尖锐、DOA 稳定时更大胆判 active。
176
+
177
+ **实现点**:
178
+ - 文件:`spatial_modules.py`
179
+ - 新增类:`ClassConditionalGate(embed_dim, num_classes, hidden_dim=128)`
180
+ - 输入:`fused_token [B, K, T, D]`, `class_logits [B, K, T, C]`, `pred_dir [B, K, T, 3]`
181
+ - 融合:`gate_input = concat(token, class_emb_avg, dir_vec)` → MLP → `gate_logit [B, K, T, 1]`
182
+ - class_emb 用 `class_logits.softmax()` 加权的 class embedding(新增 `nn.Embedding(C, 32)`)
183
+ - 在 FrameTrackPredictionHeads 里:
184
+
185
+ ```python
186
+ if self.use_class_conditional_gate:
187
+ gate_logit = self.class_conditional_gate(token, class_logits, pred_dir)
188
+ activity_logit = activity_logit + self.gate_scale * gate_logit
189
+ ```
190
+
191
+ **初始化**:MLP 最后一层 `weight=zero, bias=zero` → gate_logit = 0 → ep0 等价 v12。
192
+
193
+ **参数量**:~80K。
194
+
195
+ ### [B-4] Soft-F1 auxiliary loss
196
+
197
+ **为什么**:BCE/ASL 仍是 per-sample 损失,优化目标和 macro-F20 评测有 gap。Soft-F1 直接按类聚合,和 DCASE 评估同构。
198
+
199
+ **实现点**:
200
+ - 文件:`spatial_loss.py`
201
+ - 新增函数:`soft_macro_f1_loss(activity_logits, class_logits, target_active, target_class)`
202
+ - 对每个类 `c`:
203
+ - `p_c = sigmoid(act_logit) * softmax(class)[c]` (class-c 的软 activity)
204
+ - `y_c = (target_active and target_class==c)`
205
+ - `tp_c = sum(p_c * y_c)`, `fp_c = sum(p_c * (1-y_c))`, `fn_c = sum((1-p_c) * y_c)`
206
+ - `f1_c = 2 tp_c / (2 tp_c + fp_c + fn_c + eps)`
207
+ - `loss = 1 - mean(f1_c)`
208
+ - 在总 loss 里:
209
+
210
+ ```python
211
+ if config.soft_f1_weight > 0:
212
+ total_loss = total_loss + config.soft_f1_weight * soft_macro_f1_loss(...)
213
+ ```
214
+
215
+ **warmup**(已确认采用):前 3 ep `soft_f1_weight=0.1`,第 3 ep 起硬切到 `0.3`。
216
+
217
+ 实现方式:在 `train_spatial_beats.py` 的 epoch 循环里根据 `epoch >= soft_f1_warmup_epochs` 动态设置 `train_cfg.loss.soft_f1_weight`,新增 config 字段:
218
+ ```python
219
+ cfg.loss.soft_f1_weight_warmup: float = 0.1 # ep < warmup_epochs 时使用
220
+ cfg.loss.soft_f1_weight: float = 0.3 # ep >= warmup_epochs
221
+ cfg.loss.soft_f1_warmup_epochs: int = 3
222
+ ```
223
+
224
+ ### [B-5] Real-distribution augment
225
+
226
+ **为什么**:sim_static 混响干净,模型学到的 activity 判据在低 SNR 下崩溃。augment 让模型见到各种"污染"的 spec,对 real 数据更鲁棒。
227
+
228
+ **实现点**:
229
+ - 文件:`spatial_dataset.py`
230
+ - 在 `SpatialDataset.__getitem__` 或 collate 里加 augment pipeline
231
+ - 只在训练集(`split='train'`)启用,valid/test 不启用
232
+ - 新增 config flag:
233
+ - `use_spec_augment`(默认 False)
234
+ - `spec_augment_time_mask_ratio`(0.2 = 20% time 长度)
235
+ - `spec_augment_freq_mask_ratio`(0.15)
236
+ - `random_gain_db`(±8)
237
+ - `channel_dropout_prob`(0.1)
238
+ - `lowpass_sim_real_prob`(0.1,cutoff ∈ U[4000, 8000] Hz)
239
+
240
+ **顺序**:waveform-level augment(gain, lowpass, channel_dropout)→ feature-level augment(SpecAugment)。
241
+
242
+ **重要**:augment 只作用在 FOA 4 通道 waveform / delta feature 上,**target labels 不变**。
243
+
244
+ ### B 实验 preset: `make_ov1_unified_v13b_config`
245
+
246
+ 热启动:`v12_best.pt` (strict=False)
247
+ 开关:
248
+ ```python
249
+ cfg.model.use_class_activity_bias = True # [B-1]
250
+ cfg.model.use_class_conditional_gate = True # [B-3]
251
+ cfg.loss.activity_loss_type = "asymmetric" # [B-2]
252
+ cfg.loss.asl_gamma_neg = 4.0
253
+ cfg.loss.asl_probability_margin = 0.05
254
+ cfg.loss.soft_f1_weight = 0.3 # [B-4]
255
+ cfg.dataset.use_spec_augment = True # [B-5]
256
+ cfg.dataset.spec_augment_time_mask_ratio = 0.2
257
+ cfg.dataset.spec_augment_freq_mask_ratio = 0.15
258
+ cfg.dataset.random_gain_db = 8.0
259
+ cfg.dataset.channel_dropout_prob = 0.1
260
+ cfg.dataset.lowpass_sim_real_prob = 0.1
261
+ cfg.learning_rate = 1e-5
262
+ cfg.num_epochs = 15
263
+ cfg.output_dir = "checkpoints/spatial_beats_ov1_unified_v13b_exp/03_ov123_top4"
264
+ ```
265
+
266
+ **数据 manifest 完全复用 v12**:unified train/valid + old ov1/2/3 sim/real + dcase_starss 作为 val。
267
+
268
+ ---
269
+
270
+ ## 4. v13_C 详细设计:Data + Architecture 全面重写
271
+
272
+ ### [C-1] Real data upsampling (replication)
273
+
274
+ **为什么**:real (dcase_real) 在 train 里占 6%,梯度感受不到。DCASE 社区标准做法是 20-30% real。
275
+
276
+ **实现点**:
277
+ - 预处理脚本:`scripts/split_unified_train_by_source.py`
278
+ - 读 `unified_spatial_foa_fsd63_all/train.jsonl`
279
+ - 按 `data_source` 字段拆成三份:
280
+ - `train_sim_static.jsonl`
281
+ - `train_qa_sim.jsonl`
282
+ - `train_dcase_real.jsonl`
283
+ - 写到 `unified_spatial_foa_fsd63_all/` 同目录下
284
+ - Preset:
285
+
286
+ ```python
287
+ cfg.train_manifest_paths = (
288
+ unified_root / "train_sim_static.jsonl",
289
+ unified_root / "train_qa_sim.jsonl",
290
+ unified_root / "train_dcase_real.jsonl",
291
+ )
292
+ cfg.train_manifest_replication = (1, 1, 6)
293
+ ```
294
+
295
+ **影响估算**(基于 v12 已知分布):
296
+ - sim_static 304K × 1 = 304K
297
+ - qa_sim ~? × 1 = ~?
298
+ - dcase_real 20K × 6 = 120K
299
+ - real 占比从 6% → ~25% (取决于 qa_sim 规模)
300
+
301
+ **兼容性**:`train_manifest_replication` 机制在 `train_spatial_beats.py` 已经存在(v7j 用过),不需要新加框架代码。只改 preset。
302
+
303
+ ### [C-2] Track-wise Refinement Transformer(2 layers)
304
+
305
+ **为什么**:K=4 track slots 之间互相不知道对方在干嘛,overlap 时同一源被多个 slot 抢,或同一 slot 被多个源抢。引入 self-attention 让 slot 互相"排斥"。
306
+
307
+ **实现点**:
308
+ - 文件:`spatial_modules.py`
309
+ - 新增类:
310
+
311
+ ```python
312
+ class TrackRefinementDecoder(nn.Module):
313
+ def __init__(self, num_tracks=4, embed_dim=768, num_layers=2,
314
+ num_heads=8, dim_feedforward=2048, dropout=0.0):
315
+ self.track_queries = nn.Parameter(torch.randn(num_tracks, embed_dim) * 0.02)
316
+ self.layers = nn.ModuleList([
317
+ nn.TransformerDecoderLayer(
318
+ d_model=embed_dim, nhead=num_heads,
319
+ dim_feedforward=dim_feedforward, dropout=dropout,
320
+ activation='gelu', norm_first=True, batch_first=True,
321
+ ) for _ in range(num_layers)
322
+ ])
323
+ # Zero-init layer scale: ep0 refinement = identity
324
+ self.layer_scale = nn.Parameter(torch.zeros(num_layers))
325
+
326
+ def forward(self, memory):
327
+ # memory: fused_spatial_embeddings [B, T, D]
328
+ # 输出:refined track tokens [B, K, T, D]
329
+ B, T, D = memory.shape
330
+ K = self.track_queries.size(0)
331
+ # 复制 K queries 到时间维度:[B, K, T, D]
332
+ q = self.track_queries[None, :, None, :].expand(B, K, T, D)
333
+ # 每个时间步独立做 decoder
334
+ # 为简化,把 T 维 flatten 进 batch:[B*T, K, D] cross-attn with [B*T, 1, D]
335
+ q_flat = q.permute(0, 2, 1, 3).reshape(B * T, K, D)
336
+ mem_flat = memory.reshape(B * T, 1, D)
337
+ for i, layer in enumerate(self.layers):
338
+ out = layer(q_flat, mem_flat)
339
+ q_flat = q_flat + self.layer_scale[i] * (out - q_flat)
340
+ # reshape 回 [B, K, T, D]
341
+ refined = q_flat.reshape(B, T, K, D).permute(0, 2, 1, 3).contiguous()
342
+ return refined
343
+ ```
344
+
345
+ - 在 `SpatialBEATs` 里:
346
+
347
+ ```python
348
+ if cfg.use_track_refinement:
349
+ self.track_refinement = TrackRefinementDecoder(
350
+ num_tracks=cfg.num_tracks,
351
+ embed_dim=cfg.encoder_embed_dim,
352
+ num_layers=cfg.track_refinement_layers,
353
+ )
354
+
355
+ # encode_patches 之后、送入 head 之前:
356
+ if self.track_refinement is not None:
357
+ track_tokens = self.track_refinement(encoder_memory) # [B, K, T, D]
358
+ # 传给 FrameTrackPredictionHeads 的输入从 [B, T, D] 改成 [B, K, T, D]
359
+ else:
360
+ track_tokens = None # head 沿用旧 expand 逻辑
361
+ ```
362
+
363
+ - `FrameTrackPredictionHeads` 的 forward ��个 `track_tokens: Optional[Tensor]` 参数:
364
+ - 传入 None → 沿用现有的"[B,T,D] 复制到 K slots"
365
+ - 传入 `[B,K,T,D]` → 用 refined tokens 走 head
366
+
367
+ **参数量**:2 layer × (self_attn + cross_attn + FFN) ≈ 2 × 2M = 4M。
368
+
369
+ **Zero-init 校验**:`layer_scale = zeros(2)` + 残差公式 `q + scale * (out - q)` → ep0 输出 = `track_queries`(静态,和 memory 无关)。但这会丢掉时间信息 —— **修正**:改用 `q + scale * layer_out`,并且把 track_queries 初始化成 `memory` 投影:
370
+
371
+ 实际更安全的等价初始化:
372
+
373
+ ```python
374
+ # Zero-init 方案:layer 不改 query,query 本身先吸收 memory 信息
375
+ # 思路:在 refine 前先做一次 "identity fallback":如果 scale=0,输出 = memory 广播到 K
376
+ def forward(self, memory):
377
+ B, T, D = memory.shape
378
+ K = ...
379
+ # 初始 track_tokens = memory 广播到 K(+ 一个很小的 query 偏移)
380
+ track_tokens = memory[:, None, :, :].expand(B, K, T, D).contiguous()
381
+ track_tokens = track_tokens + 0.02 * self.track_queries[None, :, None, :]
382
+ # refine
383
+ for i, layer in enumerate(self.layers):
384
+ ...
385
+ track_tokens = track_tokens + self.layer_scale[i] * delta
386
+ return track_tokens
387
+ ```
388
+
389
+ 这样 `layer_scale=0` 时 refinement 输出 ≈ `memory` 广播到 K,和 v12 "把 [B,T,D] 复制到 [B,K,T,D]" 等价。热启动安全。
390
+
391
+ ### [C-3] Multi-scale Patch Adapter V3
392
+
393
+ **为什么**:v12 用的 V2 adapter 只看 3 个时间 bin(30 ms),抓不到房间冲激响应的 early reflection (50-150ms)。V3 加多尺度 + dilated conv。
394
+
395
+ **实现点**:
396
+ - 文件:`spatial_modules.py`
397
+ - 新增类:`SpatialDeltaPatchAdapterV3`
398
+ - 三路 branch:
399
+ - branch_3x3: `Conv2d(C, H, kernel=3, padding=1)` (同 V2)
400
+ - branch_5x5: `Conv2d(C, H, kernel=5, padding=2)` (中尺度)
401
+ - branch_dilated: `Conv2d(C, H, kernel=3, padding=2, dilation=2)` (长时)
402
+ - fuse: `torch.cat` along channel → `Conv2d(3H, H, kernel=1)`
403
+ - 接现有 V2 的 SE block + residual + patchify
404
+ - cfg:`patch_adapter_version: str = "v1"` 增加选项 `"v3"`
405
+
406
+ **参数量**:比 V2 多 ~1M。
407
+
408
+ ### [C-4] Log-distance head + Laplace NLL loss
409
+
410
+ **为什么**:dist_mae=0.57 很差。距离分布长尾,log 后近似高斯。加 uncertainty 头允许模型对不确信的距离给大 variance,减少高 bias 样本的损失。
411
+
412
+ **实现点**:
413
+ - 文件:`spatial_modules.py`, 类 `FrameTrackPredictionHeads`
414
+ - 把现有 `distance_head: Linear(D, 1)` 升级为 `distance_head: Linear(D, 2)` 输出 `[log_dist, log_var]`
415
+ - 初始化:`bias[0] = log(1.5)`(v12 平均距离附近),`bias[1] = log(0.2^2)`(var=0.04 起点)
416
+ - cfg:`use_log_distance_head: bool = False`, `distance_loss_type: str = "l1" / "laplace_nll"`
417
+
418
+ - 文件:`spatial_loss.py`, Laplace NLL:
419
+
420
+ ```python
421
+ def laplace_nll_loss(pred_log_dist, pred_log_var, target_dist, mask):
422
+ # target_dist > 0 的位置才算 loss
423
+ pred_dist = pred_log_dist.exp()
424
+ pred_b = (0.5 * pred_log_var).exp() # Laplace scale
425
+ nll = (target_dist - pred_dist).abs() / pred_b + pred_log_var * 0.5
426
+ return (nll * mask).sum() / mask.sum().clamp(min=1)
427
+ ```
428
+
429
+ 推理时 `pred_distance = exp(pred_log_dist)`。
430
+
431
+ **初期稳定性**(已确认):v13c 从第 0 ep 就切 Laplace NLL(不做 L1 warmup)。如果训练前期 loss_dist 不稳,再回来调。
432
+
433
+ ### C 实验 preset: `make_ov1_unified_v13c_config`
434
+
435
+ 热启动:`v12_best.pt` (strict=False)
436
+ 开关:
437
+ ```python
438
+ cfg.model.use_track_refinement = True # [C-2]
439
+ cfg.model.track_refinement_layers = 2
440
+ cfg.model.patch_adapter_version = "v3" # [C-3]
441
+ cfg.model.use_log_distance_head = True # [C-4]
442
+ cfg.loss.distance_loss_type = "laplace_nll" # [C-4]
443
+
444
+ cfg.train_manifest_paths = (sim_static, qa_sim, dcase_real) # [C-1]
445
+ cfg.train_manifest_replication = (1, 1, 6)
446
+
447
+ cfg.learning_rate = 1e-5
448
+ cfg.num_epochs = 20
449
+ cfg.output_dir = "checkpoints/spatial_beats_ov1_unified_v13c_exp/03_ov123_top4"
450
+ ```
451
+
452
+ **loss 完全沿用 v12 默认**(BCE + L1 → Laplace),`soft_f1_weight=0`, `activity_loss_type="bce"`。
453
+
454
+ ---
455
+
456
+ ## 5. v14 合并实验(后续)
457
+
458
+ 在 B 和 C 都验证有效(F20 > 0.42)后启动:
459
+
460
+ - 热启动:`max(v13b.best, v13c.best).pt`
461
+ - 所有 B 和 C 的 flag 全开
462
+ - LR = 5e-6(更保守,防止双改动发散)
463
+ - epochs = 20
464
+
465
+ 预期 F20 = 0.55~0.62。
466
+
467
+ ---
468
+
469
+ ## 6. 预期结果 & 风险矩阵
470
+
471
+ ### B 预期
472
+ - 聚合 F20: 0.378 → **0.45~0.52**
473
+ - sim ov1: 0.386 → 0.50~0.55
474
+ - real ov1: 0.140 → 0.22~0.28
475
+ - dcase: 0.071 → 0.15~0.20
476
+ - activity_recall: 0.13 → 0.40~0.55
477
+
478
+ ### C 预期
479
+ - 聚合 F20: 0.378 → **0.44~0.50**
480
+ - ov3_sim: 0.270 → 0.38~0.42
481
+ - real ov1: 0.140 → 0.20~0.26
482
+ - dcase: 0.071 → 0.14~0.19
483
+ - dist_mae: 0.566 → 0.38~0.42
484
+
485
+ ### 风险
486
+
487
+ | 风险 | 发生概率 | 兜底 |
488
+ |---|---|---|
489
+ | B-2 ASL γ- 太大发散 | 中 | 先 γ-=2 跑 1 ep 验证,再拉到 4 |
490
+ | B-3 gate 挡掉好样本 | 低 | gate_scale 从 0.5 改 0.2 重跑 |
491
+ | B-4 soft-F1 和 ASL 冲突 | 中 | soft_f1_weight 从 0.3 降到 0.1 |
492
+ | B-5 augment 太强 sim 掉点 | 中 | 减半 augment 比例重跑 |
493
+ | C-1 real 6× 导致 sim 掉点 | 中 | 降到 4× |
494
+ | C-2 refinement 不学 | 中 | 手动设 layer_scale warmup |
495
+ | C-3 多尺度显存爆 | 低 | 去掉 dilated branch |
496
+ | C-4 log-dist 初期不稳 | 中 | 前 3 ep 用 L1 再切 |
497
+ | v14 合并发散 | 高 | 降 LR 到 3e-6,freeze trunk 前半段 |
498
+
499
+ ---
500
+
501
+ ## 7. 落地文件清单
502
+
503
+ | 文件 | 改动类型 | B/C |
504
+ |---|---|---|
505
+ | `spatial_modules.py` | 新增 3 个类 + 现有类加 forward 分支 | B+C |
506
+ | `spatial_loss.py` | 新增 2 个 loss 函数 + config 分支 | B+C |
507
+ | `spatial_dataset.py` | 新增 augment 逻辑 + config 字段 | B |
508
+ | `spatial_beats.py` | config 字段 + 可选模块构造 + forward 分支 | B+C |
509
+ | `train_spatial_beats.py` | 新增 2 个 preset 工厂 + CLI dispatch + choices | B+C |
510
+ | `scripts/split_unified_train_by_source.py` | 新文件,预处理脚本 | C |
511
+ | `run_ov1_unified_v13b.sh` | 新文件 | B |
512
+ | `run_ov1_unified_v13c.sh` | 新文件 | C |
513
+ | `docs/v13_spatial_beats_design.md` | 本文档 | — |
514
+
515
+ 所有现有文件的改动都是**新增分支**,不删除/修改任何现有逻辑。
516
+
517
+ ---
518
+
519
+ ## 8. 验证步骤
520
+
521
+ 每完成一步,按顺序验证:
522
+
523
+ 1. **语法检查**:`python -c "import ast; ast.parse(open(path).read())"`
524
+ 2. **旧 preset 回归**:`python train_spatial_beats.py --preset ov1_unified_v12 --dry-run`(或者 ep=1 跑到第一个 batch),确认 F20 和 v12 一致
525
+ 3. **新模块零初始化等价**:在 `v13b_config` / `v13c_config` 下跑 ep=0 valid,确认和 v12 best.pt 的 valid 指标差异 < 1%
526
+ 4. **B/C 训练**:完整跑 15/20 ep,观察 F20 曲线
527
+ 5. **per-subset eval**:用 `eval_v12_per_subset.py --preset ov1_unified_v13b --checkpoint ...` 看每个子集涨点
528
+ 6. **test eval**:用同脚本加 `--split test` 跑最终测试
docs/v13d_spatial_beats_design.md ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # v13_D 设计文档 — Loss 机制 + 训练细节双重改进
2
+
3
+ > 日期:2026-05-02
4
+ > 起源:v13_B / v13_C 实验 ep3 结果证明 "loss/arch 改动在 cls warmup 结束时贡献微弱",需要重新从**训练机制**角度切入
5
+ > 目标:F20 从 v12 的 0.378 推到 **0.43+**
6
+ > 设计原则:延续 v13_B/C 的"增量式、不破坏现有实验",所有新 flag 默认 False
7
+
8
+ ---
9
+
10
+ ## 0. 为什么 v13_B / v13_C 都不 work
11
+
12
+ ep3(cls warmup 结束)对比:
13
+
14
+ | | v12 ep3 | v13_B ep3 | v13_C ep3 |
15
+ |---|---|---|---|
16
+ | F20 | 0.3529 | 0.3565 (+0.004) | 0.3849 (+0.032) |
17
+ | o_cls | 0.7793 | 0.7757 | 0.7779 |
18
+ | aR | 0.120 | 0.194 | 0.090 |
19
+ | aP | 0.866 | 0.860 | 0.859 |
20
+
21
+ **三个实验的 oracle_class_acc 基本一致** → 三组实验的表征学的是同一个东西,新改动都没触达表征。loss/head decision 层面的改动(ASL / gate / soft-F1)在 cls warmup 结束时**基本还没起作用**,因为:
22
+
23
+ 1. **zero-init 的新模块需要 3-5 ep 才能让 layer_scale 爬起来**,而 cls warmup 只有 3 ep
24
+ 2. **ASL γ-=4 让 loss_activity 绝对值从 0.29 降到 0.07**,实际上是把 activity head 的梯度信号弄弱了
25
+ 3. **augment 反而降低 activity_precision**(0.95 → 0.86)
26
+ 4. **v12 自身从 ep3→ep12 只涨了 +0.025**,外推 v13_B/C 的 best 也就 0.38~0.41
27
+
28
+ ## 1. v13_D 的切入点:**训练机制本身**
29
+
30
+ 不碰架构、不碰表征,只改三件事:
31
+
32
+ - **D-1**: 扩大 cls warmup(3 ep → 8 ep)+ cosine LR + 总 25 ep,让表征学得更稳
33
+ - **D-2**: 用 Top-K rank activity loss 替换 BCE,直接针对 activity_recall=0.13 的根本瓶颈
34
+ - **D-5**: resume optimizer,从 v12 最后的 Adam momentum 继续,避免前 2-3 ep 梯度方向混乱
35
+ - **D-6**: EMA 权重,validate 时用 EMA 模型,ep 间震荡降低 1-2 个点
36
+
37
+ D-2 是核心,D-1 / D-5 / D-6 是辅助。
38
+
39
+ ## 2. 具体设计
40
+
41
+ ### D-1: cls warmup 拉长 + cosine LR + 25 epochs
42
+
43
+ #### 诊断
44
+ v12 `frame_spatial_loss_warmup_epochs = 3`,ep0-2 只训练 cls + activity,ep3 起放开 spatial loss。但:
45
+ - v12 曲线:cls_acc ep0=0.65 → ep1=0.81 → ep2=0.85 → ep3=0.84 → ... → ep12=0.83
46
+ - **ep2 已 0.85 但 ep3 反而微降 0.84**,说明 **ep3 放开 spatial loss 的瞬间干扰了 cls**
47
+ - cls 没有再涨的机会 —— 之后一直在 0.83 附近波动
48
+ - FSD63 的 oracle_cls 卡在 0.78 是 class head 学的,不是 trunk 表征的上限
49
+
50
+ #### 改动
51
+ ```python
52
+ cfg.frame_spatial_loss_warmup_epochs = 8 # 3 → 8
53
+ cfg.frame_spatial_loss_ramp_epochs = 2 # 1 → 2(ramp 更平滑)
54
+ cfg.num_epochs = 25 # 15 → 25
55
+
56
+ # LR: cosine schedule,峰值在 ep 5-8 之间(warmup 结束附近)
57
+ cfg.use_cosine_lr = True # 新 flag
58
+ cfg.cosine_lr_warmup_epochs = 3 # 前 3 ep linear warmup 到 peak
59
+ cfg.cosine_lr_min_ratio = 0.05 # 最后降到 peak * 0.05
60
+ ```
61
+
62
+ **训练循环改动**:
63
+ ```python
64
+ # ep 0-2: linear warmup LR 0 → peak_lr
65
+ # ep 3-24: cosine decay peak_lr → peak_lr * min_ratio
66
+ # ep 0-7: spatial loss weight = warmup_scale (0.0 or 0.1)
67
+ # ep 8-9: linear ramp to 1.0
68
+ # ep 10-24: full spatial loss
69
+ ```
70
+
71
+ ### D-2: Top-K rank activity loss(核心)
72
+
73
+ #### 诊断
74
+ 当前 BCE(或 ASL)对每个 `(b, k, t)` 位置独立判断"该 slot 是否 active"。问题:
75
+ - K=4 slots 在 ov1 数据上永远 3/4 inactive,类别极不平衡
76
+ - BCE 的最优解是 `sigmoid(act_logit) ≈ 0.25`(平均 prior),不会敢预测 active
77
+ - SELD 评估实际用 **sorted by activity logit, take top-K̂** 的方式决策
78
+ - 训练目标和评估目标不对齐
79
+
80
+ #### 设计
81
+
82
+ **Top-K rank loss**:在每一帧 `(b, t)`,强制 top-`N_active_gt` 个 slot 的 activity logit **必须比其他 slot 至少高 margin**,而不是独立回归 0/1。
83
+
84
+ ```python
85
+ def topk_rank_activity_loss(
86
+ activity_logit: Tensor, # [B, K, T]
87
+ target_active: Tensor, # [B, K, T], 0/1
88
+ valid_time: Tensor, # [B, T]
89
+ margin: float = 2.0,
90
+ ) -> Tensor:
91
+ """
92
+ Per-frame marginal ranking loss:
93
+ For each active slot i (target=1) and each inactive slot j (target=0),
94
+ enforce logit[i] > logit[j] + margin.
95
+
96
+ Equivalent to:
97
+ loss = Σ_{i in A, j in I} max(0, margin + logit[j] - logit[i])
98
+ This gives direct gradient that "ranks" active slots above inactive ones,
99
+ which aligns with the DCASE eval pipeline (take top-K̂ per frame).
100
+
101
+ Plus a weak binary regularizer to anchor logit magnitude:
102
+ + 0.1 * BCE(activity_logit, target_active)
103
+ """
104
+ # [B, T] active_count per frame
105
+ n_active = target_active.sum(dim=1) # [B, T]
106
+ # Loop-free formulation using broadcasting:
107
+ # logit_i: [B, K, T] (active side)
108
+ # logit_j: [B, K, T] (inactive side)
109
+ # pairwise diff: [B, K_i, K_j, T]
110
+ # mask: target_active[i] * (1 - target_active[j]) [B, K_i, K_j, T]
111
+ act = target_active.unsqueeze(2) # [B, K, 1, T]
112
+ ina = (1.0 - target_active).unsqueeze(1) # [B, 1, K, T]
113
+ pair_mask = act * ina # [B, K_i, K_j, T]
114
+ logit_i = activity_logit.unsqueeze(2) # [B, K, 1, T]
115
+ logit_j = activity_logit.unsqueeze(1) # [B, 1, K, T]
116
+ diff = logit_j - logit_i + margin # want this <= 0
117
+ # hinge loss, masked
118
+ hinge = F.relu(diff) * pair_mask # [B, K, K, T]
119
+ # normalize by valid pairs count
120
+ pair_valid = pair_mask.sum(dim=(1, 2)) # [B, T]
121
+ time_mask = valid_time.float() * (pair_valid > 0).float()
122
+ loss_rank = (hinge.sum(dim=(1, 2)) * time_mask).sum() / time_mask.sum().clamp(min=1.0)
123
+
124
+ # Anchor term: prevents logits from drifting to ±inf
125
+ loss_bce = F.binary_cross_entropy_with_logits(
126
+ activity_logit, target_active, reduction='none'
127
+ )
128
+ loss_bce = (loss_bce * valid_time.unsqueeze(1)).mean()
129
+
130
+ return loss_rank + 0.1 * loss_bce
131
+ ```
132
+
133
+ #### 为什么比 ASL 好
134
+
135
+ | 特性 | BCE | ASL | **Top-K rank** |
136
+ |---|---|---|---|
137
+ | 优化目标 | per-element logprob | per-element logprob (γ-weighted) | **pairwise ranking** |
138
+ | 受 K 不平衡影响 | 严重 | 缓解 | 无(只看 rank) |
139
+ | 与 DCASE 评估对齐 | ❌ | ❌ | **✓** (top-K̂) |
140
+ | 训练稳定性 | 好 | 中 (γ 过大会崩) | **好**(hinge + 小 BCE anchor) |
141
+ | 已知效果 | v12 aR=0.13 | v13_B aR=0.19,aP 降 | **未验证**,但机制对 |
142
+
143
+ #### Config flag
144
+
145
+ ```python
146
+ # spatial_loss.py
147
+ frame_activity_loss_type: str = "bce" # + "topk_rank"
148
+ topk_rank_margin: float = 2.0
149
+ topk_rank_bce_weight: float = 0.1 # anchor 的 BCE 权重
150
+ ```
151
+
152
+ ### D-5: Resume optimizer(低成本改进)
153
+
154
+ #### 诊断
155
+ v13_B/C 都用 `--no-resume-optimizer`,Adam 的 m/v moment buffer 从零重建。前 2-3 个 epoch 梯度方向不稳,尤其在"换 loss 函数"后更明显。
156
+
157
+ #### 改动
158
+ ```bash
159
+ # run_ov1_unified_v13d.sh 不加 --no-resume-optimizer
160
+ # 但 LR 设为 v12 最后 LR 的 1/3(避免 resume 后太激进)
161
+ SPATIAL_LR="${SPATIAL_LR:-7e-6}" # v12 是 2e-5,这里是 2e-5/3 ≈ 7e-6
162
+ ```
163
+
164
+ **注意**:Optimizer state 包含 LR scheduler 状态,如果我们切 cosine schedule 需要 reset schedule 但保留 Adam moments。实现时:
165
+ ```python
166
+ # 加载 optimizer_state_dict
167
+ optimizer.load_state_dict(ckpt['optimizer_state_dict'])
168
+ # 但把所有 param_group 的 LR 重设为新 LR(cosine scheduler 会从这开始)
169
+ for pg in optimizer.param_groups:
170
+ pg['lr'] = new_peak_lr
171
+ # 删掉 step count(avoid schedule confusion)
172
+ # scheduler 从 epoch=0 重新开始
173
+ ```
174
+
175
+ ### D-6: EMA 权重
176
+
177
+ #### 诊断
178
+ v12 ep10-14 F20 在 0.367-0.378 震荡,SGD 困在鞍点。EMA = 取最近 N 个权重的平滑平均,能稳定在鞍点中间而非某一端。
179
+
180
+ #### 改动
181
+
182
+ **新加 `EMAModel` helper**:
183
+ ```python
184
+ class EMAModel:
185
+ def __init__(self, model: nn.Module, decay: float = 0.9995):
186
+ self.decay = decay
187
+ self.shadow: Dict[str, Tensor] = {
188
+ name: p.detach().clone()
189
+ for name, p in model.named_parameters()
190
+ if p.requires_grad
191
+ }
192
+
193
+ @torch.no_grad()
194
+ def update(self, model: nn.Module):
195
+ for name, p in model.named_parameters():
196
+ if not p.requires_grad: continue
197
+ self.shadow[name].mul_(self.decay).add_(p.detach(), alpha=1 - self.decay)
198
+
199
+ def apply_to(self, model: nn.Module) -> Dict[str, Tensor]:
200
+ """Swap model params with EMA shadow, returns backup for restoration."""
201
+ backup = {}
202
+ for name, p in model.named_parameters():
203
+ if name in self.shadow:
204
+ backup[name] = p.data.clone()
205
+ p.data.copy_(self.shadow[name])
206
+ return backup
207
+
208
+ def restore(self, model: nn.Module, backup: Dict[str, Tensor]):
209
+ for name, p in model.named_parameters():
210
+ if name in backup:
211
+ p.data.copy_(backup[name])
212
+ ```
213
+
214
+ **训练循环**:
215
+ ```python
216
+ # 每 step 后:
217
+ if ema_model is not None:
218
+ ema_model.update(model)
219
+
220
+ # validate 前:
221
+ if ema_model is not None:
222
+ backup = ema_model.apply_to(model)
223
+ val_metrics = evaluate_one_epoch(model, ...)
224
+ if ema_model is not None:
225
+ ema_model.restore(model, backup)
226
+
227
+ # save best.pt 时,保存 EMA 权重而非原始权重
228
+ if ema_model is not None:
229
+ backup = ema_model.apply_to(model)
230
+ torch.save({'model_state_dict': model.state_dict(), ...})
231
+ ema_model.restore(model, backup)
232
+ ```
233
+
234
+ #### Config flag
235
+
236
+ ```python
237
+ # TrainSpatialBEATsConfig
238
+ use_ema: bool = False
239
+ ema_decay: float = 0.9995
240
+ ema_start_epoch: int = 3 # 前 3 ep 不 EMA(避免 warmup 噪声)
241
+ ```
242
+
243
+ ## 3. v13_D preset
244
+
245
+ ```python
246
+ def make_ov1_unified_v13d_config(...):
247
+ cfg = make_ov1_unified_v12_config(...) # v12 为基础
248
+
249
+ # D-1: 扩大 cls warmup,cosine schedule
250
+ cfg.frame_spatial_loss_warmup_epochs = 8
251
+ cfg.frame_spatial_loss_ramp_epochs = 2
252
+ cfg.num_epochs = 25
253
+ cfg.use_cosine_lr = True
254
+ cfg.cosine_lr_warmup_epochs = 3
255
+ cfg.cosine_lr_min_ratio = 0.05
256
+ cfg.learning_rate = 1.5e-5 # peak LR
257
+
258
+ # D-2: Top-K rank activity loss
259
+ cfg.loss.frame_activity_loss_type = "topk_rank"
260
+ cfg.loss.topk_rank_margin = 2.0
261
+ cfg.loss.topk_rank_bce_weight = 0.1
262
+
263
+ # D-5: resume optimizer (在 run script 里,不写 --no-resume-optimizer)
264
+
265
+ # D-6: EMA
266
+ cfg.use_ema = True
267
+ cfg.ema_decay = 0.9995
268
+ cfg.ema_start_epoch = 3
269
+
270
+ cfg.output_dir = "checkpoints/spatial_beats_ov1_unified_v13d_exp/03_ov123_top4"
271
+ return cfg
272
+ ```
273
+
274
+ ## 4. 实现步骤
275
+
276
+ | 文件 | 改动 | 对应 D-* |
277
+ |---|---|---|
278
+ | `spatial_loss.py` | 加 `_topk_rank_activity_loss` + config 字段 + 分支 | D-2 |
279
+ | `train_spatial_beats.py` | 加 `EMAModel` class + cosine LR scheduler + 训练循环 hook | D-1, D-6 |
280
+ | `train_spatial_beats.py` | 加 `make_ov1_unified_v13d_config` + CLI + choices | - |
281
+ | `run_ov1_unified_v13d.sh` | 新脚本,不带 `--no-resume-optimizer` | D-5 |
282
+ | `docs/v13d_spatial_beats_design.md` | 本文档 | - |
283
+
284
+ 所有改动都通过 cfg flag 控制,默认 False → v12/v13_B/v13_C 不受影响。
285
+
286
+ ## 5. 预期结果
287
+
288
+ | 指标 | v12 best | v13_D 预期 | 机制 |
289
+ |---|---|---|---|
290
+ | F20 | 0.378 | **0.42 ~ 0.46** | Top-K rank + EMA + 长 warmup |
291
+ | aR | 0.126 | **0.25 ~ 0.40** | Top-K 强制拉高活跃 slot |
292
+ | aP | 0.855 | **0.80 ~ 0.85** | 可能小降(recall ↑ 的代价),但 hinge 保留 rank 信号 |
293
+ | class_acc | 0.834 | **0.86 ~ 0.88** | 长 warmup 让 cls 真的学完 |
294
+ | azi_mae | 19.7° | **18~20°** | 不变,不是目标 |
295
+
296
+ ## 6. 风险和兜底
297
+
298
+ | 风险 | 概率 | 兜底 |
299
+ |---|---|---|
300
+ | Top-K rank 的 hinge 梯度饱和(margin 太大) | 中 | 降 margin 到 1.0 |
301
+ | margin=2.0 导致 logit 分布爆炸(两端拉开) | 低 | anchor BCE 权重从 0.1 升到 0.3 |
302
+ | EMA 反而降 F(热启动 EMA 初始化问题) | 低 | ema_start_epoch 提到 5 |
303
+ | Cosine LR 峰值太高毁掉 v12 表征 | 中 | peak_lr 降到 1e-5(v12 也是这个) |
304
+ | resume optimizer 把 v12 的 Adam moment 固化在错方向 | 低 | 如果 ep0-2 loss 爆炸,退回 --no-resume-optimizer |
305
+ | 总体不涨(所有改动都没用) | 中 | 写 ablation,跑 v13_D_noema / v13_D_nocos 诊断 |
306
+
307
+ ## 7. 和 v13_B/C 的关系
308
+
309
+ - **v13_D 不依赖 v13_B/C 的改动**,直接从 v12 best.pt 热启动
310
+ - v13_B/C 可以看作"改模块结构"的尝试,v13_D 是"改训练机制"的尝试
311
+ - 如果 v13_D 成功(F ≥ 0.42),**可以在它之上加回 v13_C 的 refinement 2-layer**,那才是真正的 v14
312
+
313
+ ## 8. 验证步骤
314
+
315
+ 1. `python -c "import ast; ast.parse(open('spatial_loss.py').read())"` 语法
316
+ 2. Top-K rank loss 单测:给已知 activity_logit + target 手算验证
317
+ 3. 模型构造 + hot-start v12 best.pt:确认 missing=0, unexpected=0
318
+ 4. 单 batch 前向 + backward:确认 loss 是 scalar、梯度非 NaN
319
+ 5. EMA 单测:update 后 shadow 权重正确
320
+ 6. 1 epoch dry-run:看 cosine LR 曲线 + EMA shadow 随 step 变化
321
+
322
+ ---
323
+
324
+ ## 附:为什么不加更多改动(D-7 per-class expert 等)
325
+
326
+ 诊断:v13_B/C 失败的主因不是"改动不够多",而是"改动不对靶"。v13_D 只碰 loss 机制 + 训练 schedule,属于**最小必要改动**:
327
+
328
+ - D-2 Top-K 直接针对 activity_recall 瓶颈
329
+ - D-1 扩大 warmup 给表征更多学习时间
330
+ - D-6 EMA 降低末期震荡
331
+ - D-5 resume optimizer 让 LR 轨迹连续
332
+
333
+ 加更多改动(per-class expert、class-conditional gate v2)会重复 v13_B 的错误——改了 head 但没触达瓶颈,而且同时改太多东西无法 ablation。