Spatial-BEATs / docs /spatial_beats_simplified_implementation.md
dieKarotte's picture
Add files using upload-large-folder tool
86cbd36 verified
|
Raw
History Blame Contribute Delete
17 kB

Spatial-BEATs 简化版实现文档

1. 文档目的

本文档给出当前推荐的 Spatial-BEATs 简化版实现方案。

这份方案基于一个更务实的判断:

  • 真正重要的是让 FOA -> BEATs trunk 学到稳定的空间表征
  • 后面的模块只需要承担 readout / decode / supervision 的作用
  • 不需要一开始就引入复杂的 slot query decoder
  • 最终给 LLM 的 token 应尽量来自前面的空间 embedding,而不是最终任务头输出
  • 当前阶段需要先支持 encoder-only training

因此,本方案不再以“内部 4 个 slots + decoder 聚合”作为主线,而改为:

FOA waveform
  -> Spatial preprocessor
  -> Multi-channel patch embedding
  -> BEATs trunk
  -> Temporal readout
  -> Spatial embeddings at 2.5 Hz
  -> Fixed-slot prediction heads
  -> Projector
  -> LLM spatial tokens

2. 当前确定的约束

  • 采样率:16 kHz
  • 输入:FOA waveform
  • 多源数据:有
  • 最大同时源数:约 4
  • 每个源有稳定 source-level class label
  • source vocabulary:/apdcephfs_cq12/share_302080740/user/schmittzhu/data/fsd50k/FSD50K.ground_truth/final_vocabulary.csv
  • source class count:当前默认 65
  • 距离预测:连续回归
  • mel 前端参数:对齐 Qwen-2.5-Omni audio tower 的底层配置
    • sample_rate=16000
    • num_mel_bins=128
    • n_fft=400
    • win_length=400
    • hop_length=160
    • dither=0.0
  • 时间 supervision:弱时间窗口
  • 当前位置:clip 内固定
  • 未来会扩展到:位置随时间变化
  • 目标输出 token rate:2.5 Hz
  • 对于任意 clip,第 i 个样本的有效 token 数是 T_s_i = round(duration_i * 2.5)
  • 对于 10 s clip,T_s_i = 25
  • batch 内部按 T_s_max = max_i T_s_i 做 padding
  • 主干初始化:BEATs_iter3+ AS2M pre-trained
  • 当前第一阶段:优先支持只训练 encoder 的监督方案

3. 方案核心思想

3.1 什么是主角

主角是:

  • BEATs trunk

它负责从 FOA 空间特征中学习空间表征。

3.2 什么是配角

配角是:

  • temporal readout
  • prediction heads
  • projector

这些模块的作用只是:

  • 从 trunk 特征中“读出”空间信息
  • 建立 loss 回传路径
  • 把空间 embedding 投影到 LLM 接口

3.3 给 LLM 的 token 从哪里来

最终给 LLM 的 token 应该来自:

  • trunk 后
  • 或 trunk 后再经过一层很浅的 temporal readout 之后

而不是来自:

  • 最终 logits
  • 复杂 decoder 的末端输出

4. 最终结构总览

FOA waveform [B, 4, T]
  -> SpatialBEATsPreprocessor
  -> FOA feature map [B, C_foa, T_f, F]
  -> SpatialPatchEmbedding
  -> patch tokens [B, N_p, D_in]
  -> BEATs trunk
  -> encoder memory [B, N_p, D]
  -> reshape / frequency pooling
  -> temporal tokens [B, T_s_max, D]
  -> shallow temporal readout
  -> spatial embeddings [B, T_s_max, D]
  -> prediction heads
  -> Spatial projector
  -> llm spatial tokens [B, T_s_max, d_llm]

其中:

  • 每个样本:
    • T_s_i = round(duration_i * 2.5)
  • 一个 batch 内:
    • T_s_max = max_i T_s_i
  • 10 s 输入:
    • T_s_i = 25

5. 从 FOA 到 spatial token 的完整过程

5.1 输入层

输入:

  • waveform: [B, 4, T]

