| # Spatial-BEATs 实现规格 |
|
|
| ## 1. 目标 |
|
|
| 本规格文档用于将前期讨论收敛为一个可以直接实施的 `Spatial-BEATs` 方案。 |
|
|
| 目标是构建一个独立的 `Spatial Encoder`: |
|
|
| - 输入为完整 `FOA` 音频及其派生空间特征 |
| - 完整的 `FOA` 特征经过 `BEATs backbone` |
| - 最大化复用 `BEATs` 预训练权重 |
| - 输出一组 `source-level spatial tokens` |
| - 这些 token 作为独立模态输入给 LLM |
| - 原有语义 audio encoder 保持不动 |
|
|
| 这里的关键原则是: |
|
|
| > 不是让 `W-only` 走主干,再外挂一个小空间 adapter;而是让完整 FOA 空间特征真正进入 BEATs 主干,并在主干之后产出结构化空间 token。 |
|
|
| ## 2. 最终任务定义 |
|
|
| ### 2.1 核心任务 |
|
|
| `Spatial-BEATs` 的主任务定义为: |
|
|
| - 给定一个多源 `FOA` 音频片段 |
| - 预测其中最多 `K` 个潜在声源的空间表示 |
| - 每个表示对应一个 `source token` |
|
|
| 每个 source token 至少承载: |
|
|
| - `objectness` |
| - `azimuth` |
| - `elevation` |
| - `distance` |
|
|
| 可选承载: |
|
|
| - `source class auxiliary logits` |
| - `source embedding` |
|
|
| ### 2.2 推荐监督形式 |
|
|
| 如果训练数据中每个源都有标注,则推荐采用: |
|
|
| - `set prediction` |
| - `K` 个预测 token 对 `N` 个 GT sources |
| - 用 `Hungarian matching` 做一一匹配 |
|
|
| 不建议采用: |
|
|
| - 单一 scene-level spatial token |
| - 仅回归整段音频的全局空间摘要 |
|
|
| 原因是这会损失多源结构,不利于后续 LLM 做关系推理。 |
|
|
| ## 3. 最终架构 |
|
|
| 推荐最终架构: |
|
|
| ```text |
| FOA waveform |
| -> SpatialBEATsPreprocessor |
| -> FOA feature map [B, C_foa, T, F] |
| -> FOA patch embedding |
| -> BEATs trunk |
| -> Spatial query decoder |
| -> K source tokens |
| -> Spatial prediction heads |
| -> LLM projector |
| ``` |
|
|
| 为了最大化复用 BEATs 主干,本方案尽量不改 trunk 内部的 Transformer 结构。 |
|
|
| ## 4. 输入特征定义 |
|
|
| ### 4.1 默认推荐特征 |
|
|
| 第一版推荐输入通道: |
|
|
| - `W_logmel` |
| - `X_logmel` |
| - `Y_logmel` |
| - `Z_logmel` |
| - `IVx` |
| - `IVy` |
| - `IVz` |
|
|
| 即: |
|
|
| - `C_foa = 7` |
|
|
| 这是默认推荐方案。 |
|
|
| ### 4.2 备选输入特征 |
|
|
| 若希望先降低复杂度,可以使用: |
|
|
| - `WXYZ logmel` |
|
|
| 即: |
|
|
| - `C_foa = 4` |
|
|
| 但这只适合最小原型。 |
| 如果目标是稳定学习空间方向与结构,优先使用 `WXYZ + IV`。 |
|
|
| ### 4.3 前端参数建议 |
|
|
| 为了最大化复用 BEATs 主干,推荐保持与 BEATs 接近的时频分辨率: |
|
|
| - sample rate:优先 `16k` |
| - mel bins:`128` |
| - frame length:`25 ms` |
| - frame shift:`10 ms` |
|
|
| 原因: |
|
|
| - 这能让 trunk 看到与原始 BEATs 更接近的 patch 几何结构 |
| - patch embedding 和后续序列长度更容易保持一致 |
| - 预训练权重复用更稳定 |
|
|
| ### 4.4 为什么不沿用 Spatial-AST 的 binaural 前端 |
|
|
| Spatial-AST 采用的是: |
|
|
| - 双耳 log-mel |
| - IPD |
|
|
| 这适合 binaural,不适合直接迁移到 FOA。 |
|
|
| FOA 下应优先利用: |
|
|
| - ambisonic 通道本身 |
| - intensity vector |
| - 其他 FOA 物理特征 |
|
|
| ## 5. 对 BEATs 具体修改哪些模块 |
|
|
| 下面按模块说明修改方案。 |
|
|
| ### 5.1 保留不动的模块 |
|
|
| 建议尽量保留: |
|
|
| - `TransformerEncoder` |
| - `TransformerSentenceEncoderLayer` |
| - `MultiheadAttention` |
| - `conv_pos` |
| - `LayerNorm` |
| - `FFN` |
| - `post_extract_proj` |
|
|
| 也就是 `backbone.py` 内的主干结构和 `BEATs.py` 中的 trunk 逻辑尽量不动。 |
|
|
| ### 5.2 必须修改的模块 |
|
|
| 必须重做: |
|
|
| 1. `preprocess` |
| 2. `patch_embedding` |
| 3. `extract_features` 输出头部逻辑 |
| 4. 下游 `predictor` |
|
|
| ### 5.3 推荐新增的模块 |
|
|
| 建议新增: |
|
|
| 1. `SpatialBEATsPreprocessor` |
| 2. `SpatialPatchEmbedding` |
| 3. `SpatialQueryDecoder` |
| 4. `SpatialPredictionHead` |
| 5. `SpatialTokenProjector` |
| 6. `HungarianMatcher` |
| 7. `SpatialSetCriterion` |
|
|
| ## 6. 代码级映射建议 |
|
|
| ### 6.1 现有文件建议 |
|
|
| 建议保留和复用: |
|
|
| - [BEATs.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/BEATs.py) |
| - [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py) |
|
|
| 建议新增: |
|
|
| - `spatial_beats.py` |
| - `spatial_modules.py` |
| - `spatial_loss.py` |
| - `spatial_dataset.py` |
| - `train_spatial_beats.py` |
|
|
| ### 6.2 `spatial_beats.py` 建议包含 |
| |
| 建议实现: |
| |
| - `SpatialBEATsConfig` |
| - `SpatialBEATs` |
| - `SpatialBEATs.extract_spatial_tokens()` |
| - `SpatialBEATs.forward()` |
| |
| ### 6.3 `spatial_modules.py` 建议包含 |
|
|
| 建议实现: |
|
|
| - `SpatialBEATsPreprocessor` |
| - `SpatialPatchEmbedding` |
| - `SpatialQueryDecoder` |
| - `SpatialPredictionHead` |
| - `SpatialTokenProjector` |
|
|
| ### 6.4 `spatial_loss.py` 建议包含 |
| |
| 建议实现: |
| |
| - `HungarianMatcher` |
| - `SpatialSetCriterion` |
| |
| ## 7. 预训练权重如何复用 |
| |
| ## 7.1 默认推荐权重 |
| |
| 默认推荐: |
| |
| - `BEATs_iter3+ (AS2M) pre-trained` |
|
|
| 而不是: |
|
|
| - fine-tuned checkpoints |
|
|
| 原因: |
|
|
| - `pre-trained` 更适合作为 trunk 初始化 |
| - `fine-tuned` 更偏向 AudioSet 分类判别 |
| - 你这里的 spatial encoder 应与原语义 encoder 职责分离 |
|
|
| ### 7.2 必须直接加载的层 |
|
|
| 这些层建议直接加载原 BEATs checkpoint: |
|
|
| - `post_extract_proj` |
| - `encoder.pos_conv` |
| - `encoder.layers.*` |
| - `encoder.layer_norm` |
| - `layer_norm` |
|
|
| 即除了输入 stem 和输出头,主干参数都尽量继承。 |
|
|
| ### 7.3 需要特殊初始化的层 |
|
|
| 以下层因为 shape 不同,不能直接 strict load: |
|
|
| - `patch_embedding` |
| - 新增的 `query decoder` |
| - 新增的 `spatial heads` |
| - 新增的 `LLM projector` |
|
|
| ### 7.4 新 patch embedding 的初始化策略 |
|
|
| 原 BEATs stem 是: |
|
|
| - `Conv2d(1, embed_dim, kernel_size=patch, stride=patch)` |
|
|
| 新 stem 建议是: |
|
|
| - `Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)` |
|
|
| 推荐初始化策略: |
|
|
| #### 方案 A:保守初始化,默认推荐 |
|
|
| - `W_logmel` 通道继承原 stem 权重 |
| - 其他空间通道初始化为 `0` 或较小随机值 |
|
|
| 优点: |
|
|
| - 最大程度保留原 BEATs 初始分布 |
| - trunk 适配更稳 |
|
|
| 缺点: |
|
|
| - 训练初期空间通道利用较慢 |
|
|
| #### 方案 B:通道 inflation |
|
|
| - 把原 stem 权重复制到全部输入通道 |
| - 再按通道数做归一化 |
|
|
| 优点: |
|
|
| - 所有通道一开始都能进入主干 |
|
|
| 缺点: |
|
|
| - 初始统计更可能偏离原 BEATs |
|
|
| 最终推荐: |
|
|
| - 第一版用 `方案 A` |
| - 后续做 ablation 再比较 `方案 B` |
|
|
| ## 8. Spatial token 模块的最终设计 |
|
|
| ### 8.1 为什么不用全局池化 |
|
|
| 原始 BEATs 的输出方式更接近: |
|
|
| - patch sequence |
| - mean pooling |
| - clip-level prediction |
|
|
| 这不适合多源空间任务。 |
|
|
| ### 8.2 最终推荐:Query Decoder |
|
|
| 在 trunk 输出后新增: |
|
|
| - `K` 个 learnable source queries |
| - 一个轻量 `cross-attention decoder` |
|
|
| 输入: |
|
|
| - encoder memory:`H in R^{B x T x D}` |
| - source queries:`Q in R^{B x K x D}` |
|
|
| 输出: |
|
|
| - `Z in R^{B x K x D}` |
|
|
| 这里的 `Z[:, i, :]` 即第 `i` 个 `source token` |
|
|
| ### 8.3 为什么 query decoder 是当前最优解 |
|
|
| 它的优点: |
|
|
| - 不改 trunk 内部结构 |
| - 仍然让完整 FOA 特征经过 backbone |
| - 适合多源 set prediction |
| - 最利于最大化复用 trunk 权重 |
|
|
| ## 9. 输出头设计 |
|
|
| 对每个 source token `z_i`,预测: |
|
|
| - `objectness` |
| - `azimuth` |
| - `elevation` |
| - `distance` |
| - 可选 `class_aux` |
|
|
| ### 9.1 离散还是连续 |
|
|
| 第一版推荐全部使用离散分类头: |
|
|
| - `azimuth`: 360 bins |
| - `elevation`: 180 bins |
| - `distance`: 按数据分桶,例如 `0.5m` 一档 |
|
|
| 原因: |
|
|
| - 与已有 Spatial-AST/BAT 经验一致 |
| - 分类头更稳 |
| - 更便于构造离散坐标 embedding |
|
|
| ### 9.2 objectness 头 |
|
|
| 推荐增加: |
|
|
| - `objectness_head: D -> 1` |
|
|
| 用于: |
|
|
| - 判断当前 token 是否对应真实声源 |
| - 作为 Hungarian matching 的一部分 |
| - 推理时做 token 保留/裁剪 |
|
|
| ### 9.3 类别头 |
|
|
| 类别头建议作为: |
|
|
| - `auxiliary head` |
|
|
| 而不是最终 LLM 的主要输入内容。 |
|
|
| 这样做的作用: |
|
|
| - 让 query token 更容易学会 source slot 对齐 |
| - 但不把 Spatial-BEATs 变成第二个强语义 encoder |
|
|
| ## 10. Loss 设计 |
|
|
| 推荐总损失: |
|
|
| ```text |
| L_total = |
| lambda_obj * L_obj |
| + lambda_azi * L_azi |
| + lambda_ele * L_ele |
| + lambda_dist * L_dist |
| + lambda_cls * L_cls_aux |
| ``` |
|
|
| ### 10.1 匹配方式 |
|
|
| 使用 `Hungarian matching`: |
|
|
| - 预测:`K` 个 token |
| - GT:`N` 个 sources |
| - 成本由以下项构成: |
| - objectness cost |
| - azimuth cost |
| - elevation cost |
| - distance cost |
| - optional class cost |
|
|
| ### 10.2 损失项定义 |
|
|
| 推荐: |
|
|
| - `L_obj`: BCE 或 focal loss |
| - `L_azi`: cross entropy |
| - `L_ele`: cross entropy |
| - `L_dist`: cross entropy |
| - `L_cls_aux`: cross entropy 或 BCE |
|
|
| ### 10.3 初始 loss 权重建议 |
|
|
| 第一版建议从以下权重起步: |
|
|
| ```text |
| lambda_obj = 1.0 |
| lambda_azi = 2.0 |
| lambda_ele = 2.0 |
| lambda_dist = 1.0 |
| lambda_cls = 0.25 |
| ``` |
|
|
| 解释: |
|
|
| - 方向任务通常更关键 |
| - 距离次之 |
| - objectness 必须稳定 |
| - 类别监督只作为辅助 |
|
|
| ### 10.4 不建议的做法 |
|
|
| 第一版不建议: |
|
|
| - 重分类损失压倒空间损失 |
| - 直接照搬 Spatial-AST 的 `1250 * cls` |
|
|
| 原因: |
|
|
| - Spatial-AST 的目标之一是保住 sound event detection |
| - 这里 `Spatial-BEATs` 的主要目标是空间 token |
| - 原项目已有独立语义 encoder |
|
|
| ## 11. 训练策略 |
|
|
| ### 11.1 第一阶段是否需要 SSL |
|
|
| 当前最终结论: |
|
|
| - 第一版 **不需要** 重新做 BEATs 式 SSL |
|
|
| 因为当前已经有: |
|
|
| - 多源监督 |
| - 每个源的空间标注 |
| - 可复用的 BEATs 主干预训练 |
|
|
| 所以第一阶段应优先做: |
|
|
| - `supervised multi-source spatial training` |
|
|
| ### 11.2 分阶段训练建议 |
|
|
| #### Stage A:Warmup |
|
|
| 冻结: |
|
|
| - 大部分 trunk |
|
|
| 只训练: |
|
|
| - FOA preprocessor |
| - patch embedding |
| - query decoder |
| - spatial heads |
| - LLM projector |
|
|
| 目的: |
|
|
| - 让新输入 stem 和新输出头稳定接入 trunk |
|
|
| #### Stage B:Upper-trunk finetune |
|
|
| 解冻: |
|
|
| - trunk 上层若干层 |
|
|
| 目的: |
|
|
| - 让主干逐步适应 FOA 空间任务 |
|
|
| #### Stage C:Near-full finetune |
|
|
| 进一步解冻: |
|
|
| - 更多 encoder layers |
|
|
| 目的: |
|
|
| - 提升空间表示上限 |
|
|
| ### 11.3 学习率建议 |
|
|
| 推荐: |
|
|
| - trunk:较小 lr |
| - 新模块:较大学习率 |
|
|
| 例如: |
|
|
| ```text |
| lr_trunk = 1e-5 ~ 5e-5 |
| lr_new = 1e-4 ~ 5e-4 |
| ``` |
|
|
| 并配合: |
|
|
| - layer-wise lr decay |
|
|
| ## 12. 最终输出给 LLM 的 spatial token 形式 |
|
|
| 这是本项目最关键的接口定义之一。 |
|
|
| ### 12.1 内部 token 形式 |
|
|
| `Spatial-BEATs` 内部输出: |
|
|
| - `Z in R^{B x K x D}` |
|
|
| 其中: |
|
|
| - `B`: batch size |
| - `K`: source token 数 |
| - `D`: Spatial-BEATs hidden dim,建议与 BEATs trunk 一致 |
|
|
| ### 12.2 不建议直接把 raw logits 喂给 LLM |
|
|
| 不建议直接给 LLM: |
|
|
| - azimuth logits |
| - elevation logits |
| - distance logits |
| - objectness logits |
|
|
| 这些是监督头,不是最终模态表示。 |
|
|
| ### 12.3 最终推荐的 LLM spatial token 形式 |
|
|
| 最终推荐送给 LLM 的每个 token 形式为: |
|
|
| ```text |
| s_i = Proj([z_i ; e_azi(i) ; e_ele(i) ; e_dist(i) ; e_obj(i)]) |
| ``` |
|
|
| 其中: |
|
|
| - `z_i`: query decoder 输出的 latent token |
| - `e_azi(i)`: 由预测 azimuth bin 查表得到的 embedding |
| - `e_ele(i)`: 由预测 elevation bin 查表得到的 embedding |
| - `e_dist(i)`: 由预测 distance bin 查表得到的 embedding |
| - `e_obj(i)`: 由 objectness/confidence 产生的 embedding |
| - `Proj`: 投影到 LLM hidden size 的 MLP/Linear |
|
|
| 最终: |
|
|
| - `s_i in R^{d_llm}` |
|
|
| ### 12.4 为什么采用“latent + structured embedding”的混合形式 |
|
|
| 原因: |
|
|
| 1. `z_i` 保留丰富的隐式空间结构信息 |
| 2. `坐标 embedding` 给 LLM 显式离散空间线索 |
| 3. `confidence` 有助于 LLM 区分可靠/不可靠 token |
|
|
| 这比单纯只传: |
|
|
| - raw latent token |
|
|
| 或者只传: |
|
|
| - 显式坐标 one-hot / scalar |
|
|
| 都更合适。 |
|
|
| ### 12.5 最终序列形式 |
|
|
| 送入 LLM 时推荐: |
|
|
| ```text |
| <SPATIAL_START>, s_1, s_2, ..., s_K, <SPATIAL_END> |
| ``` |
|
|
| 并且: |
|
|
| - 按 `objectness` 从高到低排序 |
| - 对低置信 token 可直接截断或 mask |
|
|
| ### 12.6 是否保留全部 K 个 token |
|
|
| 默认推荐: |
|
|
| - 训练时保留全部 `K` |
| - 推理时按 `objectness` 过滤 |
|
|
| 例如: |
|
|
| - 保留前 `K_keep` |
| - 或保留 `obj > threshold` 的 token |
|
|
| ## 13. 与原语义 audio encoder 的关系 |
|
|
| 为了避免“两个 encoder 在做同样的事”,推荐如下职责划分: |
|
|
| - 原语义 audio encoder:负责 `what` |
| - Spatial-BEATs:负责 `where / spatial structure / relations` |
|
|
| ### 13.1 是否允许 Spatial-BEATs 学类别 |
|
|
| 允许,但只作为辅助。 |
|
|
| 建议: |
|
|
| - 类别头只用于训练 |
| - 最终输入给 LLM 的空间 token 不直接暴露完整类别 logits |
|
|
| ### 13.2 是否需要和语义 encoder 做对齐 |
|
|
| 第一版不是必须。 |
|
|
| 若后续希望更强的 source grounding,可进一步加入: |
|
|
| - semantic distillation |
| - cross-encoder alignment |
| - source-wise contrastive loss |
|
|
| 但这些应放到第二阶段。 |
|
|
| ## 14. 第一版推荐配置 |
|
|
| 第一版默认建议: |
|
|
| - 输入特征:`WXYZ + IVxyz` |
| - `C_foa = 7` |
| - 采样率:`16k` |
| - mel bins:`128` |
| - patch 配置:与 BEATs 保持一致 |
| - 预训练权重:`BEATs_iter3+ AS2M pre-trained` |
| - trunk:最大化加载 |
| - patch stem:`W` 继承,其余通道小初始化 |
| - 输出:`K` 个 source tokens |
| - token 解码:轻量 query decoder |
| - 监督:Hungarian matching + 多头空间分类 |
| - LLM 输入:`latent + structured coordinate embedding` 的混合 token |
|
|
| ## 15. 实现优先级 |
|
|
| 推荐按如下优先级推进: |
|
|
| 1. 实现 `FOA preprocessor` |
| 2. 实现多通道 `patch embedding` |
| 3. 完成 trunk ckpt 加载 |
| 4. 实现 `query decoder` |
| 5. 实现 `objectness / azi / ele / dist` heads |
| 6. 实现 `Hungarian matcher + criterion` |
| 7. 实现 `LLM projector` |
| 8. 完成训练脚本 |
|
|
| ## 16. 当前仍需用户确认的问题 |
|
|
| 以下问题会直接影响第一版实现细节: |
|
|
| 1. `FOA` 数据当前主要采样率是多少?是 `16k`、`24k`、`32k` 还是 `48k`? |
| 2. 每个样本中 `最大同时源数` 大概是多少?这会影响 `K` 的默认设定。 |
| 3. 每个源是否都有 `source-level class label`?如果有,类别头和匹配会更稳。 |
| 4. 你希望 `distance` 是离散分类还是连续回归?当前默认推荐离散分类。 |
| 5. 下游 LLM 的 hidden size 是多少?是否已有固定的 audio token projector? |
| 6. 你是否希望 Spatial-BEATs 在第一版就具备一定的 source semantic 辅助能力,还是严格只做空间? |
|
|
| ## 17. 结论 |
|
|
| 当前最终方案已经明确: |
|
|
| - **完整 FOA 特征进入 BEATs 主干** |
| - **最大化复用 trunk 预训练** |
| - **重做输入 stem** |
| - **重做输出为多源 spatial tokens** |
| - **第一版采用监督式 set prediction** |
| - **最终给 LLM 的不是 raw logits,而是融合 latent 与坐标 embedding 的 spatial tokens** |
|
|
| 这是当前最符合项目目标、也最稳妥的 `Spatial-BEATs` 方案。 |
|
|