Spatial-BEATs 实现规格
1. 目标
本规格文档用于将前期讨论收敛为一个可以直接实施的 Spatial-BEATs 方案。
目标是构建一个独立的 Spatial Encoder:
- 输入为完整
FOA音频及其派生空间特征 - 完整的
FOA特征经过BEATs backbone - 最大化复用
BEATs预训练权重 - 输出一组
source-level spatial tokens - 这些 token 作为独立模态输入给 LLM
- 原有语义 audio encoder 保持不动
这里的关键原则是:
不是让
W-only走主干,再外挂一个小空间 adapter;而是让完整 FOA 空间特征真正进入 BEATs 主干,并在主干之后产出结构化空间 token。
2. 最终任务定义
2.1 核心任务
Spatial-BEATs 的主任务定义为:
- 给定一个多源
FOA音频片段 - 预测其中最多
K个潜在声源的空间表示 - 每个表示对应一个
source token
每个 source token 至少承载:
objectnessazimuthelevationdistance
可选承载:
source class auxiliary logitssource embedding
2.2 推荐监督形式
如果训练数据中每个源都有标注,则推荐采用:
set predictionK个预测 token 对N个 GT sources- 用
Hungarian matching做一一匹配
不建议采用:
- 单一 scene-level spatial token
- 仅回归整段音频的全局空间摘要
原因是这会损失多源结构,不利于后续 LLM 做关系推理。
3. 最终架构
推荐最终架构:
FOA waveform
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T, F]
-> FOA patch embedding
-> BEATs trunk
-> Spatial query decoder
-> K source tokens
-> Spatial prediction heads
-> LLM projector
为了最大化复用 BEATs 主干,本方案尽量不改 trunk 内部的 Transformer 结构。
4. 输入特征定义
4.1 默认推荐特征
第一版推荐输入通道:
W_logmelX_logmelY_logmelZ_logmelIVxIVyIVz
即:
C_foa = 7
这是默认推荐方案。
4.2 备选输入特征
若希望先降低复杂度,可以使用:
WXYZ logmel
即:
C_foa = 4
但这只适合最小原型。
如果目标是稳定学习空间方向与结构,优先使用 WXYZ + IV。
4.3 前端参数建议
为了最大化复用 BEATs 主干,推荐保持与 BEATs 接近的时频分辨率:
- sample rate:优先
16k - mel bins:
128 - frame length:
25 ms - frame shift:
10 ms
原因:
- 这能让 trunk 看到与原始 BEATs 更接近的 patch 几何结构
- patch embedding 和后续序列长度更容易保持一致
- 预训练权重复用更稳定
4.4 为什么不沿用 Spatial-AST 的 binaural 前端
Spatial-AST 采用的是:
- 双耳 log-mel
- IPD
这适合 binaural,不适合直接迁移到 FOA。
FOA 下应优先利用:
- ambisonic 通道本身
- intensity vector
- 其他 FOA 物理特征
5. 对 BEATs 具体修改哪些模块
下面按模块说明修改方案。
5.1 保留不动的模块
建议尽量保留:
TransformerEncoderTransformerSentenceEncoderLayerMultiheadAttentionconv_posLayerNormFFNpost_extract_proj
也就是 backbone.py 内的主干结构和 BEATs.py 中的 trunk 逻辑尽量不动。
5.2 必须修改的模块
必须重做:
preprocesspatch_embeddingextract_features输出头部逻辑- 下游
predictor
5.3 推荐新增的模块
建议新增:
SpatialBEATsPreprocessorSpatialPatchEmbeddingSpatialQueryDecoderSpatialPredictionHeadSpatialTokenProjectorHungarianMatcherSpatialSetCriterion
6. 代码级映射建议
6.1 现有文件建议
建议保留和复用:
建议新增:
spatial_beats.pyspatial_modules.pyspatial_loss.pyspatial_dataset.pytrain_spatial_beats.py
6.2 spatial_beats.py 建议包含
建议实现:
SpatialBEATsConfigSpatialBEATsSpatialBEATs.extract_spatial_tokens()SpatialBEATs.forward()
6.3 spatial_modules.py 建议包含
建议实现:
SpatialBEATsPreprocessorSpatialPatchEmbeddingSpatialQueryDecoderSpatialPredictionHeadSpatialTokenProjector
6.4 spatial_loss.py 建议包含
建议实现:
HungarianMatcherSpatialSetCriterion
7. 预训练权重如何复用
7.1 默认推荐权重
默认推荐:
BEATs_iter3+ (AS2M) pre-trained
而不是:
- fine-tuned checkpoints
原因:
pre-trained更适合作为 trunk 初始化fine-tuned更偏向 AudioSet 分类判别- 你这里的 spatial encoder 应与原语义 encoder 职责分离
7.2 必须直接加载的层
这些层建议直接加载原 BEATs checkpoint:
post_extract_projencoder.pos_convencoder.layers.*encoder.layer_normlayer_norm
即除了输入 stem 和输出头,主干参数都尽量继承。
7.3 需要特殊初始化的层
以下层因为 shape 不同,不能直接 strict load:
patch_embedding- 新增的
query decoder - 新增的
spatial heads - 新增的
LLM projector
7.4 新 patch embedding 的初始化策略
原 BEATs stem 是:
Conv2d(1, embed_dim, kernel_size=patch, stride=patch)
新 stem 建议是:
Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)
推荐初始化策略:
方案 A:保守初始化,默认推荐
W_logmel通道继承原 stem 权重- 其他空间通道初始化为
0或较小随机值
优点:
- 最大程度保留原 BEATs 初始分布
- trunk 适配更稳
缺点:
- 训练初期空间通道利用较慢
方案 B:通道 inflation
- 把原 stem 权重复制到全部输入通道
- 再按通道数做归一化
优点:
- 所有通道一开始都能进入主干
缺点:
- 初始统计更可能偏离原 BEATs
最终推荐:
- 第一版用
方案 A - 后续做 ablation 再比较
方案 B
8. Spatial token 模块的最终设计
8.1 为什么不用全局池化
原始 BEATs 的输出方式更接近:
- patch sequence
- mean pooling
- clip-level prediction
这不适合多源空间任务。
8.2 最终推荐:Query Decoder
在 trunk 输出后新增:
K个 learnable source queries- 一个轻量
cross-attention decoder
输入:
- encoder memory:
H in R^{B x T x D} - source queries:
Q in R^{B x K x D}
输出:
Z in R^{B x K x D}
这里的 Z[:, i, :] 即第 i 个 source token
8.3 为什么 query decoder 是当前最优解
它的优点:
- 不改 trunk 内部结构
- 仍然让完整 FOA 特征经过 backbone
- 适合多源 set prediction
- 最利于最大化复用 trunk 权重
9. 输出头设计
对每个 source token z_i,预测:
objectnessazimuthelevationdistance- 可选
class_aux
9.1 离散还是连续
第一版推荐全部使用离散分类头:
azimuth: 360 binselevation: 180 binsdistance: 按数据分桶,例如0.5m一档
原因:
- 与已有 Spatial-AST/BAT 经验一致
- 分类头更稳
- 更便于构造离散坐标 embedding
9.2 objectness 头
推荐增加:
objectness_head: D -> 1
用于:
- 判断当前 token 是否对应真实声源
- 作为 Hungarian matching 的一部分
- 推理时做 token 保留/裁剪
9.3 类别头
类别头建议作为:
auxiliary head
而不是最终 LLM 的主要输入内容。
这样做的作用:
- 让 query token 更容易学会 source slot 对齐
- 但不把 Spatial-BEATs 变成第二个强语义 encoder
10. Loss 设计
推荐总损失:
L_total =
lambda_obj * L_obj
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls_aux
10.1 匹配方式
使用 Hungarian matching:
- 预测:
K个 token - GT:
N个 sources - 成本由以下项构成:
- objectness cost
- azimuth cost
- elevation cost
- distance cost
- optional class cost
10.2 损失项定义
推荐:
L_obj: BCE 或 focal lossL_azi: cross entropyL_ele: cross entropyL_dist: cross entropyL_cls_aux: cross entropy 或 BCE
10.3 初始 loss 权重建议
第一版建议从以下权重起步:
lambda_obj = 1.0
lambda_azi = 2.0
lambda_ele = 2.0
lambda_dist = 1.0
lambda_cls = 0.25
解释:
- 方向任务通常更关键
- 距离次之
- objectness 必须稳定
- 类别监督只作为辅助
10.4 不建议的做法
第一版不建议:
- 重分类损失压倒空间损失
- 直接照搬 Spatial-AST 的
1250 * cls
原因:
- Spatial-AST 的目标之一是保住 sound event detection
- 这里
Spatial-BEATs的主要目标是空间 token - 原项目已有独立语义 encoder
11. 训练策略
11.1 第一阶段是否需要 SSL
当前最终结论:
- 第一版 不需要 重新做 BEATs 式 SSL
因为当前已经有:
- 多源监督
- 每个源的空间标注
- 可复用的 BEATs 主干预训练
所以第一阶段应优先做:
supervised multi-source spatial training
11.2 分阶段训练建议
Stage A:Warmup
冻结:
- 大部分 trunk
只训练:
- FOA preprocessor
- patch embedding
- query decoder
- spatial heads
- LLM projector
目的:
- 让新输入 stem 和新输出头稳定接入 trunk
Stage B:Upper-trunk finetune
解冻:
- trunk 上层若干层
目的:
- 让主干逐步适应 FOA 空间任务
Stage C:Near-full finetune
进一步解冻:
- 更多 encoder layers
目的:
- 提升空间表示上限
11.3 学习率建议
推荐:
- trunk:较小 lr
- 新模块:较大学习率
例如:
lr_trunk = 1e-5 ~ 5e-5
lr_new = 1e-4 ~ 5e-4
并配合:
- layer-wise lr decay
12. 最终输出给 LLM 的 spatial token 形式
这是本项目最关键的接口定义之一。
12.1 内部 token 形式
Spatial-BEATs 内部输出:
Z in R^{B x K x D}
其中:
B: batch sizeK: source token 数D: Spatial-BEATs hidden dim,建议与 BEATs trunk 一致
12.2 不建议直接把 raw logits 喂给 LLM
不建议直接给 LLM:
- azimuth logits
- elevation logits
- distance logits
- objectness logits
这些是监督头,不是最终模态表示。
12.3 最终推荐的 LLM spatial token 形式
最终推荐送给 LLM 的每个 token 形式为:
s_i = Proj([z_i ; e_azi(i) ; e_ele(i) ; e_dist(i) ; e_obj(i)])
其中:
z_i: query decoder 输出的 latent tokene_azi(i): 由预测 azimuth bin 查表得到的 embeddinge_ele(i): 由预测 elevation bin 查表得到的 embeddinge_dist(i): 由预测 distance bin 查表得到的 embeddinge_obj(i): 由 objectness/confidence 产生的 embeddingProj: 投影到 LLM hidden size 的 MLP/Linear
最终:
s_i in R^{d_llm}
12.4 为什么采用“latent + structured embedding”的混合形式
原因:
z_i保留丰富的隐式空间结构信息坐标 embedding给 LLM 显式离散空间线索confidence有助于 LLM 区分可靠/不可靠 token
这比单纯只传:
- raw latent token
或者只传:
- 显式坐标 one-hot / scalar
都更合适。
12.5 最终序列形式
送入 LLM 时推荐:
<SPATIAL_START>, s_1, s_2, ..., s_K, <SPATIAL_END>
并且:
- 按
objectness从高到低排序 - 对低置信 token 可直接截断或 mask
12.6 是否保留全部 K 个 token
默认推荐:
- 训练时保留全部
K - 推理时按
objectness过滤
例如:
- 保留前
K_keep - 或保留
obj > threshold的 token
13. 与原语义 audio encoder 的关系
为了避免“两个 encoder 在做同样的事”,推荐如下职责划分:
- 原语义 audio encoder:负责
what - Spatial-BEATs:负责
where / spatial structure / relations
13.1 是否允许 Spatial-BEATs 学类别
允许,但只作为辅助。
建议:
- 类别头只用于训练
- 最终输入给 LLM 的空间 token 不直接暴露完整类别 logits
13.2 是否需要和语义 encoder 做对齐
第一版不是必须。
若后续希望更强的 source grounding,可进一步加入:
- semantic distillation
- cross-encoder alignment
- source-wise contrastive loss
但这些应放到第二阶段。
14. 第一版推荐配置
第一版默认建议:
- 输入特征:
WXYZ + IVxyz C_foa = 7- 采样率:
16k - mel bins:
128 - patch 配置:与 BEATs 保持一致
- 预训练权重:
BEATs_iter3+ AS2M pre-trained - trunk:最大化加载
- patch stem:
W继承,其余通道小初始化 - 输出:
K个 source tokens - token 解码:轻量 query decoder
- 监督:Hungarian matching + 多头空间分类
- LLM 输入:
latent + structured coordinate embedding的混合 token
15. 实现优先级
推荐按如下优先级推进:
- 实现
FOA preprocessor - 实现多通道
patch embedding - 完成 trunk ckpt 加载
- 实现
query decoder - 实现
objectness / azi / ele / distheads - 实现
Hungarian matcher + criterion - 实现
LLM projector - 完成训练脚本
16. 当前仍需用户确认的问题
以下问题会直接影响第一版实现细节:
FOA数据当前主要采样率是多少?是16k、24k、32k还是48k?- 每个样本中
最大同时源数大概是多少?这会影响K的默认设定。 - 每个源是否都有
source-level class label?如果有,类别头和匹配会更稳。 - 你希望
distance是离散分类还是连续回归?当前默认推荐离散分类。 - 下游 LLM 的 hidden size 是多少?是否已有固定的 audio token projector?
- 你是否希望 Spatial-BEATs 在第一版就具备一定的 source semantic 辅助能力,还是严格只做空间?
17. 结论
当前最终方案已经明确:
- 完整 FOA 特征进入 BEATs 主干
- 最大化复用 trunk 预训练
- 重做输入 stem
- 重做输出为多源 spatial tokens
- 第一版采用监督式 set prediction
- 最终给 LLM 的不是 raw logits,而是融合 latent 与坐标 embedding 的 spatial tokens
这是当前最符合项目目标、也最稳妥的 Spatial-BEATs 方案。