四个通道分别是:

  • W
  • X
  • Y
  • Z

10 s, 16kHz

  • T = 160000

5.2 SpatialBEATsPreprocessor

目标:

  • 把原始 FOA 波形转成多通道空间特征图

推荐输出通道:

  • W_logmel
  • X_logmel
  • Y_logmel
  • Z_logmel
  • IVx
  • IVy
  • IVz

因此:

  • C_foa = 7

内部步骤:

  1. WXYZ 做 STFT
  2. 计算各通道 log-mel
  3. 计算 IVx, IVy, IVz
  4. IV 映射到 mel 维
  5. 拼接成特征图

输出:

  • foa_feat: [B, 7, T_f, 128]

其中:

  • 128 是 mel bins
  • T_f 约等于 1000,如果帧移为 10 ms

5.3 SpatialPatchEmbedding

目标:

  • 7 通道 FOA 特征图切成 patch token

原始 BEATs 的 stem 是单通道:

Conv2d(1, embed_dim, kernel_size=patch, stride=patch)

新的 stem 改成:

Conv2d(7, embed_dim, kernel_size=patch, stride=patch)

输入:

  • foa_feat: [B, 7, T_f, 128]

输出:

  • patch_tokens: [B, N_p, D_in]
  • grid_hw = (T_p, F_p)

建议:

  • D_in = 512

原因:

  • 与 BEATs patch embedding dim 对齐

5.4 BEATs trunk

这是整个模型最重要的部分。

保留并复用的模块包括:

  • patch token layer norm
  • post_extract_proj
  • dropout_input
  • TransformerEncoder
  • conv_pos
  • 所有 transformer layers
  • LayerNorm / FFN / attention

输入:

  • patch_tokens: [B, N_p, 512]

输出:

  • encoder_memory: [B, N_p, 768]

建议:

  • 主干 hidden dim 保持 768
  • 直接加载 BEATs_iter3+ AS2M pre-trained

5.5 Reshape + Frequency Pooling

目标:

  • 把 trunk 输出变成可读的时间序列表示

步骤:

  1. encoder_memory [B, N_p, D] reshape 成:
    • grid_memory [B, T_p, F_p, D]
  2. F_p 维做 pooling

推荐第一版:

  • mean pooling over frequency

输出:

  • temporal_patch_tokens: [B, T_p, D]

这一步的意义是:

  • 先把频率信息压到时间轴上
  • 便于后面构造固定 2.5 Hz 的时序空间 token

5.6 Temporal Resampler

目标:

  • 把 patch 级时间序列压到目标 token rate

输入:

  • temporal_patch_tokens: [B, T_p, D]

输出:

  • temporal_tokens: [B, T_s, D]

其中:

  • i 个样本:
    • T_s_i = round(duration_i * 2.5)
  • batch padding 后:
    • temporal_tokens: [B, T_s_max, D]
  • 10 s clip:
    • T_s_i = 25

推荐第一版:

  • 线性插值
  • 或轻量 Conv1d 下采样

注意:

  • 2.5 Hz 明确指最终给 LLM 的 token rate
  • 对单个 10 s 样本,10 s -> 25 个有效 token
  • 对 mixed-length batch,张量会 pad 到 T_s_max

5.7 Shallow Temporal Readout

目标:

  • 在时间维上再做一层轻量整理
  • 让 trunk 输出更适合做空间监督和给 LLM 使用

推荐做法:

  • 1~2 层 transformer encoder

输入:

  • temporal_tokens: [B, T_s_max, 768]

输出:

  • spatial_embeddings: [B, T_s_max, 768]

这一层的作用是:

  • 做一个浅层 readout neck
  • 不承担复杂检测器职责
  • 不强调 decoder 身份
  • 只是把 trunk 表征整理成更干净的时间步级空间 embedding

如果你想先做最简单版本,也可以直接:

  • temporal_tokens -> LayerNorm -> spatial_embeddings

把 shallow transformer 作为后续增强项。

5.8 Prediction Heads

