# Spatial-BEATs 实现规格 ## 1. 目标 本规格文档用于将前期讨论收敛为一个可以直接实施的 `Spatial-BEATs` 方案。 目标是构建一个独立的 `Spatial Encoder`: - 输入为完整 `FOA` 音频及其派生空间特征 - 完整的 `FOA` 特征经过 `BEATs backbone` - 最大化复用 `BEATs` 预训练权重 - 输出一组 `source-level spatial tokens` - 这些 token 作为独立模态输入给 LLM - 原有语义 audio encoder 保持不动 这里的关键原则是: > 不是让 `W-only` 走主干,再外挂一个小空间 adapter;而是让完整 FOA 空间特征真正进入 BEATs 主干,并在主干之后产出结构化空间 token。 ## 2. 最终任务定义 ### 2.1 核心任务 `Spatial-BEATs` 的主任务定义为: - 给定一个多源 `FOA` 音频片段 - 预测其中最多 `K` 个潜在声源的空间表示 - 每个表示对应一个 `source token` 每个 source token 至少承载: - `objectness` - `azimuth` - `elevation` - `distance` 可选承载: - `source class auxiliary logits` - `source embedding` ### 2.2 推荐监督形式 如果训练数据中每个源都有标注,则推荐采用: - `set prediction` - `K` 个预测 token 对 `N` 个 GT sources - 用 `Hungarian matching` 做一一匹配 不建议采用: - 单一 scene-level spatial token - 仅回归整段音频的全局空间摘要 原因是这会损失多源结构,不利于后续 LLM 做关系推理。 ## 3. 最终架构 推荐最终架构: ```text FOA waveform -> SpatialBEATsPreprocessor -> FOA feature map [B, C_foa, T, F] -> FOA patch embedding -> BEATs trunk -> Spatial query decoder -> K source tokens -> Spatial prediction heads -> LLM projector ``` 为了最大化复用 BEATs 主干,本方案尽量不改 trunk 内部的 Transformer 结构。 ## 4. 输入特征定义 ### 4.1 默认推荐特征 第一版推荐输入通道: - `W_logmel` - `X_logmel` - `Y_logmel` - `Z_logmel` - `IVx` - `IVy` - `IVz` 即: - `C_foa = 7` 这是默认推荐方案。 ### 4.2 备选输入特征 若希望先降低复杂度,可以使用: - `WXYZ logmel` 即: - `C_foa = 4` 但这只适合最小原型。 如果目标是稳定学习空间方向与结构,优先使用 `WXYZ + IV`。 ### 4.3 前端参数建议 为了最大化复用 BEATs 主干,推荐保持与 BEATs 接近的时频分辨率: - sample rate:优先 `16k` - mel bins:`128` - frame length:`25 ms` - frame shift:`10 ms` 原因: - 这能让 trunk 看到与原始 BEATs 更接近的 patch 几何结构 - patch embedding 和后续序列长度更容易保持一致 - 预训练权重复用更稳定 ### 4.4 为什么不沿用 Spatial-AST 的 binaural 前端 Spatial-AST 采用的是: - 双耳 log-mel - IPD 这适合 binaural,不适合直接迁移到 FOA。 FOA 下应优先利用: - ambisonic 通道本身 - intensity vector - 其他 FOA 物理特征 ## 5. 对 BEATs 具体修改哪些模块 下面按模块说明修改方案。 ### 5.1 保留不动的模块 建议尽量保留: - `TransformerEncoder` - `TransformerSentenceEncoderLayer` - `MultiheadAttention` - `conv_pos` - `LayerNorm` - `FFN` - `post_extract_proj` 也就是 `backbone.py` 内的主干结构和 `BEATs.py` 中的 trunk 逻辑尽量不动。 ### 5.2 必须修改的模块 必须重做: 1. `preprocess` 2. `patch_embedding` 3. `extract_features` 输出头部逻辑 4. 下游 `predictor` ### 5.3 推荐新增的模块 建议新增: 1. `SpatialBEATsPreprocessor` 2. `SpatialPatchEmbedding` 3. `SpatialQueryDecoder` 4. `SpatialPredictionHead` 5. `SpatialTokenProjector` 6. `HungarianMatcher` 7. `SpatialSetCriterion` ## 6. 代码级映射建议 ### 6.1 现有文件建议 建议保留和复用: - [BEATs.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/BEATs.py) - [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py) 建议新增: - `spatial_beats.py` - `spatial_modules.py` - `spatial_loss.py` - `spatial_dataset.py` - `train_spatial_beats.py` ### 6.2 `spatial_beats.py` 建议包含 建议实现: - `SpatialBEATsConfig` - `SpatialBEATs` - `SpatialBEATs.extract_spatial_tokens()` - `SpatialBEATs.forward()` ### 6.3 `spatial_modules.py` 建议包含 建议实现: - `SpatialBEATsPreprocessor` - `SpatialPatchEmbedding` - `SpatialQueryDecoder` - `SpatialPredictionHead` - `SpatialTokenProjector` ### 6.4 `spatial_loss.py` 建议包含 建议实现: - `HungarianMatcher` - `SpatialSetCriterion` ## 7. 预训练权重如何复用 ## 7.1 默认推荐权重 默认推荐: - `BEATs_iter3+ (AS2M) pre-trained` 而不是: - fine-tuned checkpoints 原因: - `pre-trained` 更适合作为 trunk 初始化 - `fine-tuned` 更偏向 AudioSet 分类判别 - 你这里的 spatial encoder 应与原语义 encoder 职责分离 ### 7.2 必须直接加载的层 这些层建议直接加载原 BEATs checkpoint: - `post_extract_proj` - `encoder.pos_conv` - `encoder.layers.*` - `encoder.layer_norm` - `layer_norm` 即除了输入 stem 和输出头,主干参数都尽量继承。 ### 7.3 需要特殊初始化的层 以下层因为 shape 不同,不能直接 strict load: - `patch_embedding` - 新增的 `query decoder` - 新增的 `spatial heads` - 新增的 `LLM projector` ### 7.4 新 patch embedding 的初始化策略 原 BEATs stem 是: - `Conv2d(1, embed_dim, kernel_size=patch, stride=patch)` 新 stem 建议是: - `Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)` 推荐初始化策略: #### 方案 A:保守初始化,默认推荐 - `W_logmel` 通道继承原 stem 权重 - 其他空间通道初始化为 `0` 或较小随机值 优点: - 最大程度保留原 BEATs 初始分布 - trunk 适配更稳 缺点: - 训练初期空间通道利用较慢 #### 方案 B:通道 inflation - 把原 stem 权重复制到全部输入通道 - 再按通道数做归一化 优点: - 所有通道一开始都能进入主干 缺点: - 初始统计更可能偏离原 BEATs 最终推荐: - 第一版用 `方案 A` - 后续做 ablation 再比较 `方案 B` ## 8. Spatial token 模块的最终设计 ### 8.1 为什么不用全局池化 原始 BEATs 的输出方式更接近: - patch sequence - mean pooling - clip-level prediction 这不适合多源空间任务。 ### 8.2 最终推荐:Query Decoder 在 trunk 输出后新增: - `K` 个 learnable source queries - 一个轻量 `cross-attention decoder` 输入: - encoder memory:`H in R^{B x T x D}` - source queries:`Q in R^{B x K x D}` 输出: - `Z in R^{B x K x D}` 这里的 `Z[:, i, :]` 即第 `i` 个 `source token` ### 8.3 为什么 query decoder 是当前最优解 它的优点: - 不改 trunk 内部结构 - 仍然让完整 FOA 特征经过 backbone - 适合多源 set prediction - 最利于最大化复用 trunk 权重 ## 9. 输出头设计 对每个 source token `z_i`,预测: - `objectness` - `azimuth` - `elevation` - `distance` - 可选 `class_aux` ### 9.1 离散还是连续 第一版推荐全部使用离散分类头: - `azimuth`: 360 bins - `elevation`: 180 bins - `distance`: 按数据分桶,例如 `0.5m` 一档 原因: - 与已有 Spatial-AST/BAT 经验一致 - 分类头更稳 - 更便于构造离散坐标 embedding ### 9.2 objectness 头 推荐增加: - `objectness_head: D -> 1` 用于: - 判断当前 token 是否对应真实声源 - 作为 Hungarian matching 的一部分 - 推理时做 token 保留/裁剪 ### 9.3 类别头 类别头建议作为: - `auxiliary head` 而不是最终 LLM 的主要输入内容。 这样做的作用: - 让 query token 更容易学会 source slot 对齐 - 但不把 Spatial-BEATs 变成第二个强语义 encoder ## 10. Loss 设计 推荐总损失: ```text L_total = lambda_obj * L_obj + lambda_azi * L_azi + lambda_ele * L_ele + lambda_dist * L_dist + lambda_cls * L_cls_aux ``` ### 10.1 匹配方式 使用 `Hungarian matching`: - 预测:`K` 个 token - GT:`N` 个 sources - 成本由以下项构成: - objectness cost - azimuth cost - elevation cost - distance cost - optional class cost ### 10.2 损失项定义 推荐: - `L_obj`: BCE 或 focal loss - `L_azi`: cross entropy - `L_ele`: cross entropy - `L_dist`: cross entropy - `L_cls_aux`: cross entropy 或 BCE ### 10.3 初始 loss 权重建议 第一版建议从以下权重起步: ```text lambda_obj = 1.0 lambda_azi = 2.0 lambda_ele = 2.0 lambda_dist = 1.0 lambda_cls = 0.25 ``` 解释: - 方向任务通常更关键 - 距离次之 - objectness 必须稳定 - 类别监督只作为辅助 ### 10.4 不建议的做法 第一版不建议: - 重分类损失压倒空间损失 - 直接照搬 Spatial-AST 的 `1250 * cls` 原因: - Spatial-AST 的目标之一是保住 sound event detection - 这里 `Spatial-BEATs` 的主要目标是空间 token - 原项目已有独立语义 encoder ## 11. 训练策略 ### 11.1 第一阶段是否需要 SSL 当前最终结论: - 第一版 **不需要** 重新做 BEATs 式 SSL 因为当前已经有: - 多源监督 - 每个源的空间标注 - 可复用的 BEATs 主干预训练 所以第一阶段应优先做: - `supervised multi-source spatial training` ### 11.2 分阶段训练建议 #### Stage A:Warmup 冻结: - 大部分 trunk 只训练: - FOA preprocessor - patch embedding - query decoder - spatial heads - LLM projector 目的: - 让新输入 stem 和新输出头稳定接入 trunk #### Stage B:Upper-trunk finetune 解冻: - trunk 上层若干层 目的: - 让主干逐步适应 FOA 空间任务 #### Stage C:Near-full finetune 进一步解冻: - 更多 encoder layers 目的: - 提升空间表示上限 ### 11.3 学习率建议 推荐: - trunk:较小 lr - 新模块:较大学习率 例如: ```text lr_trunk = 1e-5 ~ 5e-5 lr_new = 1e-4 ~ 5e-4 ``` 并配合: - layer-wise lr decay ## 12. 最终输出给 LLM 的 spatial token 形式 这是本项目最关键的接口定义之一。 ### 12.1 内部 token 形式 `Spatial-BEATs` 内部输出: - `Z in R^{B x K x D}` 其中: - `B`: batch size - `K`: source token 数 - `D`: Spatial-BEATs hidden dim,建议与 BEATs trunk 一致 ### 12.2 不建议直接把 raw logits 喂给 LLM 不建议直接给 LLM: - azimuth logits - elevation logits - distance logits - objectness logits 这些是监督头,不是最终模态表示。 ### 12.3 最终推荐的 LLM spatial token 形式 最终推荐送给 LLM 的每个 token 形式为: ```text s_i = Proj([z_i ; e_azi(i) ; e_ele(i) ; e_dist(i) ; e_obj(i)]) ``` 其中: - `z_i`: query decoder 输出的 latent token - `e_azi(i)`: 由预测 azimuth bin 查表得到的 embedding - `e_ele(i)`: 由预测 elevation bin 查表得到的 embedding - `e_dist(i)`: 由预测 distance bin 查表得到的 embedding - `e_obj(i)`: 由 objectness/confidence 产生的 embedding - `Proj`: 投影到 LLM hidden size 的 MLP/Linear 最终: - `s_i in R^{d_llm}` ### 12.4 为什么采用“latent + structured embedding”的混合形式 原因: 1. `z_i` 保留丰富的隐式空间结构信息 2. `坐标 embedding` 给 LLM 显式离散空间线索 3. `confidence` 有助于 LLM 区分可靠/不可靠 token 这比单纯只传: - raw latent token 或者只传: - 显式坐标 one-hot / scalar 都更合适。 ### 12.5 最终序列形式 送入 LLM 时推荐: ```text , s_1, s_2, ..., s_K, ``` 并且: - 按 `objectness` 从高到低排序 - 对低置信 token 可直接截断或 mask ### 12.6 是否保留全部 K 个 token 默认推荐: - 训练时保留全部 `K` - 推理时按 `objectness` 过滤 例如: - 保留前 `K_keep` - 或保留 `obj > threshold` 的 token ## 13. 与原语义 audio encoder 的关系 为了避免“两个 encoder 在做同样的事”,推荐如下职责划分: - 原语义 audio encoder:负责 `what` - Spatial-BEATs:负责 `where / spatial structure / relations` ### 13.1 是否允许 Spatial-BEATs 学类别 允许,但只作为辅助。 建议: - 类别头只用于训练 - 最终输入给 LLM 的空间 token 不直接暴露完整类别 logits ### 13.2 是否需要和语义 encoder 做对齐 第一版不是必须。 若后续希望更强的 source grounding,可进一步加入: - semantic distillation - cross-encoder alignment - source-wise contrastive loss 但这些应放到第二阶段。 ## 14. 第一版推荐配置 第一版默认建议: - 输入特征:`WXYZ + IVxyz` - `C_foa = 7` - 采样率:`16k` - mel bins:`128` - patch 配置:与 BEATs 保持一致 - 预训练权重:`BEATs_iter3+ AS2M pre-trained` - trunk:最大化加载 - patch stem:`W` 继承,其余通道小初始化 - 输出:`K` 个 source tokens - token 解码:轻量 query decoder - 监督:Hungarian matching + 多头空间分类 - LLM 输入:`latent + structured coordinate embedding` 的混合 token ## 15. 实现优先级 推荐按如下优先级推进: 1. 实现 `FOA preprocessor` 2. 实现多通道 `patch embedding` 3. 完成 trunk ckpt 加载 4. 实现 `query decoder` 5. 实现 `objectness / azi / ele / dist` heads 6. 实现 `Hungarian matcher + criterion` 7. 实现 `LLM projector` 8. 完成训练脚本 ## 16. 当前仍需用户确认的问题 以下问题会直接影响第一版实现细节: 1. `FOA` 数据当前主要采样率是多少?是 `16k`、`24k`、`32k` 还是 `48k`? 2. 每个样本中 `最大同时源数` 大概是多少?这会影响 `K` 的默认设定。 3. 每个源是否都有 `source-level class label`?如果有,类别头和匹配会更稳。 4. 你希望 `distance` 是离散分类还是连续回归?当前默认推荐离散分类。 5. 下游 LLM 的 hidden size 是多少?是否已有固定的 audio token projector? 6. 你是否希望 Spatial-BEATs 在第一版就具备一定的 source semantic 辅助能力,还是严格只做空间? ## 17. 结论 当前最终方案已经明确: - **完整 FOA 特征进入 BEATs 主干** - **最大化复用 trunk 预训练** - **重做输入 stem** - **重做输出为多源 spatial tokens** - **第一版采用监督式 set prediction** - **最终给 LLM 的不是 raw logits,而是融合 latent 与坐标 embedding 的 spatial tokens** 这是当前最符合项目目标、也最稳妥的 `Spatial-BEATs` 方案。