Spatial-BEATs / docs /spatial_beats_implementation_spec.md
dieKarotte's picture
Add files using upload-large-folder tool
bf04039 verified
|
Raw
History Blame Contribute Delete
14.4 kB

Spatial-BEATs 实现规格

1. 目标

本规格文档用于将前期讨论收敛为一个可以直接实施的 Spatial-BEATs 方案。

目标是构建一个独立的 Spatial Encoder

  • 输入为完整 FOA 音频及其派生空间特征
  • 完整的 FOA 特征经过 BEATs backbone
  • 最大化复用 BEATs 预训练权重
  • 输出一组 source-level spatial tokens
  • 这些 token 作为独立模态输入给 LLM
  • 原有语义 audio encoder 保持不动

这里的关键原则是:

不是让 W-only 走主干,再外挂一个小空间 adapter;而是让完整 FOA 空间特征真正进入 BEATs 主干,并在主干之后产出结构化空间 token。

2. 最终任务定义

2.1 核心任务

Spatial-BEATs 的主任务定义为:

  • 给定一个多源 FOA 音频片段
  • 预测其中最多 K 个潜在声源的空间表示
  • 每个表示对应一个 source token

每个 source token 至少承载:

  • objectness
  • azimuth
  • elevation
  • distance

可选承载:

  • source class auxiliary logits
  • source embedding

2.2 推荐监督形式

如果训练数据中每个源都有标注,则推荐采用:

  • set prediction
  • K 个预测 token 对 N 个 GT sources
  • Hungarian matching 做一一匹配

不建议采用:

  • 单一 scene-level spatial token
  • 仅回归整段音频的全局空间摘要

原因是这会损失多源结构,不利于后续 LLM 做关系推理。

3. 最终架构

推荐最终架构:

FOA waveform
  -> SpatialBEATsPreprocessor
  -> FOA feature map [B, C_foa, T, F]
  -> FOA patch embedding
  -> BEATs trunk
  -> Spatial query decoder
  -> K source tokens
  -> Spatial prediction heads
  -> LLM projector

为了最大化复用 BEATs 主干,本方案尽量不改 trunk 内部的 Transformer 结构。

4. 输入特征定义

4.1 默认推荐特征

第一版推荐输入通道:

  • W_logmel
  • X_logmel
  • Y_logmel
  • Z_logmel
  • IVx
  • IVy
  • IVz

即:

  • C_foa = 7

这是默认推荐方案。

4.2 备选输入特征

若希望先降低复杂度,可以使用:

  • WXYZ logmel

即:

  • C_foa = 4

但这只适合最小原型。
如果目标是稳定学习空间方向与结构,优先使用 WXYZ + IV

4.3 前端参数建议

为了最大化复用 BEATs 主干,推荐保持与 BEATs 接近的时频分辨率:

  • sample rate:优先 16k
  • mel bins:128
  • frame length:25 ms
  • frame shift:10 ms

原因:

  • 这能让 trunk 看到与原始 BEATs 更接近的 patch 几何结构
  • patch embedding 和后续序列长度更容易保持一致
  • 预训练权重复用更稳定

4.4 为什么不沿用 Spatial-AST 的 binaural 前端

Spatial-AST 采用的是:

  • 双耳 log-mel
  • IPD

这适合 binaural,不适合直接迁移到 FOA。

FOA 下应优先利用:

  • ambisonic 通道本身
  • intensity vector
  • 其他 FOA 物理特征

5. 对 BEATs 具体修改哪些模块

下面按模块说明修改方案。

5.1 保留不动的模块

建议尽量保留:

  • TransformerEncoder
  • TransformerSentenceEncoderLayer
  • MultiheadAttention
  • conv_pos
  • LayerNorm
  • FFN
  • post_extract_proj

也就是 backbone.py 内的主干结构和 BEATs.py 中的 trunk 逻辑尽量不动。

5.2 必须修改的模块

必须重做:

  1. preprocess
  2. patch_embedding
  3. extract_features 输出头部逻辑
  4. 下游 predictor

5.3 推荐新增的模块

