File size: 23,961 Bytes
86cbd36 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 | # Spatial-BEATs Coding Guide
## 1. 本文档的作用
本文档是 `Spatial-BEATs` 的最终代码实施指南,用于直接指导后续代码开发。
它基于当前已经确认的项目约束:
- FOA 采样率统一到 `16 kHz`
- 每个样本最大同时声源数约为 `4`
- 每个声源都有稳定的 `source-level class label`
- `distance` 使用连续回归
- `Spatial-BEATs` 拥有自己的 `projector`
- 不需要与原始语义 audio encoder 做表示对齐
- 目标 `spatial token rate` 约为 `2.5 Hz`
- 允许增加 `source class auxiliary head`
本文档应视为后续实现的主参考。
## 2. 最终设计结论
### 2.1 总体目标
构建一个独立的 `Spatial-BEATs`:
- 输入完整 `FOA waveform`
- 从 `FOA` 中计算空间特征
- 将完整空间特征送入 `BEATs backbone`
- 输出可输入 LLM 的 `spatial tokens`
注意:
- 不是 `W-only`
- 不是外挂小 adapter
- 不是在原有语义 encoder 内部混合空间分支
而是:
- 一个独立的 `Spatial Encoder`
- 最大化复用 `BEATs trunk`
- 最终输出自己的空间 token 序列
### 2.2 关键实现原则
1. **完整 FOA 特征经过 BEATs 主干**
2. **尽量不改 BEATs trunk 内部 Transformer**
3. **重做输入 stem**
4. **重做输出头和 token 生成方式**
5. **主训练目标是多源空间建模,不是 clip-level 分类**
## 3. 最终模型架构
推荐最终架构如下:
```text
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T_f, F]
-> SpatialPatchEmbedding
-> BEATs trunk
-> Patch grid reshape
-> Temporal downsampler (to 2.5 Hz)
-> Slot query decoder
-> Source slot tokens [B, T_s, K, D]
-> Prediction heads
-> Spatial projector
-> LLM spatial tokens [B, N_keep, d_llm]
```
其中:
- `T_s` 是时间 token 数
- `K` 是每个时间步最大 source slot 数
- `D` 是 BEATs hidden dim
- `d_llm` 是 LLM hidden dim
## 4. 固定超参与默认取值
### 4.1 输入参数
- sample rate: `16000`
- mel bins: `128`
- frame length: `25 ms`
- frame shift: `10 ms`
### 4.2 token 相关参数
- token rate: `2.5 Hz`
- 对应时间间隔:`400 ms`
- 对于 `10 s` 样本:
- `T_s = 25`
### 4.3 source slot 参数
- 最大同时源数:`4`
- 默认 `K = 4`
说明:
- 第一版直接令 `K = 4`
- 不额外引入冗余 slot
- 如果后续发现数据中存在漏标、异常源或更复杂重叠,再考虑改成 `K = 5/6`
### 4.4 输入通道数
默认推荐:
- `W_logmel`
- `X_logmel`
- `Y_logmel`
- `Z_logmel`
- `IVx`
- `IVy`
- `IVz`
因此:
- `C_foa = 7`
## 5. 输入特征定义
### 5.1 推荐特征形式
第一版明确使用:
- `WXYZ log-mel`
- `IVx, IVy, IVz`
其中:
- `WXYZ` 提供 ambisonic 通道信息
- `IV` 提供显式方向 cue
### 5.2 IV 计算建议
建议在 STFT 域中计算 intensity vector,然后再映射到 mel 维:
```text
IVx ~ Re(conj(W) * X)
IVy ~ Re(conj(W) * Y)
IVz ~ Re(conj(W) * Z)
```
可再配合能量归一化:
```text
IV = IV / (|W|^2 + |X|^2 + |Y|^2 + |Z|^2 + eps)
```
实现时可以先得到频域 IV,再通过 mel filter bank 压到 `128` mel bins。
### 5.3 为什么不用 binaural IPD
当前任务是 `FOA`,不是 binaural。
Spatial-AST 的 `mel + IPD` 经验可借鉴其结构思路,但不能直接复用其输入表示。
本项目应优先使用:
- FOA 通道本身
- intensity vector
## 6. 对 BEATs 代码的具体改造
## 6.1 尽量保留的部分
建议完全复用:
- `TransformerEncoder`
- `TransformerSentenceEncoderLayer`
- `MultiheadAttention`
- `conv_pos`
- `post_extract_proj`
- trunk 中的 `LayerNorm / FFN / attention`
也就是说:
- [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py) 尽量不改
### 6.2 需要重写的部分
必须重写:
1. `preprocess`
2. `patch_embedding`
3. `extract_features` 的输出形式
4. 原始 `predictor`
### 6.3 推荐新增文件
建议新增如下文件:
- `spatial_beats.py`
- `spatial_modules.py`
- `spatial_loss.py`
- `spatial_dataset.py`
- `train_spatial_beats.py`
- 可选 `infer_spatial_beats.py`
## 7. 预训练权重复用方案
## 7.1 推荐 checkpoint
默认推荐:
- `BEATs_iter3+ (AS2M) pre-trained`
不推荐第一版直接用 fine-tuned checkpoint 作为 trunk 初始化。
### 7.2 直接加载的层
建议直接加载:
- `post_extract_proj`
- `encoder.pos_conv`
- `encoder.layers.*`
- `encoder.layer_norm`
- `layer_norm`
这些层使用:
- `strict=False`
并打印缺失与不匹配项。
### 7.3 不能直接加载的层
以下层需要新初始化:
- 新的 `patch_embedding`
- `temporal downsampler`
- `slot query decoder`
- `prediction heads`
- `spatial projector`
### 7.4 新 patch stem 的初始化
原始 BEATs stem:
```text
Conv2d(1, embed_dim, kernel_size=patch, stride=patch)
```
新的 stem:
```text
Conv2d(7, embed_dim, kernel_size=patch, stride=patch)
```
推荐初始化方案:
- `W_logmel` 通道继承原 BEATs stem 权重
- `X/Y/Z/IVx/IVy/IVz` 通道初始化为较小随机值
推荐做法:
```text
new_weight[:, 0, :, :] = old_weight[:, 0, :, :]
new_weight[:, 1:, :, :] ~ N(0, 0.02 * std(old_weight))
```
不推荐全部复制 inflation 作为默认方案。
第一版优先稳定,而不是让所有通道一开始等价共享单通道语义滤波器。
## 8. 代码结构建议
## 8.1 `spatial_modules.py`
建议包含以下模块:
### `SpatialBEATsPreprocessor`
职责:
- 输入 `FOA waveform [B, 4, T]`
- 计算:
- `WXYZ logmel`
- `IVx, IVy, IVz`
- 输出:
- `foa_feat [B, 7, T_f, 128]`
建议接口:
```python
class SpatialBEATsPreprocessor(nn.Module):
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
...
```
### `SpatialPatchEmbedding`
职责:
- 对 `foa_feat` 做多通道 patch embedding
建议接口:
```python
class SpatialPatchEmbedding(nn.Module):
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
# returns:
# tokens: [B, N_p, D_in]
# grid_hw: (T_p, F_p)
...
```
### `TemporalDownsampler`
职责:
- 将 trunk 输出从 patch 时间分辨率下采样到 `2.5 Hz`
建议输入输出:
- 输入:`grid memory [B, T_p, F_p, D]`
- 先对 `F_p` 做平均或轻量 attention pooling
- 得到:`temporal memory [B, T_p, D]`
- 再用线性插值或 1D conv 下采样到:
- `slot memory [B, T_s, D]`
建议接口:
```python
class TemporalDownsampler(nn.Module):
def forward(self, grid_x: torch.Tensor, target_steps: int) -> torch.Tensor:
# grid_x: [B, T_p, F_p, D]
# out: [B, T_s, D]
...
```
默认推荐:
- 第一版使用 `freq-mean + linear interpolate`
原因:
- 简单
- 稳定
- 容易调试
### `SlotQueryDecoder`
职责:
- 对每个时间步生成 `K=4` 个 source slots
推荐设计:
- 为每个 slot 准备一个 learnable `slot embedding`
- 将时间 token `m_t` 与 slot embedding 相加,形成初始 query
- query 对 trunk memory 做 cross-attention
建议输出:
- `slot_tokens [B, T_s, K, D]`
建议接口:
```python
class SlotQueryDecoder(nn.Module):
def forward(
self,
temporal_memory: torch.Tensor,
encoder_memory: torch.Tensor,
) -> torch.Tensor:
# temporal_memory: [B, T_s, D]
# encoder_memory: [B, N_p, D]
# out: [B, T_s, K, D]
...
```
实现建议:
- 先用 `temporal_memory` 生成时间条件 query
- 再用 `2 层 TransformerDecoderLayer` 或自定义 cross-attn block
第一版推荐:
- `2 层 decoder`
- hidden dim 与 trunk 一致
### `SpatialPredictionHead`
职责:
- 对 `slot_tokens` 预测各任务输出
建议输出:
- `pred_obj: [B, T_s, K]`
- `pred_azi_logits: [B, T_s, K, 360]`
- `pred_ele_logits: [B, T_s, K, 180]`
- `pred_dist: [B, T_s, K, 1]`
- `pred_class_logits: [B, T_s, K, C_cls]`
### `SpatialTokenProjector`
职责:
- 将 slot latent 与结构化坐标信息组合
- 投影到 LLM hidden size
输出:
- `llm_tokens [B, N_keep, d_llm]`
## 8.2 `spatial_beats.py`
建议定义:
### `SpatialBEATsConfig`
字段建议:
- `sample_rate=16000`
- `num_mel_bins=128`
- `token_rate=2.5`
- `max_sources=4`
- `foa_channels=7`
- `distance_max_m`
- `llm_hidden_size`
- `use_class_aux=True`
- `num_decoder_layers=2`
### `SpatialBEATs`
建议结构:
```python
class SpatialBEATs(nn.Module):
def __init__(self, cfg, beats_ckpt=None):
...
def extract_spatial_features(self, waveforms):
...
def extract_spatial_tokens(self, waveforms, audio_lengths=None):
...
def project_tokens_for_llm(self, slot_tokens, preds, keep_mask=None):
...
def forward(self, waveforms, audio_lengths=None, targets=None):
...
```
### `forward()` 推荐返回形式
返回字典:
```python
{
"encoder_memory": ...,
"slot_tokens": ...,
"pred_obj": ...,
"pred_azi_logits": ...,
"pred_ele_logits": ...,
"pred_dist": ...,
"pred_class_logits": ...,
"llm_tokens": ...,
"llm_token_mask": ...,
"token_meta": ...,
}
```
## 8.3 `spatial_loss.py`
建议定义:
### `HungarianMatcher`
输入:
- 预测输出
- GT targets
输出:
- 每个样本每个时间步的匹配索引
### `SpatialSetCriterion`
计算:
- objectness loss
- azimuth loss
- elevation loss
- distance regression loss
- class auxiliary loss
可选:
- temporal smoothness loss
## 8.4 `spatial_dataset.py`
建议数据格式:
```python
sample = {
"waveform": FloatTensor[4, T],
"duration_s": float,
"sources": [
{
"class_id": int,
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
"start_s": float,
"end_s": float,
"is_time_weak": bool,
"is_position_dynamic": bool,
"trajectory": optional,
},
...
]
}
```
当前推荐额外保留以下字段:
- `start_s`: 源开始进入场景的时间
- `end_s`: 若有则保留,否则可由 `start_s + length_s` 推出
- `length_s`: 原始 source clip 时长
- `is_time_weak`: 当前时间边界是否只是弱监督
- `is_position_dynamic`: 该源位置是否随时间变化
- `trajectory`: 若位置变化,则存储分段轨迹或逐帧轨迹
如果当前只有源进入时间和原始 source length,可统一转成:
- `start_s = start time`
- `end_s = start_s + length_s`
- `is_time_weak = True`
## 9. 时间建模、2.5 Hz 输出与弱时间监督
这是当前实现中最关键的新增约束之一。
### 9.1 token rate 的解释
这里定义:
- 每个 `source slot stream` 的输出速率为 `2.5 Hz`
即:
- 每个时间步间隔 `400 ms`
### 9.2 输出张量形状
对时长为 `L` 秒的样本:
```text
T_s = round(L * 2.5)
```
例如:
- `10 s -> 25`
最终 slot token 形状:
```text
[B, T_s, K, D]
```
在当前默认配置下:
- `K = 4`
也就是最多 `4` 条并行 source slot 流。
### 9.3 当前时间标注的含义
你当前可提供的时间信息是:
- 知道每个 source 的 `start time`
- 知道原始 `FSD50K source clip` 的 `length`
- 但这段 `length` 内不保证每一时刻都真的 active
因此,当前不应把 `[start_s, end_s]` 视为严格逐帧真值,而应视为:
- `weak temporal support window`
也就是:
- 源最可能出现的候选时间范围
- 不是精确的逐帧 activity annotation
### 9.4 第一版如何把弱时间标注映射到 2.5 Hz token
对于第 `t` 个时间步,其中心时刻记为 `tau_t`。
对每个 GT source,定义:
- `candidate window = [start_s, end_s]`
第一版推荐构造三种 mask:
1. `pos_window_mask`
- `tau_t` 落在 `[start_s, end_s]` 内
2. `neg_window_mask`
- `tau_t` 明确落在窗口外
3. `ignore_mask`
- 可选,用于窗口边界附近或不确定区域
当前默认建议最简单实现:
- 窗口外:作为 objectness 负样本
- 窗口内:作为弱正样本候选
但不要对窗口内所有步都施加强监督位置 loss。
### 9.5 当前 loss 需要怎么改
为了适应弱时间标注,当前建议把 loss 拆成两层:
#### `L_obj`
- 对窗口外,正常做负样本监督
- 对窗口内,做弱正样本监督
推荐:
- 使用 `BCE` 或 `focal loss`
- 窗口内正样本权重降低
例如:
```text
w_obj_pos_weak = 0.3 ~ 0.5
w_obj_neg = 1.0
```
#### `L_azi / L_ele / L_dist / L_cls`
第一版不要在窗口内所有时间步都强制监督。
推荐只在以下位置监督:
- 与 GT source 匹配且 `pred_obj` 较高的 slot
- 或窗口内的 top-k 高置信时间步
更稳的第一版做法:
- 先对每个 GT source 在窗口内选择 `top-1` 或 `top-2` 个 objectness 最高的时间步参与坐标/类别监督
这样可以避免:
- 源在窗口内部分时间其实不 active
- 但模型被错误惩罚
### 9.6 推荐的第一版弱监督匹配策略
当前建议采用两阶段匹配:
1. 先按时间窗口过滤候选时间步
2. 再在候选时间步内做 slot matching
更具体地说:
- 对每个 GT source,只允许匹配其时间窗口内的 `slot tokens`
- 在这些候选中选出最优 `(t, k)`
这比直接对所有 `[T_s, K]` 位置做全局 Hungarian 更稳。
第一版推荐:
- `per-source best-of-window matching`
而不是:
- 全局 dense set matching
原因:
- 你当前时间标注是弱的
- 先用窗口约束大幅降低匹配歧义更现实
### 9.7 推理时 token 序列长度不需要改变
`2.5 Hz` 的 token rate 不需要变。
要改的是:
- 训练 supervision 的构造方式
- objectness 与坐标 loss 的作用范围
### 9.8 未来如果拿到更好的 activity 标注
如果后续可以拿到:
- energy-based active mask
- frame-level source activity
- VAD / source activation probability
则可把当前的弱时间监督替换成:
- `strong temporal supervision`
到时只需替换 target 构造和 criterion,不需要改主模型结构。
### 9.9 喂给 LLM 时的 token 数量
如果全部展开,理论最大 token 速率为:
```text
2.5 Hz * 4 = 10 tokens / second
```
但推理时可通过 `objectness` 做过滤,所以通常会低于这个上限。
### 9.10 LLM 展开顺序
建议按如下顺序展开:
- 先按时间排序
- 每个时间步内部按 `objectness` 从高到低排序
也就是:
```text
t1_s1, t1_s2, t1_s3, t1_s4, t2_s1, t2_s2, ...
```
然后再过滤低置信 slot。
## 10. 输出给 LLM 的 spatial token 形式
## 10.1 不直接喂原始 logits
不建议直接把:
- 方位分类 logits
- 类别 logits
- 距离标量
直接作为 token 输入给 LLM。
### 10.2 推荐 token 构造方式
每个 slot token `z_{t,k}` 最终形成一个结构化 token:
```text
s_{t,k} = Proj([z_{t,k} ; c_{t,k} ; u_{t,k} ; d_{t,k} ; o_{t,k}])
```
其中:
- `z_{t,k}`: slot latent
- `c_{t,k}`: source class context embedding
- `u_{t,k}`: 方向向量
- `d_{t,k}`: 连续距离 embedding
- `o_{t,k}`: objectness/confidence embedding
### 10.3 各项具体建议
#### `c_{t,k}`: 类别上下文
由 `pred_class_logits` 构造:
```text
p_cls = softmax(pred_class_logits)
c = p_cls @ E_cls
```
其中:
- `E_cls` 是一个可学习类别 embedding 表
作用:
- 给 spatial token 少量语义 grounding
- 但不需要与原始 audio encoder 对齐
#### `u_{t,k}`: 方向向量
先由:
- `pred_azi_logits`
- `pred_ele_logits`
得到预测角度,再转换成单位球坐标向量:
```text
u = [x, y, z]
```
推荐实现:
- 训练时用分布期望或 soft-argmax
- 推理时可用 argmax
#### `d_{t,k}`: 连续距离表示
由于 distance 是连续回归,建议:
- 对 `pred_dist` 做归一化
- 再经一个小 MLP 变成 embedding
#### `o_{t,k}`: 置信度表示
由 `pred_obj` 经 sigmoid 得到 objectness,再做小 MLP 映射。
### 10.4 projector 的最终作用
`SpatialTokenProjector` 的任务是把:
- slot latent
- class context
- direction vector
- distance embedding
- objectness embedding
融合并投影到:
- `d_llm`
输出:
```text
llm_tokens: [B, N_keep, d_llm]
```
### 10.5 是否需要与原 audio encoder 对齐
当前结论:
- **不需要**
因此:
- 这个 projector 完全独立训练
- 只服务于 `Spatial-BEATs -> LLM`
## 11. Loss 设计
## 11.1 任务头与 loss
推荐 loss 组成:
```text
L_total =
lambda_obj * L_obj
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls
```
### 11.2 各项定义
- `L_obj`: BCE 或 focal loss,支持弱正样本权重
- `L_azi`: cross entropy
- `L_ele`: cross entropy
- `L_dist`: SmoothL1 / Huber
- `L_cls`: cross entropy
### 11.3 distance 回归的实现
由于你已经明确要连续回归,推荐:
- head 输出 `pred_dist_norm in [0, 1]`
- 再乘以 `distance_max_m`
训练时使用:
- `SmoothL1Loss(pred_dist_norm, gt_dist_norm)`
优点:
- 比直接回归未归一化距离更稳
- 比 MSE 更抗异常值
### 11.4 推荐初始权重
建议第一版从以下权重起步:
```text
lambda_obj = 1.0
lambda_azi = 2.0
lambda_ele = 2.0
lambda_dist = 1.0
lambda_cls = 0.5
```
这里把 `class auxiliary` 权重从之前建议的 `0.25` 提到 `0.5`,因为现在你已经确认:
- 每个源有稳定的 `source-level class label`
- Spatial-BEATs 自身保留一定语义信息是可接受的
### 11.5 当前版本的匹配方式修订
由于当前时间 supervision 是弱的,第一版不建议直接做:
- 全局 `Hungarian matching` over `[T_s, K]`
更推荐:
1. 先根据 `source time window` 过滤候选时间步
2. 再在候选窗口内做匹配
推荐实现:
- `window-constrained matching`
可选两种方式:
#### 方案 A:推荐默认方案
- 对每个 GT source
- 在其 window 内所有 `(t, k)` 候选中,选择 cost 最小的一对
这本质上是:
- `best-of-window assignment`
优点:
- 简单
- 稳定
- 对弱时间监督更友好
#### 方案 B:后续增强方案
- 对每个时间步分别做 Hungarian
- 再加时间连续性正则
这更适合将来位置随时间变化时使用。
### 11.6 匹配 cost
Hungarian matching cost 建议:
```text
cost =
w_obj * cost_obj
+ w_azi * cost_azi
+ w_ele * cost_ele
+ w_dist * cost_dist
+ w_cls * cost_cls
```
推荐初值:
```text
w_obj = 1.0
w_azi = 2.0
w_ele = 2.0
w_dist = 1.0
w_cls = 1.0
```
## 12. 训练策略
### 12.1 第一阶段是否需要 SSL
当前明确结论:
- 第一版 **不做新的 BEATs 式 SSL**
理由:
- 已有空间 GT
- 已有 source class GT
- 已有强 trunk 预训练
- 当前主要目标是空间结构建模
### 12.2 推荐训练阶段
#### Stage A: Warmup
冻结:
- trunk 大部分层
训练:
- preprocessor
- patch stem
- temporal downsampler
- slot query decoder
- prediction heads
- projector
#### Stage B: Upper-trunk finetune
解冻:
- trunk 上层若干层
#### Stage C: Wider finetune
逐步解冻更多层,直到性能稳定。
### 12.3 训练时建议增加的正则项
当前位置在 clip 内固定,因此建议增加:
- `temporal consistency loss`
具体可对同一 source 在相邻时间步的预测加约束:
- objectness 平滑
- azimuth/elevation 分布平滑
- distance 平滑
第一版可选实现:
```text
L_temp =
smooth(pred_obj_t, pred_obj_{t+1})
+ smooth(pred_dist_t, pred_dist_{t+1})
```
由于当前位置固定,这类正则通常有利于稳定训练。
### 12.4 学习率建议
推荐:
```text
lr_trunk = 1e-5 ~ 5e-5
lr_new = 1e-4 ~ 5e-4
```
并使用:
- weight decay
- warmup
- layer-wise lr decay
## 13. 训练与推理输出格式
## 13.1 训练时 `forward()` 输出
建议 `forward()` 返回:
```python
{
"slot_tokens": FloatTensor[B, T_s, K, D],
"pred_obj": FloatTensor[B, T_s, K],
"pred_azi_logits": FloatTensor[B, T_s, K, 360],
"pred_ele_logits": FloatTensor[B, T_s, K, 180],
"pred_dist": FloatTensor[B, T_s, K, 1],
"pred_class_logits": FloatTensor[B, T_s, K, C_cls],
"llm_tokens": FloatTensor[B, N_keep, d_llm],
"llm_token_mask": BoolTensor[B, N_keep],
"token_meta": dict,
}
```
### 13.2 推理时建议额外输出
建议额外输出:
- `pred_azi_deg`
- `pred_ele_deg`
- `pred_dist_m`
- `pred_obj_prob`
- `pred_class_id`
便于后续可视化和调试。
## 14. 最小实现顺序
建议严格按以下顺序实现:
1. 写 `SpatialBEATsPreprocessor`
2. 写 `SpatialPatchEmbedding`
3. 完成 trunk checkpoint 加载
4. 写 `TemporalDownsampler`
5. 写 `SlotQueryDecoder`
6. 写 `SpatialPredictionHead`
7. 写 `SpatialTokenProjector`
8. 写 `HungarianMatcher`
9. 写 `SpatialSetCriterion`
10. 写 dataset 和训练脚本
## 15. 未来支持“位置随时间变化”时需要改什么
你已经说明:
- 当前 clip 内位置固定
- 后续会加入随时间变化的位置
这意味着当前模型结构基本可保留,但 target 和 decoder 训练方式需要升级。
### 15.1 当前结构哪些不用改
以下部分未来仍可直接保留:
- `FOA preprocessor`
- `patch embedding`
- `BEATs trunk`
- `TemporalDownsampler`
- `SlotQueryDecoder`
- `SpatialTokenProjector`
- `2.5 Hz` token rate
### 15.2 未来必须改的部分
未来位置动态化后,需要改:
1. `dataset target format`
2. `matching strategy`
3. `loss supervision`
### 15.3 数据结构怎么升级
当前静态位置:
```python
{
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
}
```
未来动态位置建议升级为:
```python
{
"trajectory": [
{
"time_s": float,
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
},
...
]
}
```
或直接存成与 `2.5 Hz` 对齐的逐步 target:
```python
{
"traj_azi_deg": FloatTensor[T_s],
"traj_ele_deg": FloatTensor[T_s],
"traj_dist_m": FloatTensor[T_s],
"traj_valid_mask": BoolTensor[T_s],
}
```
### 15.4 匹配怎么升级
当前位置固定时:
- `per-source best-of-window matching`
未来位置变化时:
- 更适合改为 `per-time-step matching`
- 或 `track-level matching`
推荐未来版本:
- 每个 source 对应一条 slot track
- 在整个时间维上维持 slot identity
### 15.5 loss 怎么升级
未来动态位置时:
- `L_azi / L_ele / L_dist` 应按时间步计算
- `temporal consistency loss` 不能再强制“位置恒定”
- 应改成“速度平滑”或“轨迹平滑”
也就是从:
- `constant-position regularization`
升级成:
- `trajectory smoothness regularization`
### 15.6 代码层面建议现在就预留的接口
为了兼容未来动态位置,当前第一版建议在数据与 loss 接口里预留:
- `is_position_dynamic`
- `trajectory`
- `traj_valid_mask`
即使第一版不用,也建议把字段和分支接口预留出来。
## 16. 当前仍需要确认的问题
虽然核心方案已经足够落地,但还有一个关键问题最好在编码前确认:
当前核心方案已经足够编码。
如果后续继续推进,唯一还值得尽早确认的是:
- 是否能从原始 source waveform 自动提取更精细的 energy/activity mask
如果可以,第一版的弱时间监督会明显更稳。
## 17. 结论
当前可以直接进入代码实现的最终方案是:
- `16k FOA`
- `WXYZ + IV`
- `K=4`
- `2.5 Hz` slot token streams
- `distance` 连续回归
- `class auxiliary head` 开启
- `BEATs_iter3+ AS2M pre-trained` 作为 trunk 初始化
- `Spatial-BEATs` 拥有自己的 projector
- 最终输出自己的 LLM spatial tokens
- 当前时间 supervision 按 `weak temporal window` 处理
- 当前位置 supervision 按 `clip-level fixed position` 处理
- 未来动态位置仅需升级 target/matching/loss,不需要重写主干结构
这份文档已经足够作为第一版实现蓝图使用。
|