目标:

  • spatial_embeddings 上接显式监督
  • 通过 loss 让前面的 trunk 真正学到空间表征
  • 在不引入复杂 decoder 的前提下提供多源监督出口

输入:

  • spatial_embeddings: [B, T_s_max, 768]

5.8.1 Encoder-only 阶段的固定槽位 readout

虽然简化版不再使用复杂 slot query decoder,但当前仍然有:

  • 最大同时源数约为 4

因此,encoder-only 训练阶段推荐在 spatial_embeddings 后接一个很轻的固定槽位 readout:

spatial_embeddings [B, T_s_max, 768]
  -> Linear / MLP expand
  -> slot_latents [B, T_s_max, 4, H]
  -> shared prediction heads

推荐默认:

  • H = 768

最简单实现:

Linear(768 -> 4 * 768)
reshape -> [B, 25, 4, 768]

更稳一点的实现:

Linear(768 -> 768)
-> GELU
-> Linear(768 -> 4 * 768)
reshape -> [B, 25, 4, 768]

这一步的作用不是做复杂目标解析,而只是:

  • 给每个时间步提供 4 个固定 source 槽位
  • 让多源监督有明确着力点
  • 把 loss 稳定回传到前面的 trunk

5.8.2 预测头设计

slot_latents [B, 25, 4, H] 接共享或独立 heads。

建议输出头:

  • activity / objectness head
  • azimuth head
  • elevation head
  • distance head
  • class auxiliary head

输出形式建议:

  • pred_activity: [B, T_s_max, 4]
  • pred_azi_logits: [B, T_s_max, 4, 360]
  • pred_ele_logits: [B, T_s_max, 4, 180]
  • pred_dist: [B, T_s_max, 4, 1]
  • pred_class_logits: [B, T_s_max, 4, C_cls]

其中:

  • pred_activity 负责当前时间步当前槽位是否解释某个源
  • pred_azi_logits / pred_ele_logits 负责方向
  • pred_dist 负责连续距离回归
  • pred_class_logits 负责 source-level 辅助类别监督

当前默认建议:

  • C_cls = 65
  • label vocabulary 来自:
    • final_vocabulary.csv
  • 推荐字段:
    • label_id
    • final_label

5.8.3 为什么这里仍然保留 K=4

在简化版中:

  • K=4 不再体现在复杂 decoder 结构里
  • 但仍然通过固定槽位 readout 出现在监督头中

也就是说:

  • 不需要复杂 query-based object slots
  • 但仍然需要一个简单的多槽位 readout 来承载多源标签

这更符合当前“先把主干训练出来”的目标。

5.9 Spatial Projector

目标:

  • spatial_embeddings 投影到 LLM hidden size

输入:

  • spatial_embeddings: [B, T_s_max, 768]

推荐:

  • 独立的 MLP projector

形式:

Linear(768 -> D_mid)
-> GELU
-> LayerNorm
-> Linear(D_mid -> d_llm)

输出:

  • llm_spatial_tokens: [B, T_s_max, d_llm]

这就是最终喂给 LLM 的 spatial tokens。

5.9.1 Encoder-only 阶段 projector 的角色

当前第一阶段如果只训练 encoder,本质目标是:

  • 用监督把 spatial_embeddings 训好

因此 projector 在这一阶段有两种合理策略:

方案 A:先不训练 projector,推荐默认

  • 只训练:

    • preprocessor
    • patch embedding
    • BEATs trunk
    • temporal readout
    • fixed-slot prediction heads
  • projector 只保留接口,不参与训练

方案 B:一并训练 projector

  • 当你希望尽早固定 LLM 接口维度时可以启用

但第一版默认推荐:

  • 先把 encoder 训好,再训练 projector

6. 每层的输入输出总结

6.1 SpatialBEATsPreprocessor

  • 输入:[B, 4, T]
  • 输出:[B, 7, T_f, 128]

6.2 SpatialPatchEmbedding

  • 输入:[B, 7, T_f, 128]
  • 输出:[B, N_p, 512]

6.3 BEATs trunk

  • 输入:[B, N_p, 512]
  • 输出:[B, N_p, 768]

