File size: 14,441 Bytes
bf04039 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 | # Spatial-BEATs 实现规格
## 1. 目标
本规格文档用于将前期讨论收敛为一个可以直接实施的 `Spatial-BEATs` 方案。
目标是构建一个独立的 `Spatial Encoder`:
- 输入为完整 `FOA` 音频及其派生空间特征
- 完整的 `FOA` 特征经过 `BEATs backbone`
- 最大化复用 `BEATs` 预训练权重
- 输出一组 `source-level spatial tokens`
- 这些 token 作为独立模态输入给 LLM
- 原有语义 audio encoder 保持不动
这里的关键原则是:
> 不是让 `W-only` 走主干,再外挂一个小空间 adapter;而是让完整 FOA 空间特征真正进入 BEATs 主干,并在主干之后产出结构化空间 token。
## 2. 最终任务定义
### 2.1 核心任务
`Spatial-BEATs` 的主任务定义为:
- 给定一个多源 `FOA` 音频片段
- 预测其中最多 `K` 个潜在声源的空间表示
- 每个表示对应一个 `source token`
每个 source token 至少承载:
- `objectness`
- `azimuth`
- `elevation`
- `distance`
可选承载:
- `source class auxiliary logits`
- `source embedding`
### 2.2 推荐监督形式
如果训练数据中每个源都有标注,则推荐采用:
- `set prediction`
- `K` 个预测 token 对 `N` 个 GT sources
- 用 `Hungarian matching` 做一一匹配
不建议采用:
- 单一 scene-level spatial token
- 仅回归整段音频的全局空间摘要
原因是这会损失多源结构,不利于后续 LLM 做关系推理。
## 3. 最终架构
推荐最终架构:
```text
FOA waveform
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T, F]
-> FOA patch embedding
-> BEATs trunk
-> Spatial query decoder
-> K source tokens
-> Spatial prediction heads
-> LLM projector
```
为了最大化复用 BEATs 主干,本方案尽量不改 trunk 内部的 Transformer 结构。
## 4. 输入特征定义
### 4.1 默认推荐特征
第一版推荐输入通道:
- `W_logmel`
- `X_logmel`
- `Y_logmel`
- `Z_logmel`
- `IVx`
- `IVy`
- `IVz`
即:
- `C_foa = 7`
这是默认推荐方案。
### 4.2 备选输入特征
若希望先降低复杂度,可以使用:
- `WXYZ logmel`
即:
- `C_foa = 4`
但这只适合最小原型。
如果目标是稳定学习空间方向与结构,优先使用 `WXYZ + IV`。
### 4.3 前端参数建议
为了最大化复用 BEATs 主干,推荐保持与 BEATs 接近的时频分辨率:
- sample rate:优先 `16k`
- mel bins:`128`
- frame length:`25 ms`
- frame shift:`10 ms`
原因:
- 这能让 trunk 看到与原始 BEATs 更接近的 patch 几何结构
- patch embedding 和后续序列长度更容易保持一致
- 预训练权重复用更稳定
### 4.4 为什么不沿用 Spatial-AST 的 binaural 前端
Spatial-AST 采用的是:
- 双耳 log-mel
- IPD
这适合 binaural,不适合直接迁移到 FOA。
FOA 下应优先利用:
- ambisonic 通道本身
- intensity vector
- 其他 FOA 物理特征
## 5. 对 BEATs 具体修改哪些模块
下面按模块说明修改方案。
### 5.1 保留不动的模块
建议尽量保留:
- `TransformerEncoder`
- `TransformerSentenceEncoderLayer`
- `MultiheadAttention`
- `conv_pos`
- `LayerNorm`
- `FFN`
- `post_extract_proj`
也就是 `backbone.py` 内的主干结构和 `BEATs.py` 中的 trunk 逻辑尽量不动。
### 5.2 必须修改的模块
必须重做:
1. `preprocess`
2. `patch_embedding`
3. `extract_features` 输出头部逻辑
4. 下游 `predictor`
### 5.3 推荐新增的模块
建议新增:
1. `SpatialBEATsPreprocessor`
2. `SpatialPatchEmbedding`
3. `SpatialQueryDecoder`
4. `SpatialPredictionHead`
5. `SpatialTokenProjector`
6. `HungarianMatcher`
7. `SpatialSetCriterion`
## 6. 代码级映射建议
### 6.1 现有文件建议
建议保留和复用:
- [BEATs.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/BEATs.py)
- [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py)
建议新增:
- `spatial_beats.py`
- `spatial_modules.py`
- `spatial_loss.py`
- `spatial_dataset.py`
- `train_spatial_beats.py`
### 6.2 `spatial_beats.py` 建议包含
建议实现:
- `SpatialBEATsConfig`
- `SpatialBEATs`
- `SpatialBEATs.extract_spatial_tokens()`
- `SpatialBEATs.forward()`
### 6.3 `spatial_modules.py` 建议包含
建议实现:
- `SpatialBEATsPreprocessor`
- `SpatialPatchEmbedding`
- `SpatialQueryDecoder`
- `SpatialPredictionHead`
- `SpatialTokenProjector`
### 6.4 `spatial_loss.py` 建议包含
建议实现:
- `HungarianMatcher`
- `SpatialSetCriterion`
## 7. 预训练权重如何复用
## 7.1 默认推荐权重
默认推荐:
- `BEATs_iter3+ (AS2M) pre-trained`
而不是:
- fine-tuned checkpoints
原因:
- `pre-trained` 更适合作为 trunk 初始化
- `fine-tuned` 更偏向 AudioSet 分类判别
- 你这里的 spatial encoder 应与原语义 encoder 职责分离
### 7.2 必须直接加载的层
这些层建议直接加载原 BEATs checkpoint:
- `post_extract_proj`
- `encoder.pos_conv`
- `encoder.layers.*`
- `encoder.layer_norm`
- `layer_norm`
即除了输入 stem 和输出头,主干参数都尽量继承。
### 7.3 需要特殊初始化的层
以下层因为 shape 不同,不能直接 strict load:
- `patch_embedding`
- 新增的 `query decoder`
- 新增的 `spatial heads`
- 新增的 `LLM projector`
### 7.4 新 patch embedding 的初始化策略
原 BEATs stem 是:
- `Conv2d(1, embed_dim, kernel_size=patch, stride=patch)`
新 stem 建议是:
- `Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)`
推荐初始化策略:
#### 方案 A:保守初始化,默认推荐
- `W_logmel` 通道继承原 stem 权重
- 其他空间通道初始化为 `0` 或较小随机值
优点:
- 最大程度保留原 BEATs 初始分布
- trunk 适配更稳
缺点:
- 训练初期空间通道利用较慢
#### 方案 B:通道 inflation
- 把原 stem 权重复制到全部输入通道
- 再按通道数做归一化
优点:
- 所有通道一开始都能进入主干
缺点:
- 初始统计更可能偏离原 BEATs
最终推荐:
- 第一版用 `方案 A`
- 后续做 ablation 再比较 `方案 B`
## 8. Spatial token 模块的最终设计
### 8.1 为什么不用全局池化
原始 BEATs 的输出方式更接近:
- patch sequence
- mean pooling
- clip-level prediction
这不适合多源空间任务。
### 8.2 最终推荐:Query Decoder
在 trunk 输出后新增:
- `K` 个 learnable source queries
- 一个轻量 `cross-attention decoder`
输入:
- encoder memory:`H in R^{B x T x D}`
- source queries:`Q in R^{B x K x D}`
输出:
- `Z in R^{B x K x D}`
这里的 `Z[:, i, :]` 即第 `i` 个 `source token`
### 8.3 为什么 query decoder 是当前最优解
它的优点:
- 不改 trunk 内部结构
- 仍然让完整 FOA 特征经过 backbone
- 适合多源 set prediction
- 最利于最大化复用 trunk 权重
## 9. 输出头设计
对每个 source token `z_i`,预测:
- `objectness`
- `azimuth`
- `elevation`
- `distance`
- 可选 `class_aux`
### 9.1 离散还是连续
第一版推荐全部使用离散分类头:
- `azimuth`: 360 bins
- `elevation`: 180 bins
- `distance`: 按数据分桶,例如 `0.5m` 一档
原因:
- 与已有 Spatial-AST/BAT 经验一致
- 分类头更稳
- 更便于构造离散坐标 embedding
### 9.2 objectness 头
推荐增加:
- `objectness_head: D -> 1`
用于:
- 判断当前 token 是否对应真实声源
- 作为 Hungarian matching 的一部分
- 推理时做 token 保留/裁剪
### 9.3 类别头
类别头建议作为:
- `auxiliary head`
而不是最终 LLM 的主要输入内容。
这样做的作用:
- 让 query token 更容易学会 source slot 对齐
- 但不把 Spatial-BEATs 变成第二个强语义 encoder
## 10. Loss 设计
推荐总损失:
```text
L_total =
lambda_obj * L_obj
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls_aux
```
### 10.1 匹配方式
使用 `Hungarian matching`:
- 预测:`K` 个 token
- GT:`N` 个 sources
- 成本由以下项构成:
- objectness cost
- azimuth cost
- elevation cost
- distance cost
- optional class cost
### 10.2 损失项定义
推荐:
- `L_obj`: BCE 或 focal loss
- `L_azi`: cross entropy
- `L_ele`: cross entropy
- `L_dist`: cross entropy
- `L_cls_aux`: cross entropy 或 BCE
### 10.3 初始 loss 权重建议
第一版建议从以下权重起步:
```text
lambda_obj = 1.0
lambda_azi = 2.0
lambda_ele = 2.0
lambda_dist = 1.0
lambda_cls = 0.25
```
解释:
- 方向任务通常更关键
- 距离次之
- objectness 必须稳定
- 类别监督只作为辅助
### 10.4 不建议的做法
第一版不建议:
- 重分类损失压倒空间损失
- 直接照搬 Spatial-AST 的 `1250 * cls`
原因:
- Spatial-AST 的目标之一是保住 sound event detection
- 这里 `Spatial-BEATs` 的主要目标是空间 token
- 原项目已有独立语义 encoder
## 11. 训练策略
### 11.1 第一阶段是否需要 SSL
当前最终结论:
- 第一版 **不需要** 重新做 BEATs 式 SSL
因为当前已经有:
- 多源监督
- 每个源的空间标注
- 可复用的 BEATs 主干预训练
所以第一阶段应优先做:
- `supervised multi-source spatial training`
### 11.2 分阶段训练建议
#### Stage A:Warmup
冻结:
- 大部分 trunk
只训练:
- FOA preprocessor
- patch embedding
- query decoder
- spatial heads
- LLM projector
目的:
- 让新输入 stem 和新输出头稳定接入 trunk
#### Stage B:Upper-trunk finetune
解冻:
- trunk 上层若干层
目的:
- 让主干逐步适应 FOA 空间任务
#### Stage C:Near-full finetune
进一步解冻:
- 更多 encoder layers
目的:
- 提升空间表示上限
### 11.3 学习率建议
推荐:
- trunk:较小 lr
- 新模块:较大学习率
例如:
```text
lr_trunk = 1e-5 ~ 5e-5
lr_new = 1e-4 ~ 5e-4
```
并配合:
- layer-wise lr decay
## 12. 最终输出给 LLM 的 spatial token 形式
这是本项目最关键的接口定义之一。
### 12.1 内部 token 形式
`Spatial-BEATs` 内部输出:
- `Z in R^{B x K x D}`
其中:
- `B`: batch size
- `K`: source token 数
- `D`: Spatial-BEATs hidden dim,建议与 BEATs trunk 一致
### 12.2 不建议直接把 raw logits 喂给 LLM
不建议直接给 LLM:
- azimuth logits
- elevation logits
- distance logits
- objectness logits
这些是监督头,不是最终模态表示。
### 12.3 最终推荐的 LLM spatial token 形式
最终推荐送给 LLM 的每个 token 形式为:
```text
s_i = Proj([z_i ; e_azi(i) ; e_ele(i) ; e_dist(i) ; e_obj(i)])
```
其中:
- `z_i`: query decoder 输出的 latent token
- `e_azi(i)`: 由预测 azimuth bin 查表得到的 embedding
- `e_ele(i)`: 由预测 elevation bin 查表得到的 embedding
- `e_dist(i)`: 由预测 distance bin 查表得到的 embedding
- `e_obj(i)`: 由 objectness/confidence 产生的 embedding
- `Proj`: 投影到 LLM hidden size 的 MLP/Linear
最终:
- `s_i in R^{d_llm}`
### 12.4 为什么采用“latent + structured embedding”的混合形式
原因:
1. `z_i` 保留丰富的隐式空间结构信息
2. `坐标 embedding` 给 LLM 显式离散空间线索
3. `confidence` 有助于 LLM 区分可靠/不可靠 token
这比单纯只传:
- raw latent token
或者只传:
- 显式坐标 one-hot / scalar
都更合适。
### 12.5 最终序列形式
送入 LLM 时推荐:
```text
<SPATIAL_START>, s_1, s_2, ..., s_K, <SPATIAL_END>
```
并且:
- 按 `objectness` 从高到低排序
- 对低置信 token 可直接截断或 mask
### 12.6 是否保留全部 K 个 token
默认推荐:
- 训练时保留全部 `K`
- 推理时按 `objectness` 过滤
例如:
- 保留前 `K_keep`
- 或保留 `obj > threshold` 的 token
## 13. 与原语义 audio encoder 的关系
为了避免“两个 encoder 在做同样的事”,推荐如下职责划分:
- 原语义 audio encoder:负责 `what`
- Spatial-BEATs:负责 `where / spatial structure / relations`
### 13.1 是否允许 Spatial-BEATs 学类别
允许,但只作为辅助。
建议:
- 类别头只用于训练
- 最终输入给 LLM 的空间 token 不直接暴露完整类别 logits
### 13.2 是否需要和语义 encoder 做对齐
第一版不是必须。
若后续希望更强的 source grounding,可进一步加入:
- semantic distillation
- cross-encoder alignment
- source-wise contrastive loss
但这些应放到第二阶段。
## 14. 第一版推荐配置
第一版默认建议:
- 输入特征:`WXYZ + IVxyz`
- `C_foa = 7`
- 采样率:`16k`
- mel bins:`128`
- patch 配置:与 BEATs 保持一致
- 预训练权重:`BEATs_iter3+ AS2M pre-trained`
- trunk:最大化加载
- patch stem:`W` 继承,其余通道小初始化
- 输出:`K` 个 source tokens
- token 解码:轻量 query decoder
- 监督:Hungarian matching + 多头空间分类
- LLM 输入:`latent + structured coordinate embedding` 的混合 token
## 15. 实现优先级
推荐按如下优先级推进:
1. 实现 `FOA preprocessor`
2. 实现多通道 `patch embedding`
3. 完成 trunk ckpt 加载
4. 实现 `query decoder`
5. 实现 `objectness / azi / ele / dist` heads
6. 实现 `Hungarian matcher + criterion`
7. 实现 `LLM projector`
8. 完成训练脚本
## 16. 当前仍需用户确认的问题
以下问题会直接影响第一版实现细节:
1. `FOA` 数据当前主要采样率是多少?是 `16k`、`24k`、`32k` 还是 `48k`?
2. 每个样本中 `最大同时源数` 大概是多少?这会影响 `K` 的默认设定。
3. 每个源是否都有 `source-level class label`?如果有,类别头和匹配会更稳。
4. 你希望 `distance` 是离散分类还是连续回归?当前默认推荐离散分类。
5. 下游 LLM 的 hidden size 是多少?是否已有固定的 audio token projector?
6. 你是否希望 Spatial-BEATs 在第一版就具备一定的 source semantic 辅助能力,还是严格只做空间?
## 17. 结论
当前最终方案已经明确:
- **完整 FOA 特征进入 BEATs 主干**
- **最大化复用 trunk 预训练**
- **重做输入 stem**
- **重做输出为多源 spatial tokens**
- **第一版采用监督式 set prediction**
- **最终给 LLM 的不是 raw logits,而是融合 latent 与坐标 embedding 的 spatial tokens**
这是当前最符合项目目标、也最稳妥的 `Spatial-BEATs` 方案。
|