建议新增:

  1. SpatialBEATsPreprocessor
  2. SpatialPatchEmbedding
  3. SpatialQueryDecoder
  4. SpatialPredictionHead
  5. SpatialTokenProjector
  6. HungarianMatcher
  7. SpatialSetCriterion

6. 代码级映射建议

6.1 现有文件建议

建议保留和复用:

建议新增:

  • spatial_beats.py
  • spatial_modules.py
  • spatial_loss.py
  • spatial_dataset.py
  • train_spatial_beats.py

6.2 spatial_beats.py 建议包含

建议实现:

  • SpatialBEATsConfig
  • SpatialBEATs
  • SpatialBEATs.extract_spatial_tokens()
  • SpatialBEATs.forward()

6.3 spatial_modules.py 建议包含

建议实现:

  • SpatialBEATsPreprocessor
  • SpatialPatchEmbedding
  • SpatialQueryDecoder
  • SpatialPredictionHead
  • SpatialTokenProjector

6.4 spatial_loss.py 建议包含

建议实现:

  • HungarianMatcher
  • SpatialSetCriterion

7. 预训练权重如何复用

7.1 默认推荐权重

默认推荐:

  • BEATs_iter3+ (AS2M) pre-trained

而不是:

  • fine-tuned checkpoints

原因:

  • pre-trained 更适合作为 trunk 初始化
  • fine-tuned 更偏向 AudioSet 分类判别
  • 你这里的 spatial encoder 应与原语义 encoder 职责分离

7.2 必须直接加载的层

这些层建议直接加载原 BEATs checkpoint:

  • post_extract_proj
  • encoder.pos_conv
  • encoder.layers.*
  • encoder.layer_norm
  • layer_norm

即除了输入 stem 和输出头,主干参数都尽量继承。

7.3 需要特殊初始化的层

以下层因为 shape 不同,不能直接 strict load:

  • patch_embedding
  • 新增的 query decoder
  • 新增的 spatial heads
  • 新增的 LLM projector

7.4 新 patch embedding 的初始化策略

原 BEATs stem 是:

  • Conv2d(1, embed_dim, kernel_size=patch, stride=patch)

新 stem 建议是:

  • Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)

推荐初始化策略:

方案 A:保守初始化,默认推荐

  • W_logmel 通道继承原 stem 权重
  • 其他空间通道初始化为 0 或较小随机值

优点:

  • 最大程度保留原 BEATs 初始分布
  • trunk 适配更稳

缺点:

  • 训练初期空间通道利用较慢

方案 B:通道 inflation

  • 把原 stem 权重复制到全部输入通道
  • 再按通道数做归一化

优点:

  • 所有通道一开始都能进入主干

缺点:

  • 初始统计更可能偏离原 BEATs

最终推荐:

  • 第一版用 方案 A
  • 后续做 ablation 再比较 方案 B

8. Spatial token 模块的最终设计

8.1 为什么不用全局池化

原始 BEATs 的输出方式更接近:

  • patch sequence
  • mean pooling
  • clip-level prediction

这不适合多源空间任务。

8.2 最终推荐:Query Decoder

在 trunk 输出后新增:

  • K 个 learnable source queries
  • 一个轻量 cross-attention decoder

输入:

  • encoder memory:H in R^{B x T x D}
  • source queries:Q in R^{B x K x D}

输出:

  • Z in R^{B x K x D}

这里的 Z[:, i, :] 即第 isource token

8.3 为什么 query decoder 是当前最优解

它的优点:

  • 不改 trunk 内部结构
  • 仍然让完整 FOA 特征经过 backbone
  • 适合多源 set prediction
  • 最利于最大化复用 trunk 权重

9. 输出头设计

对每个 source token z_i,预测:

  • objectness
  • azimuth
  • elevation
  • distance
  • 可选 class_aux

9.1 离散还是连续

第一版推荐全部使用离散分类头:

  • azimuth: 360 bins
  • elevation: 180 bins
  • distance: 按数据分桶,例如 0.5m 一档

