| # Spatial-BEATs 设计与训练指南 |
|
|
| ## 1. 文档目标 |
|
|
| 本文档用于整理本项目中 `Spatial-BEATs` 的任务定义、模型改造方向、保留与替换的模块、训练方法,以及后续接入 LLM 的接口约定。 |
|
|
| 目标是基于公开的 BEATs 框架,构建一个独立的 `Spatial Encoder`: |
|
|
| - 输入为 `FOA 音频` 及其派生的 `FOA 空间特征` |
| - 主干尽可能复用 `BEATs backbone` 和其预训练权重 |
| - 输出为一组 `Spatial Tokens` |
| - 这些 `Spatial Tokens` 作为独立模态输入给 LLM |
| - 原有的语义 audio encoder 保持不动,避免直接与空间 encoder 混合后产生职责不清或语义冲突 |
|
|
| 本设计不是在原有 LLM audio encoder 上强行加入空间分支,而是单独训练一个 `Spatial-BEATs`,让其专注于空间感知和空间结构建模。 |
|
|
| ## 2. 任务定义 |
|
|
| ### 2.1 核心任务 |
|
|
| `Spatial-BEATs` 的核心任务不是通用音频语义分类,而是: |
|
|
| 1. 从 `FOA` 输入中提取与空间位置相关的结构化表示 |
| 2. 在多源场景中输出 `source-level spatial tokens` |
| 3. 每个 token 尽量对应一个潜在声源,编码其空间信息 |
| 4. 后续由 LLM 使用这些 token 完成空间关系理解和推理 |
|
|
| ### 2.2 输入与输出 |
|
|
| 输入: |
|
|
| - 原始 `FOA waveform` |
| - 或由 `FOA waveform` 计算得到的 `FOA 特征图` |
|
|
| 推荐特征: |
|
|
| - `W, X, Y, Z` 的 log-mel |
| - `IV (Intensity Vector)`,例如 `IVx, IVy, IVz` |
| - 可选的 `diffuseness / coherence / phase-related` 特征 |
|
|
| 输出: |
|
|
| - `K` 个 `Spatial Tokens` |
| - 每个 token 对应一个潜在声源或一个空间实体 |
| - 每个 token 供下游预测: |
| - `objectness` |
| - `azimuth` |
| - `elevation` |
| - `distance` |
| - 可选的 `class embedding / source type embedding` |
|
|
| ### 2.3 为什么不是只用 W |
|
|
| 只让 `W` 通道经过 backbone,本质上更像单通道语义编码,空间线索主要被放到外挂 adapter 中。 |
| 这不符合本项目目标,因为这里希望: |
|
|
| - 整个 `FOA` 特征都经过主干 |
| - 主干本身学习空间结构 |
| - `Spatial-BEATs` 成为一个真正的空间 encoder,而不是一个“语义 encoder + 小空间补丁” |
|
|
| 因此,本项目的推荐路线是: |
|
|
| - `整套 FOA 特征 -> patch embedding -> BEATs backbone -> source-level spatial tokens` |
|
|
| 而不是: |
|
|
| - `W-only -> BEATs` |
| - `W-only BEATs + 外挂小 adapter` |
|
|
| ## 3. 与原始 BEATs 的关系 |
|
|
| ### 3.1 BEATs 中值得最大化复用的部分 |
|
|
| 当前仓库中的 BEATs 主干主要包括: |
|
|
| - `post_extract_proj` |
| - `TransformerEncoder` |
| - `Transformer layers` |
| - `conv_pos` |
| - `LayerNorm / FFN / attention` |
|
|
| 这些模块位于: |
|
|
| - `BEATs.py` |
| - `backbone.py` |
|
|
| 这些部分是最应该保留并加载预训练权重的。 |
|
|
| ### 3.2 BEATs 中不适合直接保留的部分 |
|
|
| 原始 BEATs 代码是单通道设计,关键假设包括: |
|
|
| - `preprocess()` 只生成单通道 `fbank` |
| - `patch_embedding` 是 `Conv2d(1, embed_dim, ...)` |
| - 下游输出默认是整段时间序列平均后的分类预测 |
|
|
| 因此,下列部分不应直接照搬: |
|
|
| 1. 单通道 `preprocess` |
| 2. 单通道 `patch_embedding` |
| 3. 最终的 clip-level 平均池化分类输出方式 |
| 4. 原始 `predictor` 作为最终目标头 |
|
|
| ### 3.3 对 BEATs 的总体改造原则 |
|
|
| 原则是: |
|
|
| - **尽量保留 trunk** |
| - **必要时重做 stem** |
| - **完全重做 spatial head** |
|
|
| 也就是: |
|
|
| - `输入端` 改 |
| - `输出端` 改 |
| - `中间主干` 尽量不改 |
|
|
| ## 4. Spatial-AST 相比 AudioMAE 的改造经验 |
|
|
| Spatial-AST 对本项目最有借鉴价值的不是其 binaural 细节,而是其改造模式。 |
|
|
| ### 4.1 Spatial-AST 做了什么 |
|
|
| 相对于原始 AudioMAE/ViT,Spatial-AST 主要改了四类模块: |
|
|
| 1. **输入前端** |
| - 从原始单通道 spectrogram 输入,改成 `双耳 log-mel + IPD` |
| - 在输入前端加入 `STFT / LogMel / IPD / conv_downsample` |
|
|
| 2. **token 设计** |
| - 把原来的单个 `cls token` 改为 `3 个任务专用 token` |
| - 分别对应: |
| - 分类 |
| - 距离 |
| - 方向 |
|
|
| 3. **输出头** |
| - 除原分类 head 外,新增: |
| - `distance_head` |
| - `azimuth_head` |
| - `elevation_head` |
|
|
| 4. **训练目标** |
| - 从单任务分类,改成多任务训练 |
| - 同时训练: |
| - sound event detection |
| - distance prediction |
| - direction prediction |
|
|
| ### 4.2 Spatial-AST 没有怎么改 |
|
|
| Spatial-AST 没有重写 Transformer block 本体。 |
| 它保留了 AudioMAE/ViT 的核心 encoder 结构,而把改动集中在: |
|
|
| - front-end |
| - tokens |
| - heads |
| - objectives |
|
|
| ### 4.3 对本项目的可迁移结论 |
|
|
| 可直接借鉴的思想: |
|
|
| 1. 用预训练音频主干初始化空间 encoder |
| 2. 重新设计输入前端,让空间特征真正进入 backbone |
| 3. 使用专门的空间 token,而不是只做全局池化 |
| 4. 使用多任务监督训练空间能力 |
|
|
| 不能直接照搬的部分: |
|
|
| 1. Spatial-AST 的 `binaural + IPD` 前端 |
| 2. 只面向单/双耳的空间 cue 设计 |
| 3. 只输出全局 token 的思路 |
|
|
| 本项目是 `FOA`,因此应该把输入前端换成 `FOA 专属空间特征`。 |
|
|
| ## 5. Spatial-BEATs 的推荐任务形式 |
|
|
| ### 5.1 单场景 token 不够 |
|
|
| 如果只输出一个全局 spatial token,它只能表示整个 scene 的压缩摘要,不适合做: |
|
|
| - 多声源关系理解 |
| - “谁在谁左边” |
| - “某个类对应的声源在什么位置” |
| - source-wise grounding |
|
|
| 既然数据中已经有多源标注,推荐直接把任务定义为: |
|
|
| - `multi-source set prediction` |
|
|
| ### 5.2 推荐输出形式 |
|
|
| 令模型输出固定数量的 `K` 个 spatial queries/tokens。 |
|
|
| 每个 token 预测: |
|
|
| - `p(obj)` |
| - `azimuth` |
| - `elevation` |
| - `distance` |
| - 可选 `class logits` 或 `class embedding` |
|
|
| 训练时使用: |
|
|
| - `Hungarian matching` |
| - 或者其他 set prediction matching |
|
|
| 把 `K` 个预测 token 与当前样本中的 `N` 个 GT 声源做一一匹配。 |
|
|
| 这会比单一 scene token 更适合后续接入 LLM 做空间关系推理。 |
|
|
| ## 6. Spatial-BEATs 的推荐模型结构 |
|
|
| 推荐结构如下: |
|
|
| ```text |
| FOA waveform |
| -> FOA front-end |
| -> FOA feature map |
| -> FOA patch embedding |
| -> BEATs Transformer trunk |
| -> source queries / spatial decoder |
| -> K spatial tokens |
| -> spatial heads |
| ``` |
|
|
| ### 6.1 FOA front-end |
|
|
| 输入可以是: |
|
|
| - `WXYZ log-mel` |
| - `WXYZ log-mel + IVx/IVy/IVz` |
|
|
| 推荐第一版就至少使用: |
|
|
| - `WXYZ` |
| - `IV` |
|
|
| 原因: |
|
|
| - 只用 `WXYZ` 仍然需要模型自己从通道关系中恢复空间线索 |
| - `IV` 直接提供有物理意义的方向信息 |
| - 对空间收敛会更稳 |
|
|
| ### 6.2 FOA patch embedding |
|
|
| 这一层应替换原始 BEATs 的单通道 patch stem。 |
|
|
| 原始: |
|
|
| - `Conv2d(1, embed_dim, kernel_size=patch, stride=patch)` |
|
|
| 新的思路: |
|
|
| - `Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)` |
|
|
| 其中 `C_foa` 可以是: |
|
|
| - `4`,如果只用 `WXYZ` |
| - `7`,如果用 `WXYZ + IVxyz` |
| - 更大,如果加入更多派生空间特征 |
|
|
| ### 6.3 BEATs trunk |
|
|
| 尽量保留以下模块: |
|
|
| - `post_extract_proj` |
| - `TransformerEncoder` |
| - `attention` |
| - `FFN` |
| - `conv_pos` |
| - `LayerNorm` |
|
|
| 这是整个“最大化复用预训练权重”的核心。 |
|
|
| ### 6.4 Spatial token 模块 |
|
|
| 不要继续使用原始 BEATs 的: |
|
|
| - mean pooling |
| - clip-level predictor |
|
|
| 推荐改为: |
|
|
| - `K` 个 learnable source queries |
| - queries 对 trunk 输出做 attention |
| - 得到 `K` 个 source-level spatial tokens |
|
|
| 如果实现上希望更简单,第一版也可以: |
|
|
| - 先直接在 trunk 输出后接一个轻量 decoder |
| - 再输出 `K` 个 tokens |
|
|
| ### 6.5 Spatial heads |
|
|
| 每个 token 对应: |
|
|
| - `objectness head` |
| - `azimuth head` |
| - `elevation head` |
| - `distance head` |
| - 可选 `class head` |
|
|
| 如果担心与原 LLM audio encoder 产生语义冲突,则建议: |
|
|
| - 把 `class head` 只作为辅助监督 |
| - 不把其输出作为最终送入 LLM 的主要表示 |
|
|
| ## 7. 保留的部分 |
|
|
| 下面这些建议尽量保留: |
|
|
| ### 7.1 保留原有 LLM audio encoder |
|
|
| 原始语义 audio encoder 不动,继续负责: |
|
|
| - 音频内容语义 |
| - 事件类别理解 |
| - 与现有 LLM 接口保持兼容 |
|
|
| ### 7.2 保留 Spatial-BEATs 作为独立 encoder |
|
|
| `Spatial-BEATs` 单独负责: |
|
|
| - 方向 |
| - 距离 |
| - 多源空间结构 |
| - 可选 source-wise 辅助类别信息 |
|
|
| ### 7.3 保留 BEATs 主干参数初始化 |
|
|
| 建议保留: |
|
|
| - trunk 的预训练参数 |
| - 尽量避免从零训练整个 spatial encoder |
|
|
| ## 8. 需要替换或新增的部分 |
|
|
| ### 8.1 必改模块 |
|
|
| 必须修改: |
|
|
| 1. `preprocess` |
| 2. `patch_embedding` |
| 3. `forward / extract_features` 的输出方式 |
| 4. 下游 `predictor` |
|
|
| ### 8.2 必增模块 |
|
|
| 必须新增: |
|
|
| 1. `FOA spatial front-end` |
| 2. `spatial query / token module` |
| 3. `multi-head spatial prediction heads` |
| 4. `set matching / multi-source loss` |
|
|
| ### 8.3 建议新增模块 |
|
|
| 建议新增: |
|
|
| 1. `source confidence / objectness` |
| 2. `auxiliary class supervision` |
| 3. `LLM projection head` |
|
|
| ## 9. 训练方法 |
|
|
| ## 9.1 第一阶段是否需要 SSL |
|
|
| 当前结论是: |
|
|
| - **第一版不需要重新做 BEATs 式 SSL** |
|
|
| 原因: |
|
|
| 1. 已经有多源 `GT relative positions` |
| 2. 目标不是再学通用音频语义,而是让模型学空间结构 |
| 3. 已有 BEATs 预训练权重可作为稳定初始化 |
| 4. 先做监督式空间学习,工程收益最高 |
|
|
| 因此,推荐第一阶段直接做 `supervised multi-task training`。 |
|
|
| ### 9.2 第一阶段训练目标 |
|
|
| 基础目标: |
|
|
| - `L_obj` |
| - `L_azimuth` |
| - `L_elevation` |
| - `L_distance` |
|
|
| 可选目标: |
|
|
| - `L_class_aux` |
|
|
| 总损失可写为: |
|
|
| ```text |
| L = lambda_obj * L_obj |
| + lambda_azi * L_azimuth |
| + lambda_ele * L_elevation |
| + lambda_dist * L_distance |
| + lambda_cls * L_class_aux |
| ``` |
|
|
| 这里建议: |
|
|
| - 空间任务作为主目标 |
| - 类别只做辅助目标 |
|
|
| 因为本项目中语义主责已经由原始 audio encoder 承担。 |
|
|
| ### 9.3 多源匹配训练 |
|
|
| 如果每个样本有多个声源标注,推荐: |
|
|
| 1. 模型输出固定数量 `K` 个 tokens |
| 2. 用 Hungarian matching 在 token 与 GT 源之间做匹配 |
| 3. 对 matched token 计算位置损失 |
| 4. 对 unmatched token 计算 no-object loss |
|
|
| 这是比“对所有 token 平均做 scene 监督”更合适的做法。 |
|
|
| ### 9.4 训练阶段建议 |
|
|
| 推荐三步走: |
|
|
| #### Stage A: Stem + Head Warmup |
|
|
| - 冻结大部分 BEATs trunk |
| - 只训练: |
| - FOA front-end |
| - FOA patch embedding |
| - spatial token/query module |
| - spatial heads |
|
|
| 目的: |
|
|
| - 让新输入 stem 和新 heads 先适配预训练 trunk |
|
|
| #### Stage B: Upper Trunk Finetune |
|
|
| - 解冻 BEATs 上层若干层 |
| - 使用较小学习率微调 |
| - 使用 layer-wise lr decay |
|
|
| 目的: |
|
|
| - 让 trunk 逐步适配 FOA 分布和空间任务 |
|
|
| #### Stage C: Full or Near-Full Finetune |
|
|
| - 在稳定后解冻更多层 |
| - 继续以空间目标微调 |
|
|
| 目的: |
|
|
| - 提升空间 token 的表达能力 |
|
|
| ### 9.5 训练数据组织 |
|
|
| 每个样本应包含: |
|
|
| - `foa waveform` |
| - `num_sources` |
| - `per-source azimuth` |
| - `per-source elevation` |
| - `per-source distance` |
| - 可选 `per-source class` |
|
|
| 推荐统一成: |
|
|
| ```text |
| sample = { |
| "audio": ..., |
| "sources": [ |
| {"azimuth": ..., "elevation": ..., "distance": ..., "label": ...}, |
| ... |
| ] |
| } |
| ``` |
|
|
| ## 10. 与 LLM 的接口 |
|
|
| ### 10.1 推荐输入形式 |
|
|
| 最终不要把 Spatial-BEATs 的全部 dense patch tokens 都喂给 LLM。 |
| 推荐只输出: |
|
|
| - `K` 个 `Spatial Tokens` |
|
|
| 每个 token 代表一个潜在空间实体。 |
|
|
| ### 10.2 避免语义冲突的策略 |
|
|
| 本项目中“避免语义冲突”的关键不是完全不学类别,而是: |
|
|
| 1. 原始 audio encoder 继续承担主要语义理解 |
| 2. Spatial-BEATs 主要承担空间结构建模 |
| 3. Spatial-BEATs 输出给 LLM 的是 `source-level spatial tokens` |
| 4. 辅助类别监督只用于训练,不一定直接暴露为最终模态表示 |
|
|
| 这样两套 encoder 的职责边界更清晰: |
|
|
| - 原语义 encoder:`what` |
| - Spatial-BEATs:`where / relation / spatial structure` |
|
|
| ## 11. 当前推荐方案总结 |
|
|
| 当前最推荐的路线不是: |
|
|
| - `W-only BEATs` |
| - `W-only + adapter` |
|
|
| 而是: |
|
|
| - `FOA full-feature Spatial-BEATs` |
| - `独立空间 encoder` |
| - `最大化复用 BEATs trunk` |
| - `重做输入 stem` |
| - `重做多源 spatial token heads` |
|
|
| 用一句话总结就是: |
|
|
| > 用 BEATs 的主干做 FOA 空间建模,而不是只拿 BEATs 当单通道语义骨干,再在旁边打一个空间补丁。 |
|
|
| ## 12. 推荐实施顺序 |
|
|
| 建议按以下顺序推进: |
|
|
| 1. 明确 `FOA feature schema` |
| - 是否使用 `WXYZ` |
| - 是否加入 `IV` |
| - 是否加入其他物理特征 |
|
|
| 2. 设计 `Spatial-BEATs` 的新输入 stem |
| - 替换单通道 preprocess |
| - 替换 patch embedding |
|
|
| 3. 设计 `K-source spatial tokens` |
| - 确定 token 数量 |
| - 确定 query 机制 |
|
|
| 4. 实现多头空间预测 |
| - objectness |
| - distance |
| - azimuth |
| - elevation |
| - optional class |
|
|
| 5. 实现多源匹配训练 |
| - Hungarian matching |
| - no-object loss |
|
|
| 6. 先做监督训练 |
| - 不急于加入 SSL |
|
|
| 7. 训练稳定后,再评估是否需要第二阶段自监督 |
| - masked spatial cue prediction |
| - view consistency |
| - teacher distillation |
|
|
| ## 13. 第二阶段可选方向 |
|
|
| 在第一版监督训练稳定后,可考虑加入: |
|
|
| 1. `Spatial SSL` |
| - masked FOA cue prediction |
| - spatial consistency loss |
| - teacher-student distillation |
|
|
| 2. `source-conditioned tokens` |
| - 类别条件的 source token |
|
|
| 3. `LLM-side alignment` |
| - 将 spatial token 投影到 LLM hidden space |
| - 与文本空间词汇进行弱对齐 |
|
|
| 4. `更复杂的多源推理训练` |
| - source relation supervision |
| - pairwise spatial relation labels |
|
|
| ## 14. 结论 |
|
|
| 本项目的推荐方向已经比较明确: |
|
|
| - 使用 `FOA 全特征` |
| - 让 `FOA 特征` 真正经过 `BEATs backbone` |
| - 将 `BEATs` 改造成独立的 `Spatial-BEATs` |
| - 保留原语义 audio encoder 不动 |
| - 以 `多源 spatial token prediction` 为核心任务 |
| - 第一阶段采用 `监督式空间训练` |
| - 最大化复用 `BEATs trunk` 的预训练权重 |
|
|
| 如果后续开始实现,建议优先落地以下最小可行版本: |
|
|
| 1. `WXYZ + IV` 输入 |
| 2. 多通道 patch embedding |
| 3. BEATs trunk 复用 |
| 4. `K` 个 spatial queries |
| 5. `objectness + azimuth + elevation + distance` 多头监督 |
|
|
| 这会比任何 `W-only` 或“仅外挂 adapter”的方案更符合项目目标,也更适合最终作为 LLM 的空间模态输入。 |
|
|