Spatial-BEATs 简化版实现文档
1. 文档目的
本文档给出当前推荐的 Spatial-BEATs 简化版实现方案。
这份方案基于一个更务实的判断:
- 真正重要的是让
FOA -> BEATs trunk学到稳定的空间表征 - 后面的模块只需要承担
readout / decode / supervision的作用 - 不需要一开始就引入复杂的
slot query decoder - 最终给 LLM 的 token 应尽量来自前面的空间 embedding,而不是最终任务头输出
- 当前阶段需要先支持
encoder-only training
因此,本方案不再以“内部 4 个 slots + decoder 聚合”作为主线,而改为:
FOA waveform
-> Spatial preprocessor
-> Multi-channel patch embedding
-> BEATs trunk
-> Temporal readout
-> Spatial embeddings at 2.5 Hz
-> Fixed-slot prediction heads
-> Projector
-> LLM spatial tokens
2. 当前确定的约束
- 采样率:
16 kHz - 输入:
FOA waveform - 多源数据:有
- 最大同时源数:约
4 - 每个源有稳定
source-level class label - source vocabulary:
/apdcephfs_cq12/share_302080740/user/schmittzhu/data/fsd50k/FSD50K.ground_truth/final_vocabulary.csv - source class count:当前默认
65 - 距离预测:
连续回归 - mel 前端参数:对齐
Qwen-2.5-Omniaudio tower 的底层配置sample_rate=16000num_mel_bins=128n_fft=400win_length=400hop_length=160dither=0.0
- 时间 supervision:
弱时间窗口 - 当前位置:
clip 内固定 - 未来会扩展到:
位置随时间变化 - 目标输出 token rate:
2.5 Hz - 对于任意 clip,第
i个样本的有效 token 数是T_s_i = round(duration_i * 2.5) - 对于
10 sclip,T_s_i = 25 - batch 内部按
T_s_max = max_i T_s_i做 padding - 主干初始化:
BEATs_iter3+ AS2M pre-trained - 当前第一阶段:优先支持只训练 encoder 的监督方案
3. 方案核心思想
3.1 什么是主角
主角是:
BEATs trunk
它负责从 FOA 空间特征中学习空间表征。
3.2 什么是配角
配角是:
- temporal readout
- prediction heads
- projector
这些模块的作用只是:
- 从 trunk 特征中“读出”空间信息
- 建立 loss 回传路径
- 把空间 embedding 投影到 LLM 接口
3.3 给 LLM 的 token 从哪里来
最终给 LLM 的 token 应该来自:
- trunk 后
- 或 trunk 后再经过一层很浅的 temporal readout 之后
而不是来自:
- 最终 logits
- 复杂 decoder 的末端输出
4. 最终结构总览
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T_f, F]
-> SpatialPatchEmbedding
-> patch tokens [B, N_p, D_in]
-> BEATs trunk
-> encoder memory [B, N_p, D]
-> reshape / frequency pooling
-> temporal tokens [B, T_s_max, D]
-> shallow temporal readout
-> spatial embeddings [B, T_s_max, D]
-> prediction heads
-> Spatial projector
-> llm spatial tokens [B, T_s_max, d_llm]
其中:
- 每个样本:
T_s_i = round(duration_i * 2.5)
- 一个 batch 内:
T_s_max = max_i T_s_i
- 对
10 s输入:T_s_i = 25
5. 从 FOA 到 spatial token 的完整过程
5.1 输入层
输入:
waveform: [B, 4, T]
四个通道分别是:
WXYZ
对 10 s, 16kHz:
T = 160000
5.2 SpatialBEATsPreprocessor
目标:
- 把原始 FOA 波形转成多通道空间特征图
推荐输出通道:
W_logmelX_logmelY_logmelZ_logmelIVxIVyIVz
因此:
C_foa = 7
内部步骤:
- 对
WXYZ做 STFT - 计算各通道
log-mel - 计算
IVx, IVy, IVz - 将
IV映射到 mel 维 - 拼接成特征图
输出:
foa_feat: [B, 7, T_f, 128]
其中:
128是 mel binsT_f约等于1000,如果帧移为10 ms
5.3 SpatialPatchEmbedding
目标:
- 把
7通道 FOA 特征图切成 patch token
原始 BEATs 的 stem 是单通道:
Conv2d(1, embed_dim, kernel_size=patch, stride=patch)
新的 stem 改成:
Conv2d(7, embed_dim, kernel_size=patch, stride=patch)
输入:
foa_feat: [B, 7, T_f, 128]
输出:
patch_tokens: [B, N_p, D_in]grid_hw = (T_p, F_p)
建议:
D_in = 512
原因:
- 与 BEATs patch embedding dim 对齐
5.4 BEATs trunk
这是整个模型最重要的部分。
保留并复用的模块包括:
- patch token layer norm
post_extract_projdropout_inputTransformerEncoderconv_pos- 所有 transformer layers
- LayerNorm / FFN / attention
输入:
patch_tokens: [B, N_p, 512]
输出:
encoder_memory: [B, N_p, 768]
建议:
- 主干 hidden dim 保持
768 - 直接加载
BEATs_iter3+ AS2M pre-trained
5.5 Reshape + Frequency Pooling
目标:
- 把 trunk 输出变成可读的时间序列表示
步骤:
- 将
encoder_memory [B, N_p, D]reshape 成:grid_memory [B, T_p, F_p, D]
- 对
F_p维做 pooling
推荐第一版:
mean pooling over frequency
输出:
temporal_patch_tokens: [B, T_p, D]
这一步的意义是:
- 先把频率信息压到时间轴上
- 便于后面构造固定
2.5 Hz的时序空间 token
5.6 Temporal Resampler
目标:
- 把 patch 级时间序列压到目标 token rate
输入:
temporal_patch_tokens: [B, T_p, D]
输出:
temporal_tokens: [B, T_s, D]
其中:
- 第
i个样本:T_s_i = round(duration_i * 2.5)
- batch padding 后:
temporal_tokens: [B, T_s_max, D]
- 对
10 sclip:T_s_i = 25
推荐第一版:
- 线性插值
- 或轻量
Conv1d下采样
注意:
2.5 Hz明确指最终给 LLM 的 token rate- 对单个
10 s样本,10 s -> 25个有效 token - 对 mixed-length batch,张量会 pad 到
T_s_max
5.7 Shallow Temporal Readout
目标:
- 在时间维上再做一层轻量整理
- 让 trunk 输出更适合做空间监督和给 LLM 使用
推荐做法:
1~2层 transformer encoder
输入:
temporal_tokens: [B, T_s_max, 768]
输出:
spatial_embeddings: [B, T_s_max, 768]
这一层的作用是:
- 做一个浅层 readout neck
- 不承担复杂检测器职责
- 不强调 decoder 身份
- 只是把 trunk 表征整理成更干净的时间步级空间 embedding
如果你想先做最简单版本,也可以直接:
temporal_tokens -> LayerNorm -> spatial_embeddings
把 shallow transformer 作为后续增强项。
5.8 Prediction Heads
目标:
- 从
spatial_embeddings上接显式监督 - 通过 loss 让前面的 trunk 真正学到空间表征
- 在不引入复杂 decoder 的前提下提供多源监督出口
输入:
spatial_embeddings: [B, T_s_max, 768]
5.8.1 Encoder-only 阶段的固定槽位 readout
虽然简化版不再使用复杂 slot query decoder,但当前仍然有:
- 最大同时源数约为
4
因此,encoder-only 训练阶段推荐在 spatial_embeddings 后接一个很轻的固定槽位 readout:
spatial_embeddings [B, T_s_max, 768]
-> Linear / MLP expand
-> slot_latents [B, T_s_max, 4, H]
-> shared prediction heads
推荐默认:
H = 768
最简单实现:
Linear(768 -> 4 * 768)
reshape -> [B, 25, 4, 768]
更稳一点的实现:
Linear(768 -> 768)
-> GELU
-> Linear(768 -> 4 * 768)
reshape -> [B, 25, 4, 768]
这一步的作用不是做复杂目标解析,而只是:
- 给每个时间步提供
4个固定 source 槽位 - 让多源监督有明确着力点
- 把 loss 稳定回传到前面的 trunk
5.8.2 预测头设计
对 slot_latents [B, 25, 4, H] 接共享或独立 heads。
建议输出头:
activity / objectness headazimuth headelevation headdistance headclass auxiliary head
输出形式建议:
pred_activity: [B, T_s_max, 4]pred_azi_logits: [B, T_s_max, 4, 360]pred_ele_logits: [B, T_s_max, 4, 180]pred_dist: [B, T_s_max, 4, 1]pred_class_logits: [B, T_s_max, 4, C_cls]
其中:
pred_activity负责当前时间步当前槽位是否解释某个源pred_azi_logits/pred_ele_logits负责方向pred_dist负责连续距离回归pred_class_logits负责 source-level 辅助类别监督
当前默认建议:
C_cls = 65- label vocabulary 来自:
final_vocabulary.csv
- 推荐字段:
label_idfinal_label
5.8.3 为什么这里仍然保留 K=4
在简化版中:
K=4不再体现在复杂 decoder 结构里- 但仍然通过固定槽位 readout 出现在监督头中
也就是说:
- 不需要复杂 query-based object slots
- 但仍然需要一个简单的多槽位 readout 来承载多源标签
这更符合当前“先把主干训练出来”的目标。
5.9 Spatial Projector
目标:
- 把
spatial_embeddings投影到 LLM hidden size
输入:
spatial_embeddings: [B, T_s_max, 768]
推荐:
- 独立的
MLP projector
形式:
Linear(768 -> D_mid)
-> GELU
-> LayerNorm
-> Linear(D_mid -> d_llm)
输出:
llm_spatial_tokens: [B, T_s_max, d_llm]
这就是最终喂给 LLM 的 spatial tokens。
5.9.1 Encoder-only 阶段 projector 的角色
当前第一阶段如果只训练 encoder,本质目标是:
- 用监督把
spatial_embeddings训好
因此 projector 在这一阶段有两种合理策略:
方案 A:先不训练 projector,推荐默认
只训练:
- preprocessor
- patch embedding
- BEATs trunk
- temporal readout
- fixed-slot prediction heads
projector 只保留接口,不参与训练
方案 B:一并训练 projector
- 当你希望尽早固定 LLM 接口维度时可以启用
但第一版默认推荐:
- 先把 encoder 训好,再训练 projector
6. 每层的输入输出总结
6.1 SpatialBEATsPreprocessor
- 输入:
[B, 4, T] - 输出:
[B, 7, T_f, 128]
6.2 SpatialPatchEmbedding
- 输入:
[B, 7, T_f, 128] - 输出:
[B, N_p, 512]
6.3 BEATs trunk
- 输入:
[B, N_p, 512] - 输出:
[B, N_p, 768]
6.4 Reshape + Frequency Pooling
- 输入:
[B, N_p, 768] - 输出:
[B, T_p, 768]
6.5 Temporal Resampler
- 输入:
[B, T_p, 768] - 输出:
[B, 25, 768]
6.6 Shallow Temporal Readout
- 输入:
[B, T_s_max, 768] - 输出:
[B, T_s_max, 768]
6.7 Prediction Heads
- 输入:
[B, T_s_max, 768] - 先经
Linear / MLP expand变成:[B, T_s_max, 4, H] - 输出:
activity [B, T_s_max, 4]azimuth [B, T_s_max, 4, 360]elevation [B, T_s_max, 4, 180]distance [B, T_s_max, 4, 1]class [B, T_s_max, 4, C_cls]
6.8 Spatial Projector
- 输入:
[B, T_s_max, 768] - 输出:
[B, T_s_max, d_llm]
7. Loss 如何作用到前面的表征
这个简化方案的关键点在于:
- 后面的 heads 并不是模型重点
- 它们只是为了接监督
loss 从这些 heads 回传后,会更新:
- readout neck
- temporal resampler
- BEATs trunk
- patch embedding
- FOA preprocessor
因此:
- 只要后面的 decode/head 能稳定预测显式空间信息
- 前面的 trunk 就会被训练成空间 encoder
推荐 loss
当前建议:
L_activityL_aziL_eleL_distL_cls_auxL_temp
总损失:
L_total =
lambda_act * L_activity
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls_aux
+ lambda_temp * L_temp
当前可继续保留:
- 方向分类
- 仰角分类
- 距离连续回归
- 辅助类别监督
- 时间一致性正则
7.1 Encoder-only 训练阶段的 loss 定义
L_activity
用于监督:
- 当前时间步当前槽位是否激活
建议:
BCEWithLogitsLoss或focal loss
结合当前的弱时间窗口 supervision:
- 窗口外:负样本
- 窗口内:弱正样本
L_cls_aux
用于监督:
- 被分配到某个 GT source 的槽位类别
建议:
CrossEntropyLoss
L_azi 和 L_ele
用于监督:
- 槽位的方向分类
建议:
CrossEntropyLossazimuth:360binselevation:180bins
L_dist
用于监督:
- 槽位的连续距离回归
建议:
- 先将距离归一化到
[0, 1] - 使用
SmoothL1Loss
L_temp
由于当前位置当前是 clip 内固定,建议加入:
- 时间一致性约束
例如对同一 source 在相邻时间步对应的槽位加:
- class 分布平滑
- direction 分布平滑
- distance 平滑
第一版可以先只加:
- distance smoothness
- activity smoothness
7.2 Encoder-only 训练时的匹配方式
因为当前 heads 是固定 4 槽位,而不是 query decoder,所以建议:
- 使用轻量匹配
- 不必依赖复杂 decoder 结构
推荐策略:
- 对每个 GT source,根据其时间窗口筛出候选时间步
- 在该时间步的
4个固定槽位中,选择 cost 最小的槽位 - 对该槽位施加:
- activity
- class
- azi
- ele
- dist
推荐 cost:
cost =
w_act * cost_act
+ w_cls * cost_cls
+ w_azi * cost_azi
+ w_ele * cost_ele
+ w_dist * cost_dist
第一版可以不做全局 Hungarian,直接做:
per-step fixed-slot matching
如果后面发现多源冲突严重,再升级 matching。
8. 关于多源 supervision 的理解
虽然当前结构没有显式 slot query decoder,但这并不等于完全不考虑多源。
当前更合理的理解是:
- 让
spatial_embeddings [B, 25, D]表示该时间步的空间场景表征 - 再通过固定
4槽位 readout head 承载多源标签 - supervision 用这些多源标签约束 trunk 必须编码多个源的空间信息
这相当于:
- 先学一个强的 time-step level spatial embedding
- 再决定是否需要升级成更复杂的 object-centric / query-based 版本
这是更稳的第一版路径。
9. 为什么这个简化版更合适当前阶段
9.1 更符合老师的建议
老师的核心意见可以总结为:
- 不必执着于 decoder 结构本身
- 后面的模块只要能 decode 出空间监督即可
- 真正要训好的是前面的 trunk
这个简化版完全符合这个思路。
9.2 工程复杂度更低
不需要一开始就实现:
- slot query decoder
- cross-attention decoder
- objectness pooling
- 复杂 slot matching
9.3 更利于先验证 trunk 是否真的学到空间表征
如果这版能成功:
- 说明 trunk + temporal readout 本身已经足够表达空间信息
如果这版不够:
- 再升级到 slot/object 版本
这样路线更清晰。
10. 推荐实现版本
V1:最小可行版
FOA -> preprocessor -> patch embed -> BEATs trunk
-> frequency pooling -> temporal resampler
-> LayerNorm
-> spatial embeddings
-> Linear expand to 4 slots
-> heads
V2:推荐正式版
FOA -> preprocessor -> patch embed -> BEATs trunk
-> frequency pooling -> temporal resampler
-> 1~2 层 shallow transformer readout
-> spatial embeddings
-> Linear / MLP expand to 4 slots
-> heads
-> projector
V3:后续增强版
如果未来发现:
- 多源关系建模仍然不足
- 动态轨迹下表达不够
再升级为:
- slot-based decoder
- query-based object-centric readout
11. 推荐代码划分
建议新增或保留以下文件:
spatial_beats.pyspatial_modules.pyspatial_loss.pyspatial_dataset.pytrain_spatial_beats.py
spatial_modules.py
建议包含:
SpatialBEATsPreprocessorSpatialPatchEmbeddingTemporalResamplerTemporalReadoutTransformerFixedSlotReadoutHeadSpatialPredictionHeadsSpatialProjector
spatial_beats.py
建议主类:
SpatialBEATsConfigSpatialBEATs
主类中建议暴露:
extract_spatial_embeddings()project_for_llm()forward()
12. 最终输出接口
推荐 forward() 返回:
{
"encoder_memory": FloatTensor[B, N_p, 768],
"temporal_tokens": FloatTensor[B, 25, 768],
"spatial_embeddings": FloatTensor[B, 25, 768],
"slot_latents": FloatTensor[B, 25, 4, H],
"pred_activity": FloatTensor[B, 25, 4],
"pred_azi_logits": FloatTensor[B, 25, 4, 360],
"pred_ele_logits": FloatTensor[B, 25, 4, 180],
"pred_dist": FloatTensor[B, 25, 4, 1],
"pred_class_logits": FloatTensor[B, 25, 4, C_cls],
"llm_spatial_tokens": FloatTensor[B, 25, d_llm],
}
其中:
spatial_embeddings是最核心的中间表示slot_latents只是 encoder-only 监督出口llm_spatial_tokens是最终给 LLM 的接口
13. 结论
当前推荐的简化版最终架构是:
FOA -> spatial features -> BEATs trunk -> 2.5Hz temporal spatial embeddings -> 固定4槽位 readout heads -> projector
这套方案的重点是:
- 用显式空间监督把前面的
BEATs trunk训练成空间 encoder - 后面的 fixed-slot head 只承担监督和 readout 的职责
- 给 LLM 的 token 直接来自前面的
spatial_embeddings
这是当前最适合进入代码实现的主线方案。