File size: 23,961 Bytes
86cbd36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
# Spatial-BEATs Coding Guide

## 1. 本文档的作用

本文档是 `Spatial-BEATs` 的最终代码实施指南,用于直接指导后续代码开发。

它基于当前已经确认的项目约束:

- FOA 采样率统一到 `16 kHz`
- 每个样本最大同时声源数约为 `4`
- 每个声源都有稳定的 `source-level class label`
- `distance` 使用连续回归
- `Spatial-BEATs` 拥有自己的 `projector`
- 不需要与原始语义 audio encoder 做表示对齐
- 目标 `spatial token rate` 约为 `2.5 Hz`
- 允许增加 `source class auxiliary head`

本文档应视为后续实现的主参考。

## 2. 最终设计结论

### 2.1 总体目标

构建一个独立的 `Spatial-BEATs`- 输入完整 `FOA waveform`
-`FOA` 中计算空间特征
- 将完整空间特征送入 `BEATs backbone`
- 输出可输入 LLM 的 `spatial tokens`

注意:

- 不是 `W-only`
- 不是外挂小 adapter
- 不是在原有语义 encoder 内部混合空间分支

而是:

- 一个独立的 `Spatial Encoder`
- 最大化复用 `BEATs trunk`
- 最终输出自己的空间 token 序列

### 2.2 关键实现原则

1. **完整 FOA 特征经过 BEATs 主干**
2. **尽量不改 BEATs trunk 内部 Transformer**
3. **重做输入 stem**
4. **重做输出头和 token 生成方式**
5. **主训练目标是多源空间建模,不是 clip-level 分类**

## 3. 最终模型架构

推荐最终架构如下:

```text
FOA waveform [B, 4, T]
  -> SpatialBEATsPreprocessor
  -> FOA feature map [B, C_foa, T_f, F]
  -> SpatialPatchEmbedding
  -> BEATs trunk
  -> Patch grid reshape
  -> Temporal downsampler (to 2.5 Hz)
  -> Slot query decoder
  -> Source slot tokens [B, T_s, K, D]
  -> Prediction heads
  -> Spatial projector
  -> LLM spatial tokens [B, N_keep, d_llm]
```

其中:

- `T_s` 是时间 token 数
- `K` 是每个时间步最大 source slot 数
- `D` 是 BEATs hidden dim
- `d_llm` 是 LLM hidden dim

## 4. 固定超参与默认取值

### 4.1 输入参数

- sample rate: `16000`
- mel bins: `128`
- frame length: `25 ms`
- frame shift: `10 ms`

### 4.2 token 相关参数

- token rate: `2.5 Hz`
- 对应时间间隔:`400 ms`
- 对于 `10 s` 样本:
  - `T_s = 25`

### 4.3 source slot 参数

- 最大同时源数:`4`
- 默认 `K = 4`

说明:

- 第一版直接令 `K = 4`
- 不额外引入冗余 slot
- 如果后续发现数据中存在漏标、异常源或更复杂重叠,再考虑改成 `K = 5/6`

### 4.4 输入通道数

默认推荐:

- `W_logmel`
- `X_logmel`
- `Y_logmel`
- `Z_logmel`
- `IVx`
- `IVy`
- `IVz`

因此:

- `C_foa = 7`

## 5. 输入特征定义

### 5.1 推荐特征形式

第一版明确使用:

- `WXYZ log-mel`
- `IVx, IVy, IVz`

其中:

- `WXYZ` 提供 ambisonic 通道信息
- `IV` 提供显式方向 cue

### 5.2 IV 计算建议

建议在 STFT 域中计算 intensity vector,然后再映射到 mel 维:

```text
IVx ~ Re(conj(W) * X)
IVy ~ Re(conj(W) * Y)
IVz ~ Re(conj(W) * Z)
```

可再配合能量归一化:

```text
IV = IV / (|W|^2 + |X|^2 + |Y|^2 + |Z|^2 + eps)
```

实现时可以先得到频域 IV,再通过 mel filter bank 压到 `128` mel bins。

