Spatial-BEATs / docs /spatial_beats_design_guide.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
14.1 kB

Spatial-BEATs 设计与训练指南

1. 文档目标

本文档用于整理本项目中 Spatial-BEATs 的任务定义、模型改造方向、保留与替换的模块、训练方法,以及后续接入 LLM 的接口约定。

目标是基于公开的 BEATs 框架,构建一个独立的 Spatial Encoder

  • 输入为 FOA 音频 及其派生的 FOA 空间特征
  • 主干尽可能复用 BEATs backbone 和其预训练权重
  • 输出为一组 Spatial Tokens
  • 这些 Spatial Tokens 作为独立模态输入给 LLM
  • 原有的语义 audio encoder 保持不动,避免直接与空间 encoder 混合后产生职责不清或语义冲突

本设计不是在原有 LLM audio encoder 上强行加入空间分支,而是单独训练一个 Spatial-BEATs,让其专注于空间感知和空间结构建模。

2. 任务定义

2.1 核心任务

Spatial-BEATs 的核心任务不是通用音频语义分类,而是:

  1. FOA 输入中提取与空间位置相关的结构化表示
  2. 在多源场景中输出 source-level spatial tokens
  3. 每个 token 尽量对应一个潜在声源,编码其空间信息
  4. 后续由 LLM 使用这些 token 完成空间关系理解和推理

2.2 输入与输出

输入:

  • 原始 FOA waveform
  • 或由 FOA waveform 计算得到的 FOA 特征图

推荐特征:

  • W, X, Y, Z 的 log-mel
  • IV (Intensity Vector),例如 IVx, IVy, IVz
  • 可选的 diffuseness / coherence / phase-related 特征

输出:

  • KSpatial Tokens
  • 每个 token 对应一个潜在声源或一个空间实体
  • 每个 token 供下游预测:
    • objectness
    • azimuth
    • elevation
    • distance
    • 可选的 class embedding / source type embedding

2.3 为什么不是只用 W

只让 W 通道经过 backbone,本质上更像单通道语义编码,空间线索主要被放到外挂 adapter 中。
这不符合本项目目标,因为这里希望:

  • 整个 FOA 特征都经过主干
  • 主干本身学习空间结构
  • Spatial-BEATs 成为一个真正的空间 encoder,而不是一个“语义 encoder + 小空间补丁”

因此,本项目的推荐路线是:

  • 整套 FOA 特征 -> patch embedding -> BEATs backbone -> source-level spatial tokens

而不是:

  • W-only -> BEATs
  • W-only BEATs + 外挂小 adapter

3. 与原始 BEATs 的关系

3.1 BEATs 中值得最大化复用的部分

当前仓库中的 BEATs 主干主要包括:

  • post_extract_proj
  • TransformerEncoder
  • Transformer layers
  • conv_pos
  • LayerNorm / FFN / attention

这些模块位于:

  • BEATs.py
  • backbone.py

这些部分是最应该保留并加载预训练权重的。

3.2 BEATs 中不适合直接保留的部分

原始 BEATs 代码是单通道设计,关键假设包括:

  • preprocess() 只生成单通道 fbank
  • patch_embeddingConv2d(1, embed_dim, ...)
  • 下游输出默认是整段时间序列平均后的分类预测

因此,下列部分不应直接照搬:

  1. 单通道 preprocess
  2. 单通道 patch_embedding
  3. 最终的 clip-level 平均池化分类输出方式
  4. 原始 predictor 作为最终目标头

3.3 对 BEATs 的总体改造原则

原则是:

  • 尽量保留 trunk
  • 必要时重做 stem
  • 完全重做 spatial head

也就是:

  • 输入端
  • 输出端
  • 中间主干 尽量不改

4. Spatial-AST 相比 AudioMAE 的改造经验

Spatial-AST 对本项目最有借鉴价值的不是其 binaural 细节,而是其改造模式。

4.1 Spatial-AST 做了什么

相对于原始 AudioMAE/ViT,Spatial-AST 主要改了四类模块:

  1. 输入前端

    • 从原始单通道 spectrogram 输入,改成 双耳 log-mel + IPD
    • 在输入前端加入 STFT / LogMel / IPD / conv_downsample
  2. token 设计

    • 把原来的单个 cls token 改为 3 个任务专用 token
    • 分别对应:
      • 分类
      • 距离
      • 方向
  3. 输出头

    • 除原分类 head 外,新增:
      • distance_head
      • azimuth_head
      • elevation_head
  4. 训练目标

    • 从单任务分类,改成多任务训练
    • 同时训练:
      • sound event detection
      • distance prediction
      • direction prediction

4.2 Spatial-AST 没有怎么改

Spatial-AST 没有重写 Transformer block 本体。
它保留了 AudioMAE/ViT 的核心 encoder 结构,而把改动集中在:

  • front-end
  • tokens
  • heads
  • objectives

4.3 对本项目的可迁移结论

可直接借鉴的思想:

  1. 用预训练音频主干初始化空间 encoder
  2. 重新设计输入前端,让空间特征真正进入 backbone
  3. 使用专门的空间 token,而不是只做全局池化
  4. 使用多任务监督训练空间能力

不能直接照搬的部分:

  1. Spatial-AST 的 binaural + IPD 前端
  2. 只面向单/双耳的空间 cue 设计
  3. 只输出全局 token 的思路

本项目是 FOA,因此应该把输入前端换成 FOA 专属空间特征

5. Spatial-BEATs 的推荐任务形式

5.1 单场景 token 不够

如果只输出一个全局 spatial token,它只能表示整个 scene 的压缩摘要,不适合做:

  • 多声源关系理解
  • “谁在谁左边”
  • “某个类对应的声源在什么位置”
  • source-wise grounding

既然数据中已经有多源标注,推荐直接把任务定义为:

  • multi-source set prediction

5.2 推荐输出形式

令模型输出固定数量的 K 个 spatial queries/tokens。

每个 token 预测:

  • p(obj)
  • azimuth
  • elevation
  • distance
  • 可选 class logitsclass embedding

训练时使用:

  • Hungarian matching
  • 或者其他 set prediction matching

K 个预测 token 与当前样本中的 N 个 GT 声源做一一匹配。

这会比单一 scene token 更适合后续接入 LLM 做空间关系推理。

6. Spatial-BEATs 的推荐模型结构

推荐结构如下:

FOA waveform
  -> FOA front-end
  -> FOA feature map
  -> FOA patch embedding
  -> BEATs Transformer trunk
  -> source queries / spatial decoder
  -> K spatial tokens
  -> spatial heads

6.1 FOA front-end

输入可以是:

  • WXYZ log-mel
  • WXYZ log-mel + IVx/IVy/IVz

推荐第一版就至少使用:

  • WXYZ
  • IV

原因:

  • 只用 WXYZ 仍然需要模型自己从通道关系中恢复空间线索
  • IV 直接提供有物理意义的方向信息
  • 对空间收敛会更稳

6.2 FOA patch embedding

这一层应替换原始 BEATs 的单通道 patch stem。

原始:

  • Conv2d(1, embed_dim, kernel_size=patch, stride=patch)

新的思路:

  • Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)

其中 C_foa 可以是:

  • 4,如果只用 WXYZ
  • 7,如果用 WXYZ + IVxyz
  • 更大,如果加入更多派生空间特征

6.3 BEATs trunk

尽量保留以下模块:

  • post_extract_proj
  • TransformerEncoder
  • attention
  • FFN
  • conv_pos
  • LayerNorm

这是整个“最大化复用预训练权重”的核心。

6.4 Spatial token 模块

不要继续使用原始 BEATs 的:

  • mean pooling
  • clip-level predictor

推荐改为:

  • K 个 learnable source queries
  • queries 对 trunk 输出做 attention
  • 得到 K 个 source-level spatial tokens

如果实现上希望更简单,第一版也可以:

  • 先直接在 trunk 输出后接一个轻量 decoder
  • 再输出 K 个 tokens

6.5 Spatial heads

每个 token 对应:

  • objectness head
  • azimuth head
  • elevation head
  • distance head
  • 可选 class head

如果担心与原 LLM audio encoder 产生语义冲突,则建议:

  • class head 只作为辅助监督
  • 不把其输出作为最终送入 LLM 的主要表示

7. 保留的部分

下面这些建议尽量保留:

7.1 保留原有 LLM audio encoder

原始语义 audio encoder 不动,继续负责:

  • 音频内容语义
  • 事件类别理解
  • 与现有 LLM 接口保持兼容

7.2 保留 Spatial-BEATs 作为独立 encoder

Spatial-BEATs 单独负责:

  • 方向
  • 距离
  • 多源空间结构
  • 可选 source-wise 辅助类别信息

7.3 保留 BEATs 主干参数初始化

建议保留:

  • trunk 的预训练参数
  • 尽量避免从零训练整个 spatial encoder

8. 需要替换或新增的部分

8.1 必改模块

必须修改:

  1. preprocess
  2. patch_embedding
  3. forward / extract_features 的输出方式
  4. 下游 predictor

8.2 必增模块

必须新增:

  1. FOA spatial front-end
  2. spatial query / token module
  3. multi-head spatial prediction heads
  4. set matching / multi-source loss

8.3 建议新增模块

建议新增:

  1. source confidence / objectness
  2. auxiliary class supervision
  3. LLM projection head

9. 训练方法

9.1 第一阶段是否需要 SSL

当前结论是:

  • 第一版不需要重新做 BEATs 式 SSL

原因:

  1. 已经有多源 GT relative positions
  2. 目标不是再学通用音频语义,而是让模型学空间结构
  3. 已有 BEATs 预训练权重可作为稳定初始化
  4. 先做监督式空间学习,工程收益最高

因此,推荐第一阶段直接做 supervised multi-task training

9.2 第一阶段训练目标

基础目标:

  • L_obj
  • L_azimuth
  • L_elevation
  • L_distance

可选目标:

  • L_class_aux

总损失可写为:

L = lambda_obj * L_obj
  + lambda_azi * L_azimuth
  + lambda_ele * L_elevation
  + lambda_dist * L_distance
  + lambda_cls * L_class_aux

这里建议:

  • 空间任务作为主目标
  • 类别只做辅助目标

因为本项目中语义主责已经由原始 audio encoder 承担。

9.3 多源匹配训练

如果每个样本有多个声源标注,推荐:

  1. 模型输出固定数量 K 个 tokens
  2. 用 Hungarian matching 在 token 与 GT 源之间做匹配
  3. 对 matched token 计算位置损失
  4. 对 unmatched token 计算 no-object loss

这是比“对所有 token 平均做 scene 监督”更合适的做法。

9.4 训练阶段建议

推荐三步走:

Stage A: Stem + Head Warmup

  • 冻结大部分 BEATs trunk
  • 只训练:
    • FOA front-end
    • FOA patch embedding
    • spatial token/query module
    • spatial heads

目的:

  • 让新输入 stem 和新 heads 先适配预训练 trunk

Stage B: Upper Trunk Finetune

  • 解冻 BEATs 上层若干层
  • 使用较小学习率微调
  • 使用 layer-wise lr decay

目的:

  • 让 trunk 逐步适配 FOA 分布和空间任务

Stage C: Full or Near-Full Finetune

  • 在稳定后解冻更多层
  • 继续以空间目标微调

目的:

  • 提升空间 token 的表达能力

9.5 训练数据组织

每个样本应包含:

  • foa waveform
  • num_sources
  • per-source azimuth
  • per-source elevation
  • per-source distance
  • 可选 per-source class

推荐统一成:

sample = {
  "audio": ...,
  "sources": [
    {"azimuth": ..., "elevation": ..., "distance": ..., "label": ...},
    ...
  ]
}

10. 与 LLM 的接口

10.1 推荐输入形式

最终不要把 Spatial-BEATs 的全部 dense patch tokens 都喂给 LLM。
推荐只输出:

  • KSpatial Tokens

每个 token 代表一个潜在空间实体。

10.2 避免语义冲突的策略

本项目中“避免语义冲突”的关键不是完全不学类别,而是:

  1. 原始 audio encoder 继续承担主要语义理解
  2. Spatial-BEATs 主要承担空间结构建模
  3. Spatial-BEATs 输出给 LLM 的是 source-level spatial tokens
  4. 辅助类别监督只用于训练,不一定直接暴露为最终模态表示

这样两套 encoder 的职责边界更清晰:

  • 原语义 encoder:what
  • Spatial-BEATs:where / relation / spatial structure

11. 当前推荐方案总结

当前最推荐的路线不是:

  • W-only BEATs
  • W-only + adapter

而是:

  • FOA full-feature Spatial-BEATs
  • 独立空间 encoder
  • 最大化复用 BEATs trunk
  • 重做输入 stem
  • 重做多源 spatial token heads

用一句话总结就是:

用 BEATs 的主干做 FOA 空间建模,而不是只拿 BEATs 当单通道语义骨干,再在旁边打一个空间补丁。

12. 推荐实施顺序

建议按以下顺序推进:

  1. 明确 FOA feature schema

    • 是否使用 WXYZ
    • 是否加入 IV
    • 是否加入其他物理特征
  2. 设计 Spatial-BEATs 的新输入 stem

    • 替换单通道 preprocess
    • 替换 patch embedding
  3. 设计 K-source spatial tokens

    • 确定 token 数量
    • 确定 query 机制
  4. 实现多头空间预测

    • objectness
    • distance
    • azimuth
    • elevation
    • optional class
  5. 实现多源匹配训练

    • Hungarian matching
    • no-object loss
  6. 先做监督训练

    • 不急于加入 SSL
  7. 训练稳定后,再评估是否需要第二阶段自监督

    • masked spatial cue prediction
    • view consistency
    • teacher distillation

13. 第二阶段可选方向

在第一版监督训练稳定后,可考虑加入:

  1. Spatial SSL

    • masked FOA cue prediction
    • spatial consistency loss
    • teacher-student distillation
  2. source-conditioned tokens

    • 类别条件的 source token
  3. LLM-side alignment

    • 将 spatial token 投影到 LLM hidden space
    • 与文本空间词汇进行弱对齐
  4. 更复杂的多源推理训练

    • source relation supervision
    • pairwise spatial relation labels

14. 结论

本项目的推荐方向已经比较明确:

  • 使用 FOA 全特征
  • FOA 特征 真正经过 BEATs backbone
  • BEATs 改造成独立的 Spatial-BEATs
  • 保留原语义 audio encoder 不动
  • 多源 spatial token prediction 为核心任务
  • 第一阶段采用 监督式空间训练
  • 最大化复用 BEATs trunk 的预训练权重

如果后续开始实现,建议优先落地以下最小可行版本:

  1. WXYZ + IV 输入
  2. 多通道 patch embedding
  3. BEATs trunk 复用
  4. K 个 spatial queries
  5. objectness + azimuth + elevation + distance 多头监督

这会比任何 W-only 或“仅外挂 adapter”的方案更符合项目目标,也更适合最终作为 LLM 的空间模态输入。