6.4 Reshape + Frequency Pooling

  • 输入:[B, N_p, 768]
  • 输出:[B, T_p, 768]

6.5 Temporal Resampler

  • 输入:[B, T_p, 768]
  • 输出:[B, 25, 768]

6.6 Shallow Temporal Readout

  • 输入:[B, T_s_max, 768]
  • 输出:[B, T_s_max, 768]

6.7 Prediction Heads

  • 输入:[B, T_s_max, 768]
  • 先经 Linear / MLP expand 变成:[B, T_s_max, 4, H]
  • 输出:
    • activity [B, T_s_max, 4]
    • azimuth [B, T_s_max, 4, 360]
    • elevation [B, T_s_max, 4, 180]
    • distance [B, T_s_max, 4, 1]
    • class [B, T_s_max, 4, C_cls]

6.8 Spatial Projector

  • 输入:[B, T_s_max, 768]
  • 输出:[B, T_s_max, d_llm]

7. Loss 如何作用到前面的表征

这个简化方案的关键点在于:

  • 后面的 heads 并不是模型重点
  • 它们只是为了接监督

loss 从这些 heads 回传后,会更新:

  1. readout neck
  2. temporal resampler
  3. BEATs trunk
  4. patch embedding
  5. FOA preprocessor

因此:

  • 只要后面的 decode/head 能稳定预测显式空间信息
  • 前面的 trunk 就会被训练成空间 encoder

推荐 loss

当前建议:

  • L_activity
  • L_azi
  • L_ele
  • L_dist
  • L_cls_aux
  • L_temp

总损失:

L_total =
  lambda_act * L_activity
  + lambda_azi * L_azi
  + lambda_ele * L_ele
  + lambda_dist * L_dist
  + lambda_cls * L_cls_aux
  + lambda_temp * L_temp

当前可继续保留:

  • 方向分类
  • 仰角分类
  • 距离连续回归
  • 辅助类别监督
  • 时间一致性正则

7.1 Encoder-only 训练阶段的 loss 定义

L_activity

用于监督:

  • 当前时间步当前槽位是否激活

建议:

  • BCEWithLogitsLossfocal loss

结合当前的弱时间窗口 supervision:

  • 窗口外:负样本
  • 窗口内:弱正样本

L_cls_aux

用于监督:

  • 被分配到某个 GT source 的槽位类别

建议:

  • CrossEntropyLoss

L_aziL_ele

用于监督:

  • 槽位的方向分类

建议:

  • CrossEntropyLoss
  • azimuth: 360 bins
  • elevation: 180 bins

L_dist

用于监督:

  • 槽位的连续距离回归

建议:

  • 先将距离归一化到 [0, 1]
  • 使用 SmoothL1Loss

L_temp

由于当前位置当前是 clip 内固定,建议加入:

  • 时间一致性约束

例如对同一 source 在相邻时间步对应的槽位加:

  • class 分布平滑
  • direction 分布平滑
  • distance 平滑

第一版可以先只加:

  • distance smoothness
  • activity smoothness

7.2 Encoder-only 训练时的匹配方式

因为当前 heads 是固定 4 槽位,而不是 query decoder,所以建议:

  • 使用轻量匹配
  • 不必依赖复杂 decoder 结构

推荐策略:

  1. 对每个 GT source,根据其时间窗口筛出候选时间步
  2. 在该时间步的 4 个固定槽位中,选择 cost 最小的槽位
  3. 对该槽位施加:
    • activity
    • class
    • azi
    • ele
    • dist

推荐 cost:

cost =
  w_act  * cost_act
  + w_cls  * cost_cls
  + w_azi  * cost_azi
  + w_ele  * cost_ele
  + w_dist * cost_dist

第一版可以不做全局 Hungarian,直接做:

  • per-step fixed-slot matching

如果后面发现多源冲突严重,再升级 matching。

8. 关于多源 supervision 的理解

虽然当前结构没有显式 slot query decoder,但这并不等于完全不考虑多源。

