File size: 17,045 Bytes
86cbd36 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 | # Spatial-BEATs 简化版实现文档
## 1. 文档目的
本文档给出当前推荐的 `Spatial-BEATs` 简化版实现方案。
这份方案基于一个更务实的判断:
- 真正重要的是让 `FOA -> BEATs trunk` 学到稳定的空间表征
- 后面的模块只需要承担 `readout / decode / supervision` 的作用
- 不需要一开始就引入复杂的 `slot query decoder`
- 最终给 LLM 的 token 应尽量来自前面的空间 embedding,而不是最终任务头输出
- 当前阶段需要先支持 `encoder-only training`
因此,本方案不再以“内部 4 个 slots + decoder 聚合”作为主线,而改为:
```text
FOA waveform
-> Spatial preprocessor
-> Multi-channel patch embedding
-> BEATs trunk
-> Temporal readout
-> Spatial embeddings at 2.5 Hz
-> Fixed-slot prediction heads
-> Projector
-> LLM spatial tokens
```
## 2. 当前确定的约束
- 采样率:`16 kHz`
- 输入:`FOA waveform`
- 多源数据:有
- 最大同时源数:约 `4`
- 每个源有稳定 `source-level class label`
- source vocabulary:`/apdcephfs_cq12/share_302080740/user/schmittzhu/data/fsd50k/FSD50K.ground_truth/final_vocabulary.csv`
- source class count:当前默认 `65`
- 距离预测:`连续回归`
- mel 前端参数:对齐 `Qwen-2.5-Omni` audio tower 的底层配置
- `sample_rate=16000`
- `num_mel_bins=128`
- `n_fft=400`
- `win_length=400`
- `hop_length=160`
- `dither=0.0`
- 时间 supervision:`弱时间窗口`
- 当前位置:`clip 内固定`
- 未来会扩展到:`位置随时间变化`
- 目标输出 token rate:`2.5 Hz`
- 对于任意 clip,第 `i` 个样本的有效 token 数是 `T_s_i = round(duration_i * 2.5)`
- 对于 `10 s` clip,`T_s_i = 25`
- batch 内部按 `T_s_max = max_i T_s_i` 做 padding
- 主干初始化:`BEATs_iter3+ AS2M pre-trained`
- 当前第一阶段:优先支持只训练 encoder 的监督方案
## 3. 方案核心思想
### 3.1 什么是主角
主角是:
- `BEATs trunk`
它负责从 FOA 空间特征中学习空间表征。
### 3.2 什么是配角
配角是:
- temporal readout
- prediction heads
- projector
这些模块的作用只是:
- 从 trunk 特征中“读出”空间信息
- 建立 loss 回传路径
- 把空间 embedding 投影到 LLM 接口
### 3.3 给 LLM 的 token 从哪里来
最终给 LLM 的 token 应该来自:
- trunk 后
- 或 trunk 后再经过一层很浅的 temporal readout 之后
而不是来自:
- 最终 logits
- 复杂 decoder 的末端输出
## 4. 最终结构总览
```text
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T_f, F]
-> SpatialPatchEmbedding
-> patch tokens [B, N_p, D_in]
-> BEATs trunk
-> encoder memory [B, N_p, D]
-> reshape / frequency pooling
-> temporal tokens [B, T_s_max, D]
-> shallow temporal readout
-> spatial embeddings [B, T_s_max, D]
-> prediction heads
-> Spatial projector
-> llm spatial tokens [B, T_s_max, d_llm]
```
其中:
- 每个样本:
- `T_s_i = round(duration_i * 2.5)`
- 一个 batch 内:
- `T_s_max = max_i T_s_i`
- 对 `10 s` 输入:
- `T_s_i = 25`
## 5. 从 FOA 到 spatial token 的完整过程
## 5.1 输入层
输入:
- `waveform: [B, 4, T]`
四个通道分别是:
- `W`
- `X`
- `Y`
- `Z`
对 `10 s, 16kHz`:
- `T = 160000`
## 5.2 SpatialBEATsPreprocessor
目标:
- 把原始 FOA 波形转成多通道空间特征图
推荐输出通道:
- `W_logmel`
- `X_logmel`
- `Y_logmel`
- `Z_logmel`
- `IVx`
- `IVy`
- `IVz`
因此:
- `C_foa = 7`
内部步骤:
1. 对 `WXYZ` 做 STFT
2. 计算各通道 `log-mel`
3. 计算 `IVx, IVy, IVz`
4. 将 `IV` 映射到 mel 维
5. 拼接成特征图
输出:
- `foa_feat: [B, 7, T_f, 128]`
其中:
- `128` 是 mel bins
- `T_f` 约等于 `1000`,如果帧移为 `10 ms`
## 5.3 SpatialPatchEmbedding
目标:
- 把 `7` 通道 FOA 特征图切成 patch token
原始 BEATs 的 stem 是单通道:
```text
Conv2d(1, embed_dim, kernel_size=patch, stride=patch)
```
新的 stem 改成:
```text
Conv2d(7, embed_dim, kernel_size=patch, stride=patch)
```
输入:
- `foa_feat: [B, 7, T_f, 128]`
输出:
- `patch_tokens: [B, N_p, D_in]`
- `grid_hw = (T_p, F_p)`
建议:
- `D_in = 512`
原因:
- 与 BEATs patch embedding dim 对齐
## 5.4 BEATs trunk
这是整个模型最重要的部分。
保留并复用的模块包括:
- patch token layer norm
- `post_extract_proj`
- `dropout_input`
- `TransformerEncoder`
- `conv_pos`
- 所有 transformer layers
- LayerNorm / FFN / attention
输入:
- `patch_tokens: [B, N_p, 512]`
输出:
- `encoder_memory: [B, N_p, 768]`
建议:
- 主干 hidden dim 保持 `768`
- 直接加载 `BEATs_iter3+ AS2M pre-trained`
## 5.5 Reshape + Frequency Pooling
目标:
- 把 trunk 输出变成可读的时间序列表示
步骤:
1. 将 `encoder_memory [B, N_p, D]` reshape 成:
- `grid_memory [B, T_p, F_p, D]`
2. 对 `F_p` 维做 pooling
推荐第一版:
- `mean pooling over frequency`
输出:
- `temporal_patch_tokens: [B, T_p, D]`
这一步的意义是:
- 先把频率信息压到时间轴上
- 便于后面构造固定 `2.5 Hz` 的时序空间 token
## 5.6 Temporal Resampler
目标:
- 把 patch 级时间序列压到目标 token rate
输入:
- `temporal_patch_tokens: [B, T_p, D]`
输出:
- `temporal_tokens: [B, T_s, D]`
其中:
- 第 `i` 个样本:
- `T_s_i = round(duration_i * 2.5)`
- batch padding 后:
- `temporal_tokens: [B, T_s_max, D]`
- 对 `10 s` clip:
- `T_s_i = 25`
推荐第一版:
- 线性插值
- 或轻量 `Conv1d` 下采样
注意:
- `2.5 Hz` 明确指最终给 LLM 的 token rate
- 对单个 `10 s` 样本,`10 s -> 25` 个有效 token
- 对 mixed-length batch,张量会 pad 到 `T_s_max`
## 5.7 Shallow Temporal Readout
目标:
- 在时间维上再做一层轻量整理
- 让 trunk 输出更适合做空间监督和给 LLM 使用
推荐做法:
- `1~2` 层 transformer encoder
输入:
- `temporal_tokens: [B, T_s_max, 768]`
输出:
- `spatial_embeddings: [B, T_s_max, 768]`
这一层的作用是:
- 做一个浅层 readout neck
- 不承担复杂检测器职责
- 不强调 decoder 身份
- 只是把 trunk 表征整理成更干净的时间步级空间 embedding
如果你想先做最简单版本,也可以直接:
- `temporal_tokens -> LayerNorm -> spatial_embeddings`
把 shallow transformer 作为后续增强项。
## 5.8 Prediction Heads
目标:
- 从 `spatial_embeddings` 上接显式监督
- 通过 loss 让前面的 trunk 真正学到空间表征
- 在不引入复杂 decoder 的前提下提供多源监督出口
输入:
- `spatial_embeddings: [B, T_s_max, 768]`
### 5.8.1 Encoder-only 阶段的固定槽位 readout
虽然简化版不再使用复杂 `slot query decoder`,但当前仍然有:
- 最大同时源数约为 `4`
因此,encoder-only 训练阶段推荐在 `spatial_embeddings` 后接一个很轻的固定槽位 readout:
```text
spatial_embeddings [B, T_s_max, 768]
-> Linear / MLP expand
-> slot_latents [B, T_s_max, 4, H]
-> shared prediction heads
```
推荐默认:
- `H = 768`
最简单实现:
```text
Linear(768 -> 4 * 768)
reshape -> [B, 25, 4, 768]
```
更稳一点的实现:
```text
Linear(768 -> 768)
-> GELU
-> Linear(768 -> 4 * 768)
reshape -> [B, 25, 4, 768]
```
这一步的作用不是做复杂目标解析,而只是:
- 给每个时间步提供 `4` 个固定 source 槽位
- 让多源监督有明确着力点
- 把 loss 稳定回传到前面的 trunk
### 5.8.2 预测头设计
对 `slot_latents [B, 25, 4, H]` 接共享或独立 heads。
建议输出头:
- `activity / objectness head`
- `azimuth head`
- `elevation head`
- `distance head`
- `class auxiliary head`
输出形式建议:
- `pred_activity: [B, T_s_max, 4]`
- `pred_azi_logits: [B, T_s_max, 4, 360]`
- `pred_ele_logits: [B, T_s_max, 4, 180]`
- `pred_dist: [B, T_s_max, 4, 1]`
- `pred_class_logits: [B, T_s_max, 4, C_cls]`
其中:
- `pred_activity` 负责当前时间步当前槽位是否解释某个源
- `pred_azi_logits` / `pred_ele_logits` 负责方向
- `pred_dist` 负责连续距离回归
- `pred_class_logits` 负责 source-level 辅助类别监督
当前默认建议:
- `C_cls = 65`
- label vocabulary 来自:
- `final_vocabulary.csv`
- 推荐字段:
- `label_id`
- `final_label`
### 5.8.3 为什么这里仍然保留 K=4
在简化版中:
- `K=4` 不再体现在复杂 decoder 结构里
- 但仍然通过固定槽位 readout 出现在监督头中
也就是说:
- 不需要复杂 query-based object slots
- 但仍然需要一个简单的多槽位 readout 来承载多源标签
这更符合当前“先把主干训练出来”的目标。
## 5.9 Spatial Projector
目标:
- 把 `spatial_embeddings` 投影到 LLM hidden size
输入:
- `spatial_embeddings: [B, T_s_max, 768]`
推荐:
- 独立的 `MLP projector`
形式:
```text
Linear(768 -> D_mid)
-> GELU
-> LayerNorm
-> Linear(D_mid -> d_llm)
```
输出:
- `llm_spatial_tokens: [B, T_s_max, d_llm]`
这就是最终喂给 LLM 的 spatial tokens。
### 5.9.1 Encoder-only 阶段 projector 的角色
当前第一阶段如果只训练 encoder,本质目标是:
- 用监督把 `spatial_embeddings` 训好
因此 projector 在这一阶段有两种合理策略:
#### 方案 A:先不训练 projector,推荐默认
- 只训练:
- preprocessor
- patch embedding
- BEATs trunk
- temporal readout
- fixed-slot prediction heads
- projector 只保留接口,不参与训练
#### 方案 B:一并训练 projector
- 当你希望尽早固定 LLM 接口维度时可以启用
但第一版默认推荐:
- **先把 encoder 训好,再训练 projector**
## 6. 每层的输入输出总结
### 6.1 SpatialBEATsPreprocessor
- 输入:`[B, 4, T]`
- 输出:`[B, 7, T_f, 128]`
### 6.2 SpatialPatchEmbedding
- 输入:`[B, 7, T_f, 128]`
- 输出:`[B, N_p, 512]`
### 6.3 BEATs trunk
- 输入:`[B, N_p, 512]`
- 输出:`[B, N_p, 768]`
### 6.4 Reshape + Frequency Pooling
- 输入:`[B, N_p, 768]`
- 输出:`[B, T_p, 768]`
### 6.5 Temporal Resampler
- 输入:`[B, T_p, 768]`
- 输出:`[B, 25, 768]`
### 6.6 Shallow Temporal Readout
- 输入:`[B, T_s_max, 768]`
- 输出:`[B, T_s_max, 768]`
### 6.7 Prediction Heads
- 输入:`[B, T_s_max, 768]`
- 先经 `Linear / MLP expand` 变成:`[B, T_s_max, 4, H]`
- 输出:
- `activity [B, T_s_max, 4]`
- `azimuth [B, T_s_max, 4, 360]`
- `elevation [B, T_s_max, 4, 180]`
- `distance [B, T_s_max, 4, 1]`
- `class [B, T_s_max, 4, C_cls]`
### 6.8 Spatial Projector
- 输入:`[B, T_s_max, 768]`
- 输出:`[B, T_s_max, d_llm]`
## 7. Loss 如何作用到前面的表征
这个简化方案的关键点在于:
- 后面的 heads 并不是模型重点
- 它们只是为了接监督
loss 从这些 heads 回传后,会更新:
1. readout neck
2. temporal resampler
3. BEATs trunk
4. patch embedding
5. FOA preprocessor
因此:
- 只要后面的 decode/head 能稳定预测显式空间信息
- 前面的 trunk 就会被训练成空间 encoder
### 推荐 loss
当前建议:
- `L_activity`
- `L_azi`
- `L_ele`
- `L_dist`
- `L_cls_aux`
- `L_temp`
总损失:
```text
L_total =
lambda_act * L_activity
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls_aux
+ lambda_temp * L_temp
```
当前可继续保留:
- 方向分类
- 仰角分类
- 距离连续回归
- 辅助类别监督
- 时间一致性正则
### 7.1 Encoder-only 训练阶段的 loss 定义
#### `L_activity`
用于监督:
- 当前时间步当前槽位是否激活
建议:
- `BCEWithLogitsLoss` 或 `focal loss`
结合当前的弱时间窗口 supervision:
- 窗口外:负样本
- 窗口内:弱正样本
#### `L_cls_aux`
用于监督:
- 被分配到某个 GT source 的槽位类别
建议:
- `CrossEntropyLoss`
#### `L_azi` 和 `L_ele`
用于监督:
- 槽位的方向分类
建议:
- `CrossEntropyLoss`
- `azimuth`: `360` bins
- `elevation`: `180` bins
#### `L_dist`
用于监督:
- 槽位的连续距离回归
建议:
- 先将距离归一化到 `[0, 1]`
- 使用 `SmoothL1Loss`
#### `L_temp`
由于当前位置当前是 clip 内固定,建议加入:
- 时间一致性约束
例如对同一 source 在相邻时间步对应的槽位加:
- class 分布平滑
- direction 分布平滑
- distance 平滑
第一版可以先只加:
- distance smoothness
- activity smoothness
## 7.2 Encoder-only 训练时的匹配方式
因为当前 heads 是固定 `4` 槽位,而不是 query decoder,所以建议:
- 使用轻量匹配
- 不必依赖复杂 decoder 结构
推荐策略:
1. 对每个 GT source,根据其时间窗口筛出候选时间步
2. 在该时间步的 `4` 个固定槽位中,选择 cost 最小的槽位
3. 对该槽位施加:
- activity
- class
- azi
- ele
- dist
推荐 cost:
```text
cost =
w_act * cost_act
+ w_cls * cost_cls
+ w_azi * cost_azi
+ w_ele * cost_ele
+ w_dist * cost_dist
```
第一版可以不做全局 Hungarian,直接做:
- `per-step fixed-slot matching`
如果后面发现多源冲突严重,再升级 matching。
## 8. 关于多源 supervision 的理解
虽然当前结构没有显式 `slot query decoder`,但这并不等于完全不考虑多源。
当前更合理的理解是:
- 让 `spatial_embeddings [B, 25, D]` 表示该时间步的空间场景表征
- 再通过固定 `4` 槽位 readout head 承载多源标签
- supervision 用这些多源标签约束 trunk 必须编码多个源的空间信息
这相当于:
- 先学一个强的 time-step level spatial embedding
- 再决定是否需要升级成更复杂的 object-centric / query-based 版本
这是更稳的第一版路径。
## 9. 为什么这个简化版更合适当前阶段
### 9.1 更符合老师的建议
老师的核心意见可以总结为:
- 不必执着于 decoder 结构本身
- 后面的模块只要能 decode 出空间监督即可
- 真正要训好的是前面的 trunk
这个简化版完全符合这个思路。
### 9.2 工程复杂度更低
不需要一开始就实现:
- slot query decoder
- cross-attention decoder
- objectness pooling
- 复杂 slot matching
### 9.3 更利于先验证 trunk 是否真的学到空间表征
如果这版能成功:
- 说明 trunk + temporal readout 本身已经足够表达空间信息
如果这版不够:
- 再升级到 slot/object 版本
这样路线更清晰。
## 10. 推荐实现版本
### V1:最小可行版
```text
FOA -> preprocessor -> patch embed -> BEATs trunk
-> frequency pooling -> temporal resampler
-> LayerNorm
-> spatial embeddings
-> Linear expand to 4 slots
-> heads
```
### V2:推荐正式版
```text
FOA -> preprocessor -> patch embed -> BEATs trunk
-> frequency pooling -> temporal resampler
-> 1~2 层 shallow transformer readout
-> spatial embeddings
-> Linear / MLP expand to 4 slots
-> heads
-> projector
```
### V3:后续增强版
如果未来发现:
- 多源关系建模仍然不足
- 动态轨迹下表达不够
再升级为:
- slot-based decoder
- query-based object-centric readout
## 11. 推荐代码划分
建议新增或保留以下文件:
- `spatial_beats.py`
- `spatial_modules.py`
- `spatial_loss.py`
- `spatial_dataset.py`
- `train_spatial_beats.py`
### `spatial_modules.py`
建议包含:
- `SpatialBEATsPreprocessor`
- `SpatialPatchEmbedding`
- `TemporalResampler`
- `TemporalReadoutTransformer`
- `FixedSlotReadoutHead`
- `SpatialPredictionHeads`
- `SpatialProjector`
### `spatial_beats.py`
建议主类:
- `SpatialBEATsConfig`
- `SpatialBEATs`
主类中建议暴露:
- `extract_spatial_embeddings()`
- `project_for_llm()`
- `forward()`
## 12. 最终输出接口
推荐 `forward()` 返回:
```python
{
"encoder_memory": FloatTensor[B, N_p, 768],
"temporal_tokens": FloatTensor[B, 25, 768],
"spatial_embeddings": FloatTensor[B, 25, 768],
"slot_latents": FloatTensor[B, 25, 4, H],
"pred_activity": FloatTensor[B, 25, 4],
"pred_azi_logits": FloatTensor[B, 25, 4, 360],
"pred_ele_logits": FloatTensor[B, 25, 4, 180],
"pred_dist": FloatTensor[B, 25, 4, 1],
"pred_class_logits": FloatTensor[B, 25, 4, C_cls],
"llm_spatial_tokens": FloatTensor[B, 25, d_llm],
}
```
其中:
- `spatial_embeddings` 是最核心的中间表示
- `slot_latents` 只是 encoder-only 监督出口
- `llm_spatial_tokens` 是最终给 LLM 的接口
## 13. 结论
当前推荐的简化版最终架构是:
- `FOA -> spatial features -> BEATs trunk -> 2.5Hz temporal spatial embeddings -> 固定4槽位 readout heads -> projector`
这套方案的重点是:
- 用显式空间监督把前面的 `BEATs trunk` 训练成空间 encoder
- 后面的 fixed-slot head 只承担监督和 readout 的职责
- 给 LLM 的 token 直接来自前面的 `spatial_embeddings`
这是当前最适合进入代码实现的主线方案。
|