Spatial-BEATs / docs /spatial_beats_design_guide.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
14.1 kB
# Spatial-BEATs 设计与训练指南
## 1. 文档目标
本文档用于整理本项目中 `Spatial-BEATs` 的任务定义、模型改造方向、保留与替换的模块、训练方法,以及后续接入 LLM 的接口约定。
目标是基于公开的 BEATs 框架,构建一个独立的 `Spatial Encoder`
- 输入为 `FOA 音频` 及其派生的 `FOA 空间特征`
- 主干尽可能复用 `BEATs backbone` 和其预训练权重
- 输出为一组 `Spatial Tokens`
- 这些 `Spatial Tokens` 作为独立模态输入给 LLM
- 原有的语义 audio encoder 保持不动,避免直接与空间 encoder 混合后产生职责不清或语义冲突
本设计不是在原有 LLM audio encoder 上强行加入空间分支,而是单独训练一个 `Spatial-BEATs`,让其专注于空间感知和空间结构建模。
## 2. 任务定义
### 2.1 核心任务
`Spatial-BEATs` 的核心任务不是通用音频语义分类,而是:
1.`FOA` 输入中提取与空间位置相关的结构化表示
2. 在多源场景中输出 `source-level spatial tokens`
3. 每个 token 尽量对应一个潜在声源,编码其空间信息
4. 后续由 LLM 使用这些 token 完成空间关系理解和推理
### 2.2 输入与输出
输入:
- 原始 `FOA waveform`
- 或由 `FOA waveform` 计算得到的 `FOA 特征图`
推荐特征:
- `W, X, Y, Z` 的 log-mel
- `IV (Intensity Vector)`,例如 `IVx, IVy, IVz`
- 可选的 `diffuseness / coherence / phase-related` 特征
输出:
- `K``Spatial Tokens`
- 每个 token 对应一个潜在声源或一个空间实体
- 每个 token 供下游预测:
- `objectness`
- `azimuth`
- `elevation`
- `distance`
- 可选的 `class embedding / source type embedding`
### 2.3 为什么不是只用 W
只让 `W` 通道经过 backbone,本质上更像单通道语义编码,空间线索主要被放到外挂 adapter 中。
这不符合本项目目标,因为这里希望:
- 整个 `FOA` 特征都经过主干
- 主干本身学习空间结构
- `Spatial-BEATs` 成为一个真正的空间 encoder,而不是一个“语义 encoder + 小空间补丁”
因此,本项目的推荐路线是:
- `整套 FOA 特征 -> patch embedding -> BEATs backbone -> source-level spatial tokens`
而不是:
- `W-only -> BEATs`
- `W-only BEATs + 外挂小 adapter`
## 3. 与原始 BEATs 的关系
### 3.1 BEATs 中值得最大化复用的部分
当前仓库中的 BEATs 主干主要包括:
- `post_extract_proj`
- `TransformerEncoder`
- `Transformer layers`
- `conv_pos`
- `LayerNorm / FFN / attention`
这些模块位于:
- `BEATs.py`
- `backbone.py`
这些部分是最应该保留并加载预训练权重的。
### 3.2 BEATs 中不适合直接保留的部分
原始 BEATs 代码是单通道设计,关键假设包括:
- `preprocess()` 只生成单通道 `fbank`
- `patch_embedding``Conv2d(1, embed_dim, ...)`
- 下游输出默认是整段时间序列平均后的分类预测
因此,下列部分不应直接照搬:
1. 单通道 `preprocess`
2. 单通道 `patch_embedding`
3. 最终的 clip-level 平均池化分类输出方式
4. 原始 `predictor` 作为最终目标头
### 3.3 对 BEATs 的总体改造原则
原则是:
- **尽量保留 trunk**
- **必要时重做 stem**
- **完全重做 spatial head**
也就是:
- `输入端`
- `输出端`
- `中间主干` 尽量不改
## 4. Spatial-AST 相比 AudioMAE 的改造经验
Spatial-AST 对本项目最有借鉴价值的不是其 binaural 细节,而是其改造模式。
### 4.1 Spatial-AST 做了什么
相对于原始 AudioMAE/ViT,Spatial-AST 主要改了四类模块:
1. **输入前端**
- 从原始单通道 spectrogram 输入,改成 `双耳 log-mel + IPD`
- 在输入前端加入 `STFT / LogMel / IPD / conv_downsample`
2. **token 设计**
- 把原来的单个 `cls token` 改为 `3 个任务专用 token`
- 分别对应:
- 分类
- 距离
- 方向
3. **输出头**
- 除原分类 head 外,新增:
- `distance_head`
- `azimuth_head`
- `elevation_head`
4. **训练目标**
- 从单任务分类,改成多任务训练
- 同时训练:
- sound event detection
- distance prediction
- direction prediction
### 4.2 Spatial-AST 没有怎么改
Spatial-AST 没有重写 Transformer block 本体。
它保留了 AudioMAE/ViT 的核心 encoder 结构,而把改动集中在:
- front-end
- tokens
- heads
- objectives
### 4.3 对本项目的可迁移结论
可直接借鉴的思想:
1. 用预训练音频主干初始化空间 encoder
2. 重新设计输入前端,让空间特征真正进入 backbone
3. 使用专门的空间 token,而不是只做全局池化
4. 使用多任务监督训练空间能力
不能直接照搬的部分:
1. Spatial-AST 的 `binaural + IPD` 前端
2. 只面向单/双耳的空间 cue 设计
3. 只输出全局 token 的思路
本项目是 `FOA`,因此应该把输入前端换成 `FOA 专属空间特征`
## 5. Spatial-BEATs 的推荐任务形式
### 5.1 单场景 token 不够
如果只输出一个全局 spatial token,它只能表示整个 scene 的压缩摘要,不适合做:
- 多声源关系理解
- “谁在谁左边”
- “某个类对应的声源在什么位置”
- source-wise grounding
既然数据中已经有多源标注,推荐直接把任务定义为:
- `multi-source set prediction`
### 5.2 推荐输出形式
令模型输出固定数量的 `K` 个 spatial queries/tokens。
每个 token 预测:
- `p(obj)`
- `azimuth`
- `elevation`
- `distance`
- 可选 `class logits``class embedding`
训练时使用:
- `Hungarian matching`
- 或者其他 set prediction matching
`K` 个预测 token 与当前样本中的 `N` 个 GT 声源做一一匹配。
这会比单一 scene token 更适合后续接入 LLM 做空间关系推理。
## 6. Spatial-BEATs 的推荐模型结构
推荐结构如下:
```text
FOA waveform
-> FOA front-end
-> FOA feature map
-> FOA patch embedding
-> BEATs Transformer trunk
-> source queries / spatial decoder
-> K spatial tokens
-> spatial heads
```
### 6.1 FOA front-end
输入可以是:
- `WXYZ log-mel`
- `WXYZ log-mel + IVx/IVy/IVz`
推荐第一版就至少使用:
- `WXYZ`
- `IV`
原因:
- 只用 `WXYZ` 仍然需要模型自己从通道关系中恢复空间线索
- `IV` 直接提供有物理意义的方向信息
- 对空间收敛会更稳
### 6.2 FOA patch embedding
这一层应替换原始 BEATs 的单通道 patch stem。
原始:
- `Conv2d(1, embed_dim, kernel_size=patch, stride=patch)`
新的思路:
- `Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)`
其中 `C_foa` 可以是:
- `4`,如果只用 `WXYZ`
- `7`,如果用 `WXYZ + IVxyz`
- 更大,如果加入更多派生空间特征
### 6.3 BEATs trunk
尽量保留以下模块:
- `post_extract_proj`
- `TransformerEncoder`
- `attention`
- `FFN`
- `conv_pos`
- `LayerNorm`
这是整个“最大化复用预训练权重”的核心。
### 6.4 Spatial token 模块
不要继续使用原始 BEATs 的:
- mean pooling
- clip-level predictor
推荐改为:
- `K` 个 learnable source queries
- queries 对 trunk 输出做 attention
- 得到 `K` 个 source-level spatial tokens
如果实现上希望更简单,第一版也可以:
- 先直接在 trunk 输出后接一个轻量 decoder
- 再输出 `K` 个 tokens
### 6.5 Spatial heads
每个 token 对应:
- `objectness head`
- `azimuth head`
- `elevation head`
- `distance head`
- 可选 `class head`
如果担心与原 LLM audio encoder 产生语义冲突,则建议:
-`class head` 只作为辅助监督
- 不把其输出作为最终送入 LLM 的主要表示
## 7. 保留的部分
下面这些建议尽量保留:
### 7.1 保留原有 LLM audio encoder
原始语义 audio encoder 不动,继续负责:
- 音频内容语义
- 事件类别理解
- 与现有 LLM 接口保持兼容
### 7.2 保留 Spatial-BEATs 作为独立 encoder
`Spatial-BEATs` 单独负责:
- 方向
- 距离
- 多源空间结构
- 可选 source-wise 辅助类别信息
### 7.3 保留 BEATs 主干参数初始化
建议保留:
- trunk 的预训练参数
- 尽量避免从零训练整个 spatial encoder
## 8. 需要替换或新增的部分
### 8.1 必改模块
必须修改:
1. `preprocess`
2. `patch_embedding`
3. `forward / extract_features` 的输出方式
4. 下游 `predictor`
### 8.2 必增模块
必须新增:
1. `FOA spatial front-end`
2. `spatial query / token module`
3. `multi-head spatial prediction heads`
4. `set matching / multi-source loss`
### 8.3 建议新增模块
建议新增:
1. `source confidence / objectness`
2. `auxiliary class supervision`
3. `LLM projection head`
## 9. 训练方法
## 9.1 第一阶段是否需要 SSL
当前结论是:
- **第一版不需要重新做 BEATs 式 SSL**
原因:
1. 已经有多源 `GT relative positions`
2. 目标不是再学通用音频语义,而是让模型学空间结构
3. 已有 BEATs 预训练权重可作为稳定初始化
4. 先做监督式空间学习,工程收益最高
因此,推荐第一阶段直接做 `supervised multi-task training`
### 9.2 第一阶段训练目标
基础目标:
- `L_obj`
- `L_azimuth`
- `L_elevation`
- `L_distance`
可选目标:
- `L_class_aux`
总损失可写为:
```text
L = lambda_obj * L_obj
+ lambda_azi * L_azimuth
+ lambda_ele * L_elevation
+ lambda_dist * L_distance
+ lambda_cls * L_class_aux
```
这里建议:
- 空间任务作为主目标
- 类别只做辅助目标
因为本项目中语义主责已经由原始 audio encoder 承担。
### 9.3 多源匹配训练
如果每个样本有多个声源标注,推荐:
1. 模型输出固定数量 `K` 个 tokens
2. 用 Hungarian matching 在 token 与 GT 源之间做匹配
3. 对 matched token 计算位置损失
4. 对 unmatched token 计算 no-object loss
这是比“对所有 token 平均做 scene 监督”更合适的做法。
### 9.4 训练阶段建议
推荐三步走:
#### Stage A: Stem + Head Warmup
- 冻结大部分 BEATs trunk
- 只训练:
- FOA front-end
- FOA patch embedding
- spatial token/query module
- spatial heads
目的:
- 让新输入 stem 和新 heads 先适配预训练 trunk
#### Stage B: Upper Trunk Finetune
- 解冻 BEATs 上层若干层
- 使用较小学习率微调
- 使用 layer-wise lr decay
目的:
- 让 trunk 逐步适配 FOA 分布和空间任务
#### Stage C: Full or Near-Full Finetune
- 在稳定后解冻更多层
- 继续以空间目标微调
目的:
- 提升空间 token 的表达能力
### 9.5 训练数据组织
每个样本应包含:
- `foa waveform`
- `num_sources`
- `per-source azimuth`
- `per-source elevation`
- `per-source distance`
- 可选 `per-source class`
推荐统一成:
```text
sample = {
"audio": ...,
"sources": [
{"azimuth": ..., "elevation": ..., "distance": ..., "label": ...},
...
]
}
```
## 10. 与 LLM 的接口
### 10.1 推荐输入形式
最终不要把 Spatial-BEATs 的全部 dense patch tokens 都喂给 LLM。
推荐只输出:
- `K``Spatial Tokens`
每个 token 代表一个潜在空间实体。
### 10.2 避免语义冲突的策略
本项目中“避免语义冲突”的关键不是完全不学类别,而是:
1. 原始 audio encoder 继续承担主要语义理解
2. Spatial-BEATs 主要承担空间结构建模
3. Spatial-BEATs 输出给 LLM 的是 `source-level spatial tokens`
4. 辅助类别监督只用于训练,不一定直接暴露为最终模态表示
这样两套 encoder 的职责边界更清晰:
- 原语义 encoder:`what`
- Spatial-BEATs:`where / relation / spatial structure`
## 11. 当前推荐方案总结
当前最推荐的路线不是:
- `W-only BEATs`
- `W-only + adapter`
而是:
- `FOA full-feature Spatial-BEATs`
- `独立空间 encoder`
- `最大化复用 BEATs trunk`
- `重做输入 stem`
- `重做多源 spatial token heads`
用一句话总结就是:
> 用 BEATs 的主干做 FOA 空间建模,而不是只拿 BEATs 当单通道语义骨干,再在旁边打一个空间补丁。
## 12. 推荐实施顺序
建议按以下顺序推进:
1. 明确 `FOA feature schema`
- 是否使用 `WXYZ`
- 是否加入 `IV`
- 是否加入其他物理特征
2. 设计 `Spatial-BEATs` 的新输入 stem
- 替换单通道 preprocess
- 替换 patch embedding
3. 设计 `K-source spatial tokens`
- 确定 token 数量
- 确定 query 机制
4. 实现多头空间预测
- objectness
- distance
- azimuth
- elevation
- optional class
5. 实现多源匹配训练
- Hungarian matching
- no-object loss
6. 先做监督训练
- 不急于加入 SSL
7. 训练稳定后,再评估是否需要第二阶段自监督
- masked spatial cue prediction
- view consistency
- teacher distillation
## 13. 第二阶段可选方向
在第一版监督训练稳定后,可考虑加入:
1. `Spatial SSL`
- masked FOA cue prediction
- spatial consistency loss
- teacher-student distillation
2. `source-conditioned tokens`
- 类别条件的 source token
3. `LLM-side alignment`
- 将 spatial token 投影到 LLM hidden space
- 与文本空间词汇进行弱对齐
4. `更复杂的多源推理训练`
- source relation supervision
- pairwise spatial relation labels
## 14. 结论
本项目的推荐方向已经比较明确:
- 使用 `FOA 全特征`
-`FOA 特征` 真正经过 `BEATs backbone`
-`BEATs` 改造成独立的 `Spatial-BEATs`
- 保留原语义 audio encoder 不动
-`多源 spatial token prediction` 为核心任务
- 第一阶段采用 `监督式空间训练`
- 最大化复用 `BEATs trunk` 的预训练权重
如果后续开始实现,建议优先落地以下最小可行版本:
1. `WXYZ + IV` 输入
2. 多通道 patch embedding
3. BEATs trunk 复用
4. `K` 个 spatial queries
5. `objectness + azimuth + elevation + distance` 多头监督
这会比任何 `W-only` 或“仅外挂 adapter”的方案更符合项目目标,也更适合最终作为 LLM 的空间模态输入。