### 5.3 为什么不用 binaural IPD

当前任务是 `FOA`,不是 binaural。

Spatial-AST 的 `mel + IPD` 经验可借鉴其结构思路,但不能直接复用其输入表示。

本项目应优先使用:

- FOA 通道本身
- intensity vector

## 6. 对 BEATs 代码的具体改造

## 6.1 尽量保留的部分

建议完全复用:

- `TransformerEncoder`
- `TransformerSentenceEncoderLayer`
- `MultiheadAttention`
- `conv_pos`
- `post_extract_proj`
- trunk 中的 `LayerNorm / FFN / attention`

也就是说:

- [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py) 尽量不改

### 6.2 需要重写的部分

必须重写:

1. `preprocess`
2. `patch_embedding`
3. `extract_features` 的输出形式
4. 原始 `predictor`

### 6.3 推荐新增文件

建议新增如下文件:

- `spatial_beats.py`
- `spatial_modules.py`
- `spatial_loss.py`
- `spatial_dataset.py`
- `train_spatial_beats.py`
- 可选 `infer_spatial_beats.py`

## 7. 预训练权重复用方案

## 7.1 推荐 checkpoint

默认推荐:

- `BEATs_iter3+ (AS2M) pre-trained`

不推荐第一版直接用 fine-tuned checkpoint 作为 trunk 初始化。

### 7.2 直接加载的层

建议直接加载:

- `post_extract_proj`
- `encoder.pos_conv`
- `encoder.layers.*`
- `encoder.layer_norm`
- `layer_norm`

这些层使用:

- `strict=False`

并打印缺失与不匹配项。

### 7.3 不能直接加载的层

以下层需要新初始化:

- 新的 `patch_embedding`
- `temporal downsampler`
- `slot query decoder`
- `prediction heads`
- `spatial projector`

### 7.4 新 patch stem 的初始化

原始 BEATs stem:

```text
Conv2d(1, embed_dim, kernel_size=patch, stride=patch)
```

新的 stem:

```text
Conv2d(7, embed_dim, kernel_size=patch, stride=patch)
```

推荐初始化方案:

- `W_logmel` 通道继承原 BEATs stem 权重
- `X/Y/Z/IVx/IVy/IVz` 通道初始化为较小随机值

推荐做法:

```text
new_weight[:, 0, :, :] = old_weight[:, 0, :, :]
new_weight[:, 1:, :, :] ~ N(0, 0.02 * std(old_weight))
```

不推荐全部复制 inflation 作为默认方案。  
第一版优先稳定,而不是让所有通道一开始等价共享单通道语义滤波器。

## 8. 代码结构建议

## 8.1 `spatial_modules.py`

建议包含以下模块:

### `SpatialBEATsPreprocessor`

职责:

- 输入 `FOA waveform [B, 4, T]`
- 计算:
  - `WXYZ logmel`
  - `IVx, IVy, IVz`
- 输出:
  - `foa_feat [B, 7, T_f, 128]`

建议接口:

```python
class SpatialBEATsPreprocessor(nn.Module):
    def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
        ...
```

### `SpatialPatchEmbedding`

职责:

- 对 `foa_feat` 做多通道 patch embedding

建议接口:

```python
class SpatialPatchEmbedding(nn.Module):
    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
        # returns:
        # tokens: [B, N_p, D_in]
        # grid_hw: (T_p, F_p)
        ...
```

### `TemporalDownsampler`

职责:

- 将 trunk 输出从 patch 时间分辨率下采样到 `2.5 Hz`

建议输入输出:

- 输入:`grid memory [B, T_p, F_p, D]`
- 先对 `F_p` 做平均或轻量 attention pooling
- 得到:`temporal memory [B, T_p, D]`
- 再用线性插值或 1D conv 下采样到:
  - `slot memory [B, T_s, D]`

建议接口:

```python
class TemporalDownsampler(nn.Module):
    def forward(self, grid_x: torch.Tensor, target_steps: int) -> torch.Tensor:
        # grid_x: [B, T_p, F_p, D]
        # out: [B, T_s, D]
        ...
```