当前更合理的理解是:

  • spatial_embeddings [B, 25, D] 表示该时间步的空间场景表征
  • 再通过固定 4 槽位 readout head 承载多源标签
  • supervision 用这些多源标签约束 trunk 必须编码多个源的空间信息

这相当于:

  • 先学一个强的 time-step level spatial embedding
  • 再决定是否需要升级成更复杂的 object-centric / query-based 版本

这是更稳的第一版路径。

9. 为什么这个简化版更合适当前阶段

9.1 更符合老师的建议

老师的核心意见可以总结为:

  • 不必执着于 decoder 结构本身
  • 后面的模块只要能 decode 出空间监督即可
  • 真正要训好的是前面的 trunk

这个简化版完全符合这个思路。

9.2 工程复杂度更低

不需要一开始就实现:

  • slot query decoder
  • cross-attention decoder
  • objectness pooling
  • 复杂 slot matching

9.3 更利于先验证 trunk 是否真的学到空间表征

如果这版能成功:

  • 说明 trunk + temporal readout 本身已经足够表达空间信息

如果这版不够:

  • 再升级到 slot/object 版本

这样路线更清晰。

10. 推荐实现版本

V1:最小可行版

FOA -> preprocessor -> patch embed -> BEATs trunk
   -> frequency pooling -> temporal resampler
   -> LayerNorm
   -> spatial embeddings
   -> Linear expand to 4 slots
   -> heads

V2:推荐正式版

FOA -> preprocessor -> patch embed -> BEATs trunk
   -> frequency pooling -> temporal resampler
   -> 1~2 层 shallow transformer readout
   -> spatial embeddings
   -> Linear / MLP expand to 4 slots
   -> heads
   -> projector

V3:后续增强版

如果未来发现:

  • 多源关系建模仍然不足
  • 动态轨迹下表达不够

再升级为:

  • slot-based decoder
  • query-based object-centric readout

11. 推荐代码划分

建议新增或保留以下文件:

  • spatial_beats.py
  • spatial_modules.py
  • spatial_loss.py
  • spatial_dataset.py
  • train_spatial_beats.py

spatial_modules.py

建议包含:

  • SpatialBEATsPreprocessor
  • SpatialPatchEmbedding
  • TemporalResampler
  • TemporalReadoutTransformer
  • FixedSlotReadoutHead
  • SpatialPredictionHeads
  • SpatialProjector

spatial_beats.py

建议主类:

  • SpatialBEATsConfig
  • SpatialBEATs

主类中建议暴露:

  • extract_spatial_embeddings()
  • project_for_llm()
  • forward()

12. 最终输出接口

推荐 forward() 返回:

{
    "encoder_memory": FloatTensor[B, N_p, 768],
    "temporal_tokens": FloatTensor[B, 25, 768],
    "spatial_embeddings": FloatTensor[B, 25, 768],
    "slot_latents": FloatTensor[B, 25, 4, H],
    "pred_activity": FloatTensor[B, 25, 4],
    "pred_azi_logits": FloatTensor[B, 25, 4, 360],
    "pred_ele_logits": FloatTensor[B, 25, 4, 180],
    "pred_dist": FloatTensor[B, 25, 4, 1],
    "pred_class_logits": FloatTensor[B, 25, 4, C_cls],
    "llm_spatial_tokens": FloatTensor[B, 25, d_llm],
}

其中:

  • spatial_embeddings 是最核心的中间表示
  • slot_latents 只是 encoder-only 监督出口
  • llm_spatial_tokens 是最终给 LLM 的接口

13. 结论

当前推荐的简化版最终架构是:

  • FOA -> spatial features -> BEATs trunk -> 2.5Hz temporal spatial embeddings -> 固定4槽位 readout heads -> projector

这套方案的重点是:

  • 用显式空间监督把前面的 BEATs trunk 训练成空间 encoder
  • 后面的 fixed-slot head 只承担监督和 readout 的职责
  • 给 LLM 的 token 直接来自前面的 spatial_embeddings

这是当前最适合进入代码实现的主线方案。