# Spatial-BEATs 简化版实现文档 ## 1. 文档目的 本文档给出当前推荐的 `Spatial-BEATs` 简化版实现方案。 这份方案基于一个更务实的判断: - 真正重要的是让 `FOA -> BEATs trunk` 学到稳定的空间表征 - 后面的模块只需要承担 `readout / decode / supervision` 的作用 - 不需要一开始就引入复杂的 `slot query decoder` - 最终给 LLM 的 token 应尽量来自前面的空间 embedding,而不是最终任务头输出 - 当前阶段需要先支持 `encoder-only training` 因此,本方案不再以“内部 4 个 slots + decoder 聚合”作为主线,而改为: ```text FOA waveform -> Spatial preprocessor -> Multi-channel patch embedding -> BEATs trunk -> Temporal readout -> Spatial embeddings at 2.5 Hz -> Fixed-slot prediction heads -> Projector -> LLM spatial tokens ``` ## 2. 当前确定的约束 - 采样率:`16 kHz` - 输入:`FOA waveform` - 多源数据:有 - 最大同时源数:约 `4` - 每个源有稳定 `source-level class label` - source vocabulary:`/apdcephfs_cq12/share_302080740/user/schmittzhu/data/fsd50k/FSD50K.ground_truth/final_vocabulary.csv` - source class count:当前默认 `65` - 距离预测:`连续回归` - mel 前端参数:对齐 `Qwen-2.5-Omni` audio tower 的底层配置 - `sample_rate=16000` - `num_mel_bins=128` - `n_fft=400` - `win_length=400` - `hop_length=160` - `dither=0.0` - 时间 supervision:`弱时间窗口` - 当前位置:`clip 内固定` - 未来会扩展到:`位置随时间变化` - 目标输出 token rate:`2.5 Hz` - 对于任意 clip,第 `i` 个样本的有效 token 数是 `T_s_i = round(duration_i * 2.5)` - 对于 `10 s` clip,`T_s_i = 25` - batch 内部按 `T_s_max = max_i T_s_i` 做 padding - 主干初始化:`BEATs_iter3+ AS2M pre-trained` - 当前第一阶段:优先支持只训练 encoder 的监督方案 ## 3. 方案核心思想 ### 3.1 什么是主角 主角是: - `BEATs trunk` 它负责从 FOA 空间特征中学习空间表征。 ### 3.2 什么是配角 配角是: - temporal readout - prediction heads - projector 这些模块的作用只是: - 从 trunk 特征中“读出”空间信息 - 建立 loss 回传路径 - 把空间 embedding 投影到 LLM 接口 ### 3.3 给 LLM 的 token 从哪里来 最终给 LLM 的 token 应该来自: - trunk 后 - 或 trunk 后再经过一层很浅的 temporal readout 之后 而不是来自: - 最终 logits - 复杂 decoder 的末端输出 ## 4. 最终结构总览 ```text FOA waveform [B, 4, T] -> SpatialBEATsPreprocessor -> FOA feature map [B, C_foa, T_f, F] -> SpatialPatchEmbedding -> patch tokens [B, N_p, D_in] -> BEATs trunk -> encoder memory [B, N_p, D] -> reshape / frequency pooling -> temporal tokens [B, T_s_max, D] -> shallow temporal readout -> spatial embeddings [B, T_s_max, D] -> prediction heads -> Spatial projector -> llm spatial tokens [B, T_s_max, d_llm] ``` 其中: - 每个样本: - `T_s_i = round(duration_i * 2.5)` - 一个 batch 内: - `T_s_max = max_i T_s_i` - 对 `10 s` 输入: - `T_s_i = 25` ## 5. 从 FOA 到 spatial token 的完整过程 ## 5.1 输入层 输入: - `waveform: [B, 4, T]` 四个通道分别是: - `W` - `X` - `Y` - `Z` 对 `10 s, 16kHz`: - `T = 160000` ## 5.2 SpatialBEATsPreprocessor 目标: - 把原始 FOA 波形转成多通道空间特征图 推荐输出通道: - `W_logmel` - `X_logmel` - `Y_logmel` - `Z_logmel` - `IVx` - `IVy` - `IVz` 因此: - `C_foa = 7` 内部步骤: 1. 对 `WXYZ` 做 STFT 2. 计算各通道 `log-mel` 3. 计算 `IVx, IVy, IVz` 4. 将 `IV` 映射到 mel 维 5. 拼接成特征图 输出: - `foa_feat: [B, 7, T_f, 128]` 其中: - `128` 是 mel bins - `T_f` 约等于 `1000`,如果帧移为 `10 ms` ## 5.3 SpatialPatchEmbedding 目标: - 把 `7` 通道 FOA 特征图切成 patch token 原始 BEATs 的 stem 是单通道: ```text Conv2d(1, embed_dim, kernel_size=patch, stride=patch) ``` 新的 stem 改成: ```text Conv2d(7, embed_dim, kernel_size=patch, stride=patch) ``` 输入: - `foa_feat: [B, 7, T_f, 128]` 输出: - `patch_tokens: [B, N_p, D_in]` - `grid_hw = (T_p, F_p)` 建议: - `D_in = 512` 原因: - 与 BEATs patch embedding dim 对齐 ## 5.4 BEATs trunk 这是整个模型最重要的部分。 保留并复用的模块包括: - patch token layer norm - `post_extract_proj` - `dropout_input` - `TransformerEncoder` - `conv_pos` - 所有 transformer layers - LayerNorm / FFN / attention 输入: - `patch_tokens: [B, N_p, 512]` 输出: - `encoder_memory: [B, N_p, 768]` 建议: - 主干 hidden dim 保持 `768` - 直接加载 `BEATs_iter3+ AS2M pre-trained` ## 5.5 Reshape + Frequency Pooling 目标: - 把 trunk 输出变成可读的时间序列表示 步骤: 1. 将 `encoder_memory [B, N_p, D]` reshape 成: - `grid_memory [B, T_p, F_p, D]` 2. 对 `F_p` 维做 pooling 推荐第一版: - `mean pooling over frequency` 输出: - `temporal_patch_tokens: [B, T_p, D]` 这一步的意义是: - 先把频率信息压到时间轴上 - 便于后面构造固定 `2.5 Hz` 的时序空间 token ## 5.6 Temporal Resampler 目标: - 把 patch 级时间序列压到目标 token rate 输入: - `temporal_patch_tokens: [B, T_p, D]` 输出: - `temporal_tokens: [B, T_s, D]` 其中: - 第 `i` 个样本: - `T_s_i = round(duration_i * 2.5)` - batch padding 后: - `temporal_tokens: [B, T_s_max, D]` - 对 `10 s` clip: - `T_s_i = 25` 推荐第一版: - 线性插值 - 或轻量 `Conv1d` 下采样 注意: - `2.5 Hz` 明确指最终给 LLM 的 token rate - 对单个 `10 s` 样本,`10 s -> 25` 个有效 token - 对 mixed-length batch,张量会 pad 到 `T_s_max` ## 5.7 Shallow Temporal Readout 目标: - 在时间维上再做一层轻量整理 - 让 trunk 输出更适合做空间监督和给 LLM 使用 推荐做法: - `1~2` 层 transformer encoder 输入: - `temporal_tokens: [B, T_s_max, 768]` 输出: - `spatial_embeddings: [B, T_s_max, 768]` 这一层的作用是: - 做一个浅层 readout neck - 不承担复杂检测器职责 - 不强调 decoder 身份 - 只是把 trunk 表征整理成更干净的时间步级空间 embedding 如果你想先做最简单版本,也可以直接: - `temporal_tokens -> LayerNorm -> spatial_embeddings` 把 shallow transformer 作为后续增强项。 ## 5.8 Prediction Heads 目标: - 从 `spatial_embeddings` 上接显式监督 - 通过 loss 让前面的 trunk 真正学到空间表征 - 在不引入复杂 decoder 的前提下提供多源监督出口 输入: - `spatial_embeddings: [B, T_s_max, 768]` ### 5.8.1 Encoder-only 阶段的固定槽位 readout 虽然简化版不再使用复杂 `slot query decoder`,但当前仍然有: - 最大同时源数约为 `4` 因此,encoder-only 训练阶段推荐在 `spatial_embeddings` 后接一个很轻的固定槽位 readout: ```text spatial_embeddings [B, T_s_max, 768] -> Linear / MLP expand -> slot_latents [B, T_s_max, 4, H] -> shared prediction heads ``` 推荐默认: - `H = 768` 最简单实现: ```text Linear(768 -> 4 * 768) reshape -> [B, 25, 4, 768] ``` 更稳一点的实现: ```text Linear(768 -> 768) -> GELU -> Linear(768 -> 4 * 768) reshape -> [B, 25, 4, 768] ``` 这一步的作用不是做复杂目标解析,而只是: - 给每个时间步提供 `4` 个固定 source 槽位 - 让多源监督有明确着力点 - 把 loss 稳定回传到前面的 trunk ### 5.8.2 预测头设计 对 `slot_latents [B, 25, 4, H]` 接共享或独立 heads。 建议输出头: - `activity / objectness head` - `azimuth head` - `elevation head` - `distance head` - `class auxiliary head` 输出形式建议: - `pred_activity: [B, T_s_max, 4]` - `pred_azi_logits: [B, T_s_max, 4, 360]` - `pred_ele_logits: [B, T_s_max, 4, 180]` - `pred_dist: [B, T_s_max, 4, 1]` - `pred_class_logits: [B, T_s_max, 4, C_cls]` 其中: - `pred_activity` 负责当前时间步当前槽位是否解释某个源 - `pred_azi_logits` / `pred_ele_logits` 负责方向 - `pred_dist` 负责连续距离回归 - `pred_class_logits` 负责 source-level 辅助类别监督 当前默认建议: - `C_cls = 65` - label vocabulary 来自: - `final_vocabulary.csv` - 推荐字段: - `label_id` - `final_label` ### 5.8.3 为什么这里仍然保留 K=4 在简化版中: - `K=4` 不再体现在复杂 decoder 结构里 - 但仍然通过固定槽位 readout 出现在监督头中 也就是说: - 不需要复杂 query-based object slots - 但仍然需要一个简单的多槽位 readout 来承载多源标签 这更符合当前“先把主干训练出来”的目标。 ## 5.9 Spatial Projector 目标: - 把 `spatial_embeddings` 投影到 LLM hidden size 输入: - `spatial_embeddings: [B, T_s_max, 768]` 推荐: - 独立的 `MLP projector` 形式: ```text Linear(768 -> D_mid) -> GELU -> LayerNorm -> Linear(D_mid -> d_llm) ``` 输出: - `llm_spatial_tokens: [B, T_s_max, d_llm]` 这就是最终喂给 LLM 的 spatial tokens。 ### 5.9.1 Encoder-only 阶段 projector 的角色 当前第一阶段如果只训练 encoder,本质目标是: - 用监督把 `spatial_embeddings` 训好 因此 projector 在这一阶段有两种合理策略: #### 方案 A:先不训练 projector,推荐默认 - 只训练: - preprocessor - patch embedding - BEATs trunk - temporal readout - fixed-slot prediction heads - projector 只保留接口,不参与训练 #### 方案 B:一并训练 projector - 当你希望尽早固定 LLM 接口维度时可以启用 但第一版默认推荐: - **先把 encoder 训好,再训练 projector** ## 6. 每层的输入输出总结 ### 6.1 SpatialBEATsPreprocessor - 输入:`[B, 4, T]` - 输出:`[B, 7, T_f, 128]` ### 6.2 SpatialPatchEmbedding - 输入:`[B, 7, T_f, 128]` - 输出:`[B, N_p, 512]` ### 6.3 BEATs trunk - 输入:`[B, N_p, 512]` - 输出:`[B, N_p, 768]` ### 6.4 Reshape + Frequency Pooling - 输入:`[B, N_p, 768]` - 输出:`[B, T_p, 768]` ### 6.5 Temporal Resampler - 输入:`[B, T_p, 768]` - 输出:`[B, 25, 768]` ### 6.6 Shallow Temporal Readout - 输入:`[B, T_s_max, 768]` - 输出:`[B, T_s_max, 768]` ### 6.7 Prediction Heads - 输入:`[B, T_s_max, 768]` - 先经 `Linear / MLP expand` 变成:`[B, T_s_max, 4, H]` - 输出: - `activity [B, T_s_max, 4]` - `azimuth [B, T_s_max, 4, 360]` - `elevation [B, T_s_max, 4, 180]` - `distance [B, T_s_max, 4, 1]` - `class [B, T_s_max, 4, C_cls]` ### 6.8 Spatial Projector - 输入:`[B, T_s_max, 768]` - 输出:`[B, T_s_max, d_llm]` ## 7. Loss 如何作用到前面的表征 这个简化方案的关键点在于: - 后面的 heads 并不是模型重点 - 它们只是为了接监督 loss 从这些 heads 回传后,会更新: 1. readout neck 2. temporal resampler 3. BEATs trunk 4. patch embedding 5. FOA preprocessor 因此: - 只要后面的 decode/head 能稳定预测显式空间信息 - 前面的 trunk 就会被训练成空间 encoder ### 推荐 loss 当前建议: - `L_activity` - `L_azi` - `L_ele` - `L_dist` - `L_cls_aux` - `L_temp` 总损失: ```text L_total = lambda_act * L_activity + lambda_azi * L_azi + lambda_ele * L_ele + lambda_dist * L_dist + lambda_cls * L_cls_aux + lambda_temp * L_temp ``` 当前可继续保留: - 方向分类 - 仰角分类 - 距离连续回归 - 辅助类别监督 - 时间一致性正则 ### 7.1 Encoder-only 训练阶段的 loss 定义 #### `L_activity` 用于监督: - 当前时间步当前槽位是否激活 建议: - `BCEWithLogitsLoss` 或 `focal loss` 结合当前的弱时间窗口 supervision: - 窗口外:负样本 - 窗口内:弱正样本 #### `L_cls_aux` 用于监督: - 被分配到某个 GT source 的槽位类别 建议: - `CrossEntropyLoss` #### `L_azi` 和 `L_ele` 用于监督: - 槽位的方向分类 建议: - `CrossEntropyLoss` - `azimuth`: `360` bins - `elevation`: `180` bins #### `L_dist` 用于监督: - 槽位的连续距离回归 建议: - 先将距离归一化到 `[0, 1]` - 使用 `SmoothL1Loss` #### `L_temp` 由于当前位置当前是 clip 内固定,建议加入: - 时间一致性约束 例如对同一 source 在相邻时间步对应的槽位加: - class 分布平滑 - direction 分布平滑 - distance 平滑 第一版可以先只加: - distance smoothness - activity smoothness ## 7.2 Encoder-only 训练时的匹配方式 因为当前 heads 是固定 `4` 槽位,而不是 query decoder,所以建议: - 使用轻量匹配 - 不必依赖复杂 decoder 结构 推荐策略: 1. 对每个 GT source,根据其时间窗口筛出候选时间步 2. 在该时间步的 `4` 个固定槽位中,选择 cost 最小的槽位 3. 对该槽位施加: - activity - class - azi - ele - dist 推荐 cost: ```text cost = w_act * cost_act + w_cls * cost_cls + w_azi * cost_azi + w_ele * cost_ele + w_dist * cost_dist ``` 第一版可以不做全局 Hungarian,直接做: - `per-step fixed-slot matching` 如果后面发现多源冲突严重,再升级 matching。 ## 8. 关于多源 supervision 的理解 虽然当前结构没有显式 `slot query decoder`,但这并不等于完全不考虑多源。 当前更合理的理解是: - 让 `spatial_embeddings [B, 25, D]` 表示该时间步的空间场景表征 - 再通过固定 `4` 槽位 readout head 承载多源标签 - supervision 用这些多源标签约束 trunk 必须编码多个源的空间信息 这相当于: - 先学一个强的 time-step level spatial embedding - 再决定是否需要升级成更复杂的 object-centric / query-based 版本 这是更稳的第一版路径。 ## 9. 为什么这个简化版更合适当前阶段 ### 9.1 更符合老师的建议 老师的核心意见可以总结为: - 不必执着于 decoder 结构本身 - 后面的模块只要能 decode 出空间监督即可 - 真正要训好的是前面的 trunk 这个简化版完全符合这个思路。 ### 9.2 工程复杂度更低 不需要一开始就实现: - slot query decoder - cross-attention decoder - objectness pooling - 复杂 slot matching ### 9.3 更利于先验证 trunk 是否真的学到空间表征 如果这版能成功: - 说明 trunk + temporal readout 本身已经足够表达空间信息 如果这版不够: - 再升级到 slot/object 版本 这样路线更清晰。 ## 10. 推荐实现版本 ### V1:最小可行版 ```text FOA -> preprocessor -> patch embed -> BEATs trunk -> frequency pooling -> temporal resampler -> LayerNorm -> spatial embeddings -> Linear expand to 4 slots -> heads ``` ### V2:推荐正式版 ```text FOA -> preprocessor -> patch embed -> BEATs trunk -> frequency pooling -> temporal resampler -> 1~2 层 shallow transformer readout -> spatial embeddings -> Linear / MLP expand to 4 slots -> heads -> projector ``` ### V3:后续增强版 如果未来发现: - 多源关系建模仍然不足 - 动态轨迹下表达不够 再升级为: - slot-based decoder - query-based object-centric readout ## 11. 推荐代码划分 建议新增或保留以下文件: - `spatial_beats.py` - `spatial_modules.py` - `spatial_loss.py` - `spatial_dataset.py` - `train_spatial_beats.py` ### `spatial_modules.py` 建议包含: - `SpatialBEATsPreprocessor` - `SpatialPatchEmbedding` - `TemporalResampler` - `TemporalReadoutTransformer` - `FixedSlotReadoutHead` - `SpatialPredictionHeads` - `SpatialProjector` ### `spatial_beats.py` 建议主类: - `SpatialBEATsConfig` - `SpatialBEATs` 主类中建议暴露: - `extract_spatial_embeddings()` - `project_for_llm()` - `forward()` ## 12. 最终输出接口 推荐 `forward()` 返回: ```python { "encoder_memory": FloatTensor[B, N_p, 768], "temporal_tokens": FloatTensor[B, 25, 768], "spatial_embeddings": FloatTensor[B, 25, 768], "slot_latents": FloatTensor[B, 25, 4, H], "pred_activity": FloatTensor[B, 25, 4], "pred_azi_logits": FloatTensor[B, 25, 4, 360], "pred_ele_logits": FloatTensor[B, 25, 4, 180], "pred_dist": FloatTensor[B, 25, 4, 1], "pred_class_logits": FloatTensor[B, 25, 4, C_cls], "llm_spatial_tokens": FloatTensor[B, 25, d_llm], } ``` 其中: - `spatial_embeddings` 是最核心的中间表示 - `slot_latents` 只是 encoder-only 监督出口 - `llm_spatial_tokens` 是最终给 LLM 的接口 ## 13. 结论 当前推荐的简化版最终架构是: - `FOA -> spatial features -> BEATs trunk -> 2.5Hz temporal spatial embeddings -> 固定4槽位 readout heads -> projector` 这套方案的重点是: - 用显式空间监督把前面的 `BEATs trunk` 训练成空间 encoder - 后面的 fixed-slot head 只承担监督和 readout 的职责 - 给 LLM 的 token 直接来自前面的 `spatial_embeddings` 这是当前最适合进入代码实现的主线方案。