默认推荐:

- 第一版使用 `freq-mean + linear interpolate`

原因:

- 简单
- 稳定
- 容易调试

### `SlotQueryDecoder`

职责:

- 对每个时间步生成 `K=4` 个 source slots

推荐设计:

- 为每个 slot 准备一个 learnable `slot embedding`
- 将时间 token `m_t` 与 slot embedding 相加,形成初始 query
- query 对 trunk memory 做 cross-attention

建议输出:

- `slot_tokens [B, T_s, K, D]`

建议接口:

```python
class SlotQueryDecoder(nn.Module):
    def forward(
        self,
        temporal_memory: torch.Tensor,
        encoder_memory: torch.Tensor,
    ) -> torch.Tensor:
        # temporal_memory: [B, T_s, D]
        # encoder_memory: [B, N_p, D]
        # out: [B, T_s, K, D]
        ...
```

实现建议:

- 先用 `temporal_memory` 生成时间条件 query
- 再用 `2 层 TransformerDecoderLayer` 或自定义 cross-attn block

第一版推荐:

- `2 层 decoder`
- hidden dim 与 trunk 一致

### `SpatialPredictionHead`

职责:

-`slot_tokens` 预测各任务输出

建议输出:

- `pred_obj: [B, T_s, K]`
- `pred_azi_logits: [B, T_s, K, 360]`
- `pred_ele_logits: [B, T_s, K, 180]`
- `pred_dist: [B, T_s, K, 1]`
- `pred_class_logits: [B, T_s, K, C_cls]`

### `SpatialTokenProjector`

职责:

- 将 slot latent 与结构化坐标信息组合
- 投影到 LLM hidden size

输出:

- `llm_tokens [B, N_keep, d_llm]`

## 8.2 `spatial_beats.py`

建议定义:

### `SpatialBEATsConfig`

字段建议:

- `sample_rate=16000`
- `num_mel_bins=128`
- `token_rate=2.5`
- `max_sources=4`
- `foa_channels=7`
- `distance_max_m`
- `llm_hidden_size`
- `use_class_aux=True`
- `num_decoder_layers=2`

### `SpatialBEATs`

建议结构:

```python
class SpatialBEATs(nn.Module):
    def __init__(self, cfg, beats_ckpt=None):
        ...

    def extract_spatial_features(self, waveforms):
        ...

    def extract_spatial_tokens(self, waveforms, audio_lengths=None):
        ...

    def project_tokens_for_llm(self, slot_tokens, preds, keep_mask=None):
        ...

    def forward(self, waveforms, audio_lengths=None, targets=None):
        ...
```

### `forward()` 推荐返回形式

返回字典:

```python
{
    "encoder_memory": ...,
    "slot_tokens": ...,
    "pred_obj": ...,
    "pred_azi_logits": ...,
    "pred_ele_logits": ...,
    "pred_dist": ...,
    "pred_class_logits": ...,
    "llm_tokens": ...,
    "llm_token_mask": ...,
    "token_meta": ...,
}
```

## 8.3 `spatial_loss.py`

建议定义:

### `HungarianMatcher`

输入:

- 预测输出
- GT targets

输出:

- 每个样本每个时间步的匹配索引

### `SpatialSetCriterion`

计算:

- objectness loss
- azimuth loss
- elevation loss
- distance regression loss
- class auxiliary loss

可选:

- temporal smoothness loss

## 8.4 `spatial_dataset.py`

建议数据格式:

```python
sample = {
    "waveform": FloatTensor[4, T],
    "duration_s": float,
    "sources": [
        {
            "class_id": int,
            "azimuth_deg": float,
            "elevation_deg": float,
            "distance_m": float,
            "start_s": float,
            "end_s": float,
            "is_time_weak": bool,
            "is_position_dynamic": bool,
            "trajectory": optional,
        },
        ...
    ]
}
```

当前推荐额外保留以下字段:

- `start_s`: 源开始进入场景的时间
- `end_s`: 若有则保留,否则可由 `start_s + length_s` 推出
- `length_s`: 原始 source clip 时长
- `is_time_weak`: 当前时间边界是否只是弱监督
- `is_position_dynamic`: 该源位置是否随时间变化
- `trajectory`: 若位置变化,则存储分段轨迹或逐帧轨迹

如果当前只有源进入时间和原始 source length,可统一转成:

- `start_s = start time`
- `end_s = start_s + length_s`
- `is_time_weak = True`

## 9. 时间建模、2.5 Hz 输出与弱时间监督

这是当前实现中最关键的新增约束之一。

### 9.1 token rate 的解释

这里定义:

- 每个 `source slot stream` 的输出速率为 `2.5 Hz`

即:

- 每个时间步间隔 `400 ms`

### 9.2 输出张量形状

对时长为 `L` 秒的样本:

```text
T_s = round(L * 2.5)
```

例如:

- `10 s -> 25`

最终 slot token 形状:

```text
[B, T_s, K, D]
```

在当前默认配置下:

- `K = 4`

也就是最多 `4` 条并行 source slot 流。

### 9.3 当前时间标注的含义

你当前可提供的时间信息是:

- 知道每个 source 的 `start time`
- 知道原始 `FSD50K source clip``length`
- 但这段 `length` 内不保证每一时刻都真的 active

因此,当前不应把 `[start_s, end_s]` 视为严格逐帧真值,而应视为:

- `weak temporal support window`

也就是:

- 源最可能出现的候选时间范围
- 不是精确的逐帧 activity annotation

### 9.4 第一版如何把弱时间标注映射到 2.5 Hz token

对于第 `t` 个时间步,其中心时刻记为 `tau_t`。

对每个 GT source,定义:

- `candidate window = [start_s, end_s]`

第一版推荐构造三种 mask:

1. `pos_window_mask`
   - `tau_t` 落在 `[start_s, end_s]`2. `neg_window_mask`
   - `tau_t` 明确落在窗口外
3. `ignore_mask`
   - 可选,用于窗口边界附近或不确定区域

当前默认建议最简单实现:

- 窗口外:作为 objectness 负样本
- 窗口内:作为弱正样本候选

但不要对窗口内所有步都施加强监督位置 loss。

### 9.5 当前 loss 需要怎么改

为了适应弱时间标注,当前建议把 loss 拆成两层:

#### `L_obj`

- 对窗口外,正常做负样本监督
- 对窗口内,做弱正样本监督

推荐:

- 使用 `BCE` 或 `focal loss`
- 窗口内正样本权重降低

例如:

```text
w_obj_pos_weak = 0.3 ~ 0.5
w_obj_neg = 1.0
```

#### `L_azi / L_ele / L_dist / L_cls`

第一版不要在窗口内所有时间步都强制监督。  
推荐只在以下位置监督:

- 与 GT source 匹配且 `pred_obj` 较高的 slot
- 或窗口内的 top-k 高置信时间步

更稳的第一版做法:

- 先对每个 GT source 在窗口内选择 `top-1` 或 `top-2` 个 objectness 最高的时间步参与坐标/类别监督

这样可以避免:

- 源在窗口内部分时间其实不 active
- 但模型被错误惩罚

### 9.6 推荐的第一版弱监督匹配策略

当前建议采用两阶段匹配:

1. 先按时间窗口过滤候选时间步
2. 再在候选时间步内做 slot matching

更具体地说:

- 对每个 GT source,只允许匹配其时间窗口内的 `slot tokens`
- 在这些候选中选出最优 `(t, k)`

这比直接对所有 `[T_s, K]` 位置做全局 Hungarian 更稳。

第一版推荐:

- `per-source best-of-window matching`

而不是:

- 全局 dense set matching

原因:

- 你当前时间标注是弱的
- 先用窗口约束大幅降低匹配歧义更现实

### 9.7 推理时 token 序列长度不需要改变

`2.5 Hz` 的 token rate 不需要变。  
要改的是:

- 训练 supervision 的构造方式
- objectness 与坐标 loss 的作用范围

### 9.8 未来如果拿到更好的 activity 标注

如果后续可以拿到:

- energy-based active mask
- frame-level source activity
- VAD / source activation probability

则可把当前的弱时间监督替换成:

- `strong temporal supervision`

到时只需替换 target 构造和 criterion,不需要改主模型结构。

### 9.9 喂给 LLM 时的 token 数量

如果全部展开,理论最大 token 速率为:

```text
2.5 Hz * 4 = 10 tokens / second
```

但推理时可通过 `objectness` 做过滤,所以通常会低于这个上限。

### 9.10 LLM 展开顺序

建议按如下顺序展开:

- 先按时间排序
- 每个时间步内部按 `objectness` 从高到低排序

也就是:

```text
t1_s1, t1_s2, t1_s3, t1_s4, t2_s1, t2_s2, ...
```

然后再过滤低置信 slot。

## 10. 输出给 LLM 的 spatial token 形式

## 10.1 不直接喂原始 logits

不建议直接把:

- 方位分类 logits
- 类别 logits
- 距离标量

直接作为 token 输入给 LLM。

### 10.2 推荐 token 构造方式

每个 slot token `z_{t,k}` 最终形成一个结构化 token:

```text
s_{t,k} = Proj([z_{t,k} ; c_{t,k} ; u_{t,k} ; d_{t,k} ; o_{t,k}])
```

其中:

- `z_{t,k}`: slot latent
- `c_{t,k}`: source class context embedding
- `u_{t,k}`: 方向向量
- `d_{t,k}`: 连续距离 embedding
- `o_{t,k}`: objectness/confidence embedding

### 10.3 各项具体建议

#### `c_{t,k}`: 类别上下文

由 `pred_class_logits` 构造:

```text
p_cls = softmax(pred_class_logits)
c = p_cls @ E_cls
```

其中:

- `E_cls` 是一个可学习类别 embedding 表

作用:

- 给 spatial token 少量语义 grounding
- 但不需要与原始 audio encoder 对齐

#### `u_{t,k}`: 方向向量

先由:

- `pred_azi_logits`
- `pred_ele_logits`

得到预测角度,再转换成单位球坐标向量:

```text
u = [x, y, z]
```

推荐实现:

- 训练时用分布期望或 soft-argmax
- 推理时可用 argmax

#### `d_{t,k}`: 连续距离表示

由于 distance 是连续回归,建议:

-`pred_dist` 做归一化
- 再经一个小 MLP 变成 embedding

#### `o_{t,k}`: 置信度表示

由 `pred_obj` 经 sigmoid 得到 objectness,再做小 MLP 映射。

### 10.4 projector 的最终作用

`SpatialTokenProjector` 的任务是把:

- slot latent
- class context
- direction vector
- distance embedding
- objectness embedding

融合并投影到:

- `d_llm`

输出:

```text
llm_tokens: [B, N_keep, d_llm]
```

### 10.5 是否需要与原 audio encoder 对齐

当前结论:

- **不需要**

因此:

- 这个 projector 完全独立训练
- 只服务于 `Spatial-BEATs -> LLM`

## 11. Loss 设计

## 11.1 任务头与 loss

推荐 loss 组成:

```text
L_total =
  lambda_obj  * L_obj
  + lambda_azi  * L_azi
  + lambda_ele  * L_ele
  + lambda_dist * L_dist
  + lambda_cls  * L_cls
```

### 11.2 各项定义

- `L_obj`: BCE 或 focal loss,支持弱正样本权重
- `L_azi`: cross entropy
- `L_ele`: cross entropy
- `L_dist`: SmoothL1 / Huber
- `L_cls`: cross entropy

### 11.3 distance 回归的实现

由于你已经明确要连续回归,推荐:

- head 输出 `pred_dist_norm in [0, 1]`
- 再乘以 `distance_max_m`

训练时使用:

- `SmoothL1Loss(pred_dist_norm, gt_dist_norm)`

优点:

- 比直接回归未归一化距离更稳
- 比 MSE 更抗异常值

### 11.4 推荐初始权重

建议第一版从以下权重起步:

```text
lambda_obj  = 1.0
lambda_azi  = 2.0
lambda_ele  = 2.0
lambda_dist = 1.0
lambda_cls  = 0.5
```

这里把 `class auxiliary` 权重从之前建议的 `0.25` 提到 `0.5`,因为现在你已经确认:

- 每个源有稳定的 `source-level class label`
- Spatial-BEATs 自身保留一定语义信息是可接受的

### 11.5 当前版本的匹配方式修订

由于当前时间 supervision 是弱的,第一版不建议直接做:

- 全局 `Hungarian matching` over `[T_s, K]`

更推荐:

1. 先根据 `source time window` 过滤候选时间步
2. 再在候选窗口内做匹配

推荐实现:

- `window-constrained matching`

可选两种方式:

#### 方案 A:推荐默认方案

- 对每个 GT source
- 在其 window 内所有 `(t, k)` 候选中,选择 cost 最小的一对

这本质上是:

- `best-of-window assignment`

优点:

- 简单
- 稳定
- 对弱时间监督更友好

#### 方案 B:后续增强方案

- 对每个时间步分别做 Hungarian
- 再加时间连续性正则

这更适合将来位置随时间变化时使用。

### 11.6 匹配 cost

Hungarian matching cost 建议:

```text
cost =
  w_obj  * cost_obj
  + w_azi  * cost_azi
  + w_ele  * cost_ele
  + w_dist * cost_dist
  + w_cls  * cost_cls
```

推荐初值:

```text
w_obj  = 1.0
w_azi  = 2.0
w_ele  = 2.0
w_dist = 1.0
w_cls  = 1.0
```

## 12. 训练策略

### 12.1 第一阶段是否需要 SSL

当前明确结论:

- 第一版 **不做新的 BEATs 式 SSL**

理由:

- 已有空间 GT
- 已有 source class GT
- 已有强 trunk 预训练
- 当前主要目标是空间结构建模

### 12.2 推荐训练阶段

#### Stage A: Warmup

冻结:

- trunk 大部分层

训练:

- preprocessor
- patch stem
- temporal downsampler
- slot query decoder
- prediction heads
- projector

#### Stage B: Upper-trunk finetune

解冻:

- trunk 上层若干层

#### Stage C: Wider finetune

逐步解冻更多层,直到性能稳定。

### 12.3 训练时建议增加的正则项

当前位置在 clip 内固定,因此建议增加:

- `temporal consistency loss`

具体可对同一 source 在相邻时间步的预测加约束:

- objectness 平滑
- azimuth/elevation 分布平滑
- distance 平滑

第一版可选实现:

```text
L_temp =
  smooth(pred_obj_t, pred_obj_{t+1})
  + smooth(pred_dist_t, pred_dist_{t+1})
```

由于当前位置固定,这类正则通常有利于稳定训练。

### 12.4 学习率建议

推荐:

```text
lr_trunk = 1e-5 ~ 5e-5
lr_new   = 1e-4 ~ 5e-4
```

并使用:

- weight decay
- warmup
- layer-wise lr decay

## 13. 训练与推理输出格式

## 13.1 训练时 `forward()` 输出

建议 `forward()` 返回:

```python
{
    "slot_tokens": FloatTensor[B, T_s, K, D],
    "pred_obj": FloatTensor[B, T_s, K],
    "pred_azi_logits": FloatTensor[B, T_s, K, 360],
    "pred_ele_logits": FloatTensor[B, T_s, K, 180],
    "pred_dist": FloatTensor[B, T_s, K, 1],
    "pred_class_logits": FloatTensor[B, T_s, K, C_cls],
    "llm_tokens": FloatTensor[B, N_keep, d_llm],
    "llm_token_mask": BoolTensor[B, N_keep],
    "token_meta": dict,
}
```

### 13.2 推理时建议额外输出

建议额外输出:

- `pred_azi_deg`
- `pred_ele_deg`
- `pred_dist_m`
- `pred_obj_prob`
- `pred_class_id`

便于后续可视化和调试。

## 14. 最小实现顺序

建议严格按以下顺序实现:

1.`SpatialBEATsPreprocessor`
2.`SpatialPatchEmbedding`
3. 完成 trunk checkpoint 加载
4.`TemporalDownsampler`
5.`SlotQueryDecoder`
6.`SpatialPredictionHead`
7.`SpatialTokenProjector`
8.`HungarianMatcher`
9.`SpatialSetCriterion`
10. 写 dataset 和训练脚本

## 15. 未来支持“位置随时间变化”时需要改什么

你已经说明:

- 当前 clip 内位置固定
- 后续会加入随时间变化的位置

这意味着当前模型结构基本可保留,但 target 和 decoder 训练方式需要升级。

### 15.1 当前结构哪些不用改

以下部分未来仍可直接保留:

- `FOA preprocessor`
- `patch embedding`
- `BEATs trunk`
- `TemporalDownsampler`
- `SlotQueryDecoder`
- `SpatialTokenProjector`
- `2.5 Hz` token rate

### 15.2 未来必须改的部分

未来位置动态化后,需要改:

1. `dataset target format`
2. `matching strategy`
3. `loss supervision`

### 15.3 数据结构怎么升级

当前静态位置:

```python
{
    "azimuth_deg": float,
    "elevation_deg": float,
    "distance_m": float,
}
```

未来动态位置建议升级为:

```python
{
    "trajectory": [
        {
            "time_s": float,
            "azimuth_deg": float,
            "elevation_deg": float,
            "distance_m": float,
        },
        ...
    ]
}
```

或直接存成与 `2.5 Hz` 对齐的逐步 target:

```python
{
    "traj_azi_deg": FloatTensor[T_s],
    "traj_ele_deg": FloatTensor[T_s],
    "traj_dist_m": FloatTensor[T_s],
    "traj_valid_mask": BoolTensor[T_s],
}
```

### 15.4 匹配怎么升级

当前位置固定时:

- `per-source best-of-window matching`

未来位置变化时:

- 更适合改为 `per-time-step matching`
-`track-level matching`

推荐未来版本:

- 每个 source 对应一条 slot track
- 在整个时间维上维持 slot identity

### 15.5 loss 怎么升级

未来动态位置时:

- `L_azi / L_ele / L_dist` 应按时间步计算
- `temporal consistency loss` 不能再强制“位置恒定”
- 应改成“速度平滑”或“轨迹平滑”

也就是从:

- `constant-position regularization`

升级成:

- `trajectory smoothness regularization`

### 15.6 代码层面建议现在就预留的接口

为了兼容未来动态位置,当前第一版建议在数据与 loss 接口里预留:

- `is_position_dynamic`
- `trajectory`
- `traj_valid_mask`

即使第一版不用,也建议把字段和分支接口预留出来。

## 16. 当前仍需要确认的问题

虽然核心方案已经足够落地,但还有一个关键问题最好在编码前确认:

当前核心方案已经足够编码。  
如果后续继续推进,唯一还值得尽早确认的是:

- 是否能从原始 source waveform 自动提取更精细的 energy/activity mask

如果可以,第一版的弱时间监督会明显更稳。

## 17. 结论

当前可以直接进入代码实现的最终方案是:

- `16k FOA`
- `WXYZ + IV`
- `K=4`
- `2.5 Hz` slot token streams
- `distance` 连续回归
- `class auxiliary head` 开启
- `BEATs_iter3+ AS2M pre-trained` 作为 trunk 初始化
- `Spatial-BEATs` 拥有自己的 projector
- 最终输出自己的 LLM spatial tokens
- 当前时间 supervision 按 `weak temporal window` 处理
- 当前位置 supervision 按 `clip-level fixed position` 处理
- 未来动态位置仅需升级 target/matching/loss,不需要重写主干结构

这份文档已经足够作为第一版实现蓝图使用。