原因:

  • 与已有 Spatial-AST/BAT 经验一致
  • 分类头更稳
  • 更便于构造离散坐标 embedding

9.2 objectness 头

推荐增加:

  • objectness_head: D -> 1

用于:

  • 判断当前 token 是否对应真实声源
  • 作为 Hungarian matching 的一部分
  • 推理时做 token 保留/裁剪

9.3 类别头

类别头建议作为:

  • auxiliary head

而不是最终 LLM 的主要输入内容。

这样做的作用:

  • 让 query token 更容易学会 source slot 对齐
  • 但不把 Spatial-BEATs 变成第二个强语义 encoder

10. Loss 设计

推荐总损失:

L_total =
  lambda_obj * L_obj
  + lambda_azi * L_azi
  + lambda_ele * L_ele
  + lambda_dist * L_dist
  + lambda_cls * L_cls_aux

10.1 匹配方式

使用 Hungarian matching

  • 预测:K 个 token
  • GT:N 个 sources
  • 成本由以下项构成:
    • objectness cost
    • azimuth cost
    • elevation cost
    • distance cost
    • optional class cost

10.2 损失项定义

推荐:

  • L_obj: BCE 或 focal loss
  • L_azi: cross entropy
  • L_ele: cross entropy
  • L_dist: cross entropy
  • L_cls_aux: cross entropy 或 BCE

10.3 初始 loss 权重建议

第一版建议从以下权重起步:

lambda_obj = 1.0
lambda_azi = 2.0
lambda_ele = 2.0
lambda_dist = 1.0
lambda_cls = 0.25

解释:

  • 方向任务通常更关键
  • 距离次之
  • objectness 必须稳定
  • 类别监督只作为辅助

10.4 不建议的做法

第一版不建议:

  • 重分类损失压倒空间损失
  • 直接照搬 Spatial-AST 的 1250 * cls

原因:

  • Spatial-AST 的目标之一是保住 sound event detection
  • 这里 Spatial-BEATs 的主要目标是空间 token
  • 原项目已有独立语义 encoder

11. 训练策略

11.1 第一阶段是否需要 SSL

当前最终结论:

  • 第一版 不需要 重新做 BEATs 式 SSL

因为当前已经有:

  • 多源监督
  • 每个源的空间标注
  • 可复用的 BEATs 主干预训练

所以第一阶段应优先做:

  • supervised multi-source spatial training

11.2 分阶段训练建议

Stage A:Warmup

冻结:

  • 大部分 trunk

只训练:

  • FOA preprocessor
  • patch embedding
  • query decoder
  • spatial heads
  • LLM projector

目的:

  • 让新输入 stem 和新输出头稳定接入 trunk

Stage B:Upper-trunk finetune

解冻:

  • trunk 上层若干层

目的:

  • 让主干逐步适应 FOA 空间任务

Stage C:Near-full finetune

进一步解冻:

  • 更多 encoder layers

目的:

  • 提升空间表示上限

11.3 学习率建议

推荐:

  • trunk:较小 lr
  • 新模块:较大学习率

例如:

lr_trunk = 1e-5 ~ 5e-5
lr_new = 1e-4 ~ 5e-4

并配合:

  • layer-wise lr decay

12. 最终输出给 LLM 的 spatial token 形式

这是本项目最关键的接口定义之一。

12.1 内部 token 形式

Spatial-BEATs 内部输出:

  • Z in R^{B x K x D}

其中:

  • B: batch size
  • K: source token 数
  • D: Spatial-BEATs hidden dim,建议与 BEATs trunk 一致

12.2 不建议直接把 raw logits 喂给 LLM

不建议直接给 LLM:

  • azimuth logits
  • elevation logits
  • distance logits
  • objectness logits

这些是监督头,不是最终模态表示。

12.3 最终推荐的 LLM spatial token 形式

最终推荐送给 LLM 的每个 token 形式为:

s_i = Proj([z_i ; e_azi(i) ; e_ele(i) ; e_dist(i) ; e_obj(i)])

其中:

  • z_i: query decoder 输出的 latent token
  • e_azi(i): 由预测 azimuth bin 查表得到的 embedding
  • e_ele(i): 由预测 elevation bin 查表得到的 embedding
  • e_dist(i): 由预测 distance bin 查表得到的 embedding
  • e_obj(i): 由 objectness/confidence 产生的 embedding
  • Proj: 投影到 LLM hidden size 的 MLP/Linear

最终:

  • s_i in R^{d_llm}

12.4 为什么采用“latent + structured embedding”的混合形式

原因:

  1. z_i 保留丰富的隐式空间结构信息
  2. 坐标 embedding 给 LLM 显式离散空间线索
  3. confidence 有助于 LLM 区分可靠/不可靠 token

这比单纯只传:

  • raw latent token

或者只传:

  • 显式坐标 one-hot / scalar

都更合适。

12.5 最终序列形式

送入 LLM 时推荐:

<SPATIAL_START>, s_1, s_2, ..., s_K, <SPATIAL_END>

并且:

  • objectness 从高到低排序
  • 对低置信 token 可直接截断或 mask

12.6 是否保留全部 K 个 token

默认推荐:

  • 训练时保留全部 K
  • 推理时按 objectness 过滤

例如:

  • 保留前 K_keep
  • 或保留 obj > threshold 的 token

13. 与原语义 audio encoder 的关系

为了避免“两个 encoder 在做同样的事”,推荐如下职责划分:

  • 原语义 audio encoder:负责 what
  • Spatial-BEATs:负责 where / spatial structure / relations

13.1 是否允许 Spatial-BEATs 学类别

允许,但只作为辅助。

建议:

  • 类别头只用于训练
  • 最终输入给 LLM 的空间 token 不直接暴露完整类别 logits

13.2 是否需要和语义 encoder 做对齐

第一版不是必须。

若后续希望更强的 source grounding,可进一步加入:

  • semantic distillation
  • cross-encoder alignment
  • source-wise contrastive loss

但这些应放到第二阶段。

14. 第一版推荐配置

第一版默认建议:

  • 输入特征:WXYZ + IVxyz
  • C_foa = 7
  • 采样率:16k
  • mel bins:128
  • patch 配置:与 BEATs 保持一致
  • 预训练权重:BEATs_iter3+ AS2M pre-trained
  • trunk:最大化加载
  • patch stem:W 继承,其余通道小初始化
  • 输出:K 个 source tokens
  • token 解码:轻量 query decoder
  • 监督:Hungarian matching + 多头空间分类
  • LLM 输入:latent + structured coordinate embedding 的混合 token

15. 实现优先级

推荐按如下优先级推进:

  1. 实现 FOA preprocessor
  2. 实现多通道 patch embedding
  3. 完成 trunk ckpt 加载
  4. 实现 query decoder
  5. 实现 objectness / azi / ele / dist heads
  6. 实现 Hungarian matcher + criterion
  7. 实现 LLM projector
  8. 完成训练脚本

16. 当前仍需用户确认的问题

以下问题会直接影响第一版实现细节:

  1. FOA 数据当前主要采样率是多少?是 16k24k32k 还是 48k
  2. 每个样本中 最大同时源数 大概是多少?这会影响 K 的默认设定。
  3. 每个源是否都有 source-level class label?如果有,类别头和匹配会更稳。
  4. 你希望 distance 是离散分类还是连续回归?当前默认推荐离散分类。
  5. 下游 LLM 的 hidden size 是多少?是否已有固定的 audio token projector?
  6. 你是否希望 Spatial-BEATs 在第一版就具备一定的 source semantic 辅助能力,还是严格只做空间?

17. 结论

当前最终方案已经明确:

  • 完整 FOA 特征进入 BEATs 主干
  • 最大化复用 trunk 预训练
  • 重做输入 stem
  • 重做输出为多源 spatial tokens
  • 第一版采用监督式 set prediction
  • 最终给 LLM 的不是 raw logits,而是融合 latent 与坐标 embedding 的 spatial tokens

这是当前最符合项目目标、也最稳妥的 Spatial-BEATs 方案。