Spatial-BEATs / docs /spatial_beats_implementation_spec.md
dieKarotte's picture
Add files using upload-large-folder tool
bf04039 verified
|
Raw
History Blame Contribute Delete
14.4 kB
# Spatial-BEATs 实现规格
## 1. 目标
本规格文档用于将前期讨论收敛为一个可以直接实施的 `Spatial-BEATs` 方案。
目标是构建一个独立的 `Spatial Encoder`
- 输入为完整 `FOA` 音频及其派生空间特征
- 完整的 `FOA` 特征经过 `BEATs backbone`
- 最大化复用 `BEATs` 预训练权重
- 输出一组 `source-level spatial tokens`
- 这些 token 作为独立模态输入给 LLM
- 原有语义 audio encoder 保持不动
这里的关键原则是:
> 不是让 `W-only` 走主干,再外挂一个小空间 adapter;而是让完整 FOA 空间特征真正进入 BEATs 主干,并在主干之后产出结构化空间 token。
## 2. 最终任务定义
### 2.1 核心任务
`Spatial-BEATs` 的主任务定义为:
- 给定一个多源 `FOA` 音频片段
- 预测其中最多 `K` 个潜在声源的空间表示
- 每个表示对应一个 `source token`
每个 source token 至少承载:
- `objectness`
- `azimuth`
- `elevation`
- `distance`
可选承载:
- `source class auxiliary logits`
- `source embedding`
### 2.2 推荐监督形式
如果训练数据中每个源都有标注,则推荐采用:
- `set prediction`
- `K` 个预测 token 对 `N` 个 GT sources
-`Hungarian matching` 做一一匹配
不建议采用:
- 单一 scene-level spatial token
- 仅回归整段音频的全局空间摘要
原因是这会损失多源结构,不利于后续 LLM 做关系推理。
## 3. 最终架构
推荐最终架构:
```text
FOA waveform
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T, F]
-> FOA patch embedding
-> BEATs trunk
-> Spatial query decoder
-> K source tokens
-> Spatial prediction heads
-> LLM projector
```
为了最大化复用 BEATs 主干,本方案尽量不改 trunk 内部的 Transformer 结构。
## 4. 输入特征定义
### 4.1 默认推荐特征
第一版推荐输入通道:
- `W_logmel`
- `X_logmel`
- `Y_logmel`
- `Z_logmel`
- `IVx`
- `IVy`
- `IVz`
即:
- `C_foa = 7`
这是默认推荐方案。
### 4.2 备选输入特征
若希望先降低复杂度,可以使用:
- `WXYZ logmel`
即:
- `C_foa = 4`
但这只适合最小原型。
如果目标是稳定学习空间方向与结构,优先使用 `WXYZ + IV`
### 4.3 前端参数建议
为了最大化复用 BEATs 主干,推荐保持与 BEATs 接近的时频分辨率:
- sample rate:优先 `16k`
- mel bins:`128`
- frame length:`25 ms`
- frame shift:`10 ms`
原因:
- 这能让 trunk 看到与原始 BEATs 更接近的 patch 几何结构
- patch embedding 和后续序列长度更容易保持一致
- 预训练权重复用更稳定
### 4.4 为什么不沿用 Spatial-AST 的 binaural 前端
Spatial-AST 采用的是:
- 双耳 log-mel
- IPD
这适合 binaural,不适合直接迁移到 FOA。
FOA 下应优先利用:
- ambisonic 通道本身
- intensity vector
- 其他 FOA 物理特征
## 5. 对 BEATs 具体修改哪些模块
下面按模块说明修改方案。
### 5.1 保留不动的模块
建议尽量保留:
- `TransformerEncoder`
- `TransformerSentenceEncoderLayer`
- `MultiheadAttention`
- `conv_pos`
- `LayerNorm`
- `FFN`
- `post_extract_proj`
也就是 `backbone.py` 内的主干结构和 `BEATs.py` 中的 trunk 逻辑尽量不动。
### 5.2 必须修改的模块
必须重做:
1. `preprocess`
2. `patch_embedding`
3. `extract_features` 输出头部逻辑
4. 下游 `predictor`
### 5.3 推荐新增的模块
建议新增:
1. `SpatialBEATsPreprocessor`
2. `SpatialPatchEmbedding`
3. `SpatialQueryDecoder`
4. `SpatialPredictionHead`
5. `SpatialTokenProjector`
6. `HungarianMatcher`
7. `SpatialSetCriterion`
## 6. 代码级映射建议
### 6.1 现有文件建议
建议保留和复用:
- [BEATs.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/BEATs.py)
- [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py)
建议新增:
- `spatial_beats.py`
- `spatial_modules.py`
- `spatial_loss.py`
- `spatial_dataset.py`
- `train_spatial_beats.py`
### 6.2 `spatial_beats.py` 建议包含
建议实现:
- `SpatialBEATsConfig`
- `SpatialBEATs`
- `SpatialBEATs.extract_spatial_tokens()`
- `SpatialBEATs.forward()`
### 6.3 `spatial_modules.py` 建议包含
建议实现:
- `SpatialBEATsPreprocessor`
- `SpatialPatchEmbedding`
- `SpatialQueryDecoder`
- `SpatialPredictionHead`
- `SpatialTokenProjector`
### 6.4 `spatial_loss.py` 建议包含
建议实现:
- `HungarianMatcher`
- `SpatialSetCriterion`
## 7. 预训练权重如何复用
## 7.1 默认推荐权重
默认推荐:
- `BEATs_iter3+ (AS2M) pre-trained`
而不是:
- fine-tuned checkpoints
原因:
- `pre-trained` 更适合作为 trunk 初始化
- `fine-tuned` 更偏向 AudioSet 分类判别
- 你这里的 spatial encoder 应与原语义 encoder 职责分离
### 7.2 必须直接加载的层
这些层建议直接加载原 BEATs checkpoint:
- `post_extract_proj`
- `encoder.pos_conv`
- `encoder.layers.*`
- `encoder.layer_norm`
- `layer_norm`
即除了输入 stem 和输出头,主干参数都尽量继承。
### 7.3 需要特殊初始化的层
以下层因为 shape 不同,不能直接 strict load:
- `patch_embedding`
- 新增的 `query decoder`
- 新增的 `spatial heads`
- 新增的 `LLM projector`
### 7.4 新 patch embedding 的初始化策略
原 BEATs stem 是:
- `Conv2d(1, embed_dim, kernel_size=patch, stride=patch)`
新 stem 建议是:
- `Conv2d(C_foa, embed_dim, kernel_size=patch, stride=patch)`
推荐初始化策略:
#### 方案 A:保守初始化,默认推荐
- `W_logmel` 通道继承原 stem 权重
- 其他空间通道初始化为 `0` 或较小随机值
优点:
- 最大程度保留原 BEATs 初始分布
- trunk 适配更稳
缺点:
- 训练初期空间通道利用较慢
#### 方案 B:通道 inflation
- 把原 stem 权重复制到全部输入通道
- 再按通道数做归一化
优点:
- 所有通道一开始都能进入主干
缺点:
- 初始统计更可能偏离原 BEATs
最终推荐:
- 第一版用 `方案 A`
- 后续做 ablation 再比较 `方案 B`
## 8. Spatial token 模块的最终设计
### 8.1 为什么不用全局池化
原始 BEATs 的输出方式更接近:
- patch sequence
- mean pooling
- clip-level prediction
这不适合多源空间任务。
### 8.2 最终推荐:Query Decoder
在 trunk 输出后新增:
- `K` 个 learnable source queries
- 一个轻量 `cross-attention decoder`
输入:
- encoder memory:`H in R^{B x T x D}`
- source queries:`Q in R^{B x K x D}`
输出:
- `Z in R^{B x K x D}`
这里的 `Z[:, i, :]` 即第 `i``source token`
### 8.3 为什么 query decoder 是当前最优解
它的优点:
- 不改 trunk 内部结构
- 仍然让完整 FOA 特征经过 backbone
- 适合多源 set prediction
- 最利于最大化复用 trunk 权重
## 9. 输出头设计
对每个 source token `z_i`,预测:
- `objectness`
- `azimuth`
- `elevation`
- `distance`
- 可选 `class_aux`
### 9.1 离散还是连续
第一版推荐全部使用离散分类头:
- `azimuth`: 360 bins
- `elevation`: 180 bins
- `distance`: 按数据分桶,例如 `0.5m` 一档
原因:
- 与已有 Spatial-AST/BAT 经验一致
- 分类头更稳
- 更便于构造离散坐标 embedding
### 9.2 objectness 头
推荐增加:
- `objectness_head: D -> 1`
用于:
- 判断当前 token 是否对应真实声源
- 作为 Hungarian matching 的一部分
- 推理时做 token 保留/裁剪
### 9.3 类别头
类别头建议作为:
- `auxiliary head`
而不是最终 LLM 的主要输入内容。
这样做的作用:
- 让 query token 更容易学会 source slot 对齐
- 但不把 Spatial-BEATs 变成第二个强语义 encoder
## 10. Loss 设计
推荐总损失:
```text
L_total =
lambda_obj * L_obj
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls_aux
```
### 10.1 匹配方式
使用 `Hungarian matching`
- 预测:`K` 个 token
- GT:`N` 个 sources
- 成本由以下项构成:
- objectness cost
- azimuth cost
- elevation cost
- distance cost
- optional class cost
### 10.2 损失项定义
推荐:
- `L_obj`: BCE 或 focal loss
- `L_azi`: cross entropy
- `L_ele`: cross entropy
- `L_dist`: cross entropy
- `L_cls_aux`: cross entropy 或 BCE
### 10.3 初始 loss 权重建议
第一版建议从以下权重起步:
```text
lambda_obj = 1.0
lambda_azi = 2.0
lambda_ele = 2.0
lambda_dist = 1.0
lambda_cls = 0.25
```
解释:
- 方向任务通常更关键
- 距离次之
- objectness 必须稳定
- 类别监督只作为辅助
### 10.4 不建议的做法
第一版不建议:
- 重分类损失压倒空间损失
- 直接照搬 Spatial-AST 的 `1250 * cls`
原因:
- Spatial-AST 的目标之一是保住 sound event detection
- 这里 `Spatial-BEATs` 的主要目标是空间 token
- 原项目已有独立语义 encoder
## 11. 训练策略
### 11.1 第一阶段是否需要 SSL
当前最终结论:
- 第一版 **不需要** 重新做 BEATs 式 SSL
因为当前已经有:
- 多源监督
- 每个源的空间标注
- 可复用的 BEATs 主干预训练
所以第一阶段应优先做:
- `supervised multi-source spatial training`
### 11.2 分阶段训练建议
#### Stage A:Warmup
冻结:
- 大部分 trunk
只训练:
- FOA preprocessor
- patch embedding
- query decoder
- spatial heads
- LLM projector
目的:
- 让新输入 stem 和新输出头稳定接入 trunk
#### Stage B:Upper-trunk finetune
解冻:
- trunk 上层若干层
目的:
- 让主干逐步适应 FOA 空间任务
#### Stage C:Near-full finetune
进一步解冻:
- 更多 encoder layers
目的:
- 提升空间表示上限
### 11.3 学习率建议
推荐:
- trunk:较小 lr
- 新模块:较大学习率
例如:
```text
lr_trunk = 1e-5 ~ 5e-5
lr_new = 1e-4 ~ 5e-4
```
并配合:
- layer-wise lr decay
## 12. 最终输出给 LLM 的 spatial token 形式
这是本项目最关键的接口定义之一。
### 12.1 内部 token 形式
`Spatial-BEATs` 内部输出:
- `Z in R^{B x K x D}`
其中:
- `B`: batch size
- `K`: source token 数
- `D`: Spatial-BEATs hidden dim,建议与 BEATs trunk 一致
### 12.2 不建议直接把 raw logits 喂给 LLM
不建议直接给 LLM:
- azimuth logits
- elevation logits
- distance logits
- objectness logits
这些是监督头,不是最终模态表示。
### 12.3 最终推荐的 LLM spatial token 形式
最终推荐送给 LLM 的每个 token 形式为:
```text
s_i = Proj([z_i ; e_azi(i) ; e_ele(i) ; e_dist(i) ; e_obj(i)])
```
其中:
- `z_i`: query decoder 输出的 latent token
- `e_azi(i)`: 由预测 azimuth bin 查表得到的 embedding
- `e_ele(i)`: 由预测 elevation bin 查表得到的 embedding
- `e_dist(i)`: 由预测 distance bin 查表得到的 embedding
- `e_obj(i)`: 由 objectness/confidence 产生的 embedding
- `Proj`: 投影到 LLM hidden size 的 MLP/Linear
最终:
- `s_i in R^{d_llm}`
### 12.4 为什么采用“latent + structured embedding”的混合形式
原因:
1. `z_i` 保留丰富的隐式空间结构信息
2. `坐标 embedding` 给 LLM 显式离散空间线索
3. `confidence` 有助于 LLM 区分可靠/不可靠 token
这比单纯只传:
- raw latent token
或者只传:
- 显式坐标 one-hot / scalar
都更合适。
### 12.5 最终序列形式
送入 LLM 时推荐:
```text
<SPATIAL_START>, s_1, s_2, ..., s_K, <SPATIAL_END>
```
并且:
-`objectness` 从高到低排序
- 对低置信 token 可直接截断或 mask
### 12.6 是否保留全部 K 个 token
默认推荐:
- 训练时保留全部 `K`
- 推理时按 `objectness` 过滤
例如:
- 保留前 `K_keep`
- 或保留 `obj > threshold` 的 token
## 13. 与原语义 audio encoder 的关系
为了避免“两个 encoder 在做同样的事”,推荐如下职责划分:
- 原语义 audio encoder:负责 `what`
- Spatial-BEATs:负责 `where / spatial structure / relations`
### 13.1 是否允许 Spatial-BEATs 学类别
允许,但只作为辅助。
建议:
- 类别头只用于训练
- 最终输入给 LLM 的空间 token 不直接暴露完整类别 logits
### 13.2 是否需要和语义 encoder 做对齐
第一版不是必须。
若后续希望更强的 source grounding,可进一步加入:
- semantic distillation
- cross-encoder alignment
- source-wise contrastive loss
但这些应放到第二阶段。
## 14. 第一版推荐配置
第一版默认建议:
- 输入特征:`WXYZ + IVxyz`
- `C_foa = 7`
- 采样率:`16k`
- mel bins:`128`
- patch 配置:与 BEATs 保持一致
- 预训练权重:`BEATs_iter3+ AS2M pre-trained`
- trunk:最大化加载
- patch stem:`W` 继承,其余通道小初始化
- 输出:`K` 个 source tokens
- token 解码:轻量 query decoder
- 监督:Hungarian matching + 多头空间分类
- LLM 输入:`latent + structured coordinate embedding` 的混合 token
## 15. 实现优先级
推荐按如下优先级推进:
1. 实现 `FOA preprocessor`
2. 实现多通道 `patch embedding`
3. 完成 trunk ckpt 加载
4. 实现 `query decoder`
5. 实现 `objectness / azi / ele / dist` heads
6. 实现 `Hungarian matcher + criterion`
7. 实现 `LLM projector`
8. 完成训练脚本
## 16. 当前仍需用户确认的问题
以下问题会直接影响第一版实现细节:
1. `FOA` 数据当前主要采样率是多少?是 `16k``24k``32k` 还是 `48k`
2. 每个样本中 `最大同时源数` 大概是多少?这会影响 `K` 的默认设定。
3. 每个源是否都有 `source-level class label`?如果有,类别头和匹配会更稳。
4. 你希望 `distance` 是离散分类还是连续回归?当前默认推荐离散分类。
5. 下游 LLM 的 hidden size 是多少?是否已有固定的 audio token projector?
6. 你是否希望 Spatial-BEATs 在第一版就具备一定的 source semantic 辅助能力,还是严格只做空间?
## 17. 结论
当前最终方案已经明确:
- **完整 FOA 特征进入 BEATs 主干**
- **最大化复用 trunk 预训练**
- **重做输入 stem**
- **重做输出为多源 spatial tokens**
- **第一版采用监督式 set prediction**
- **最终给 LLM 的不是 raw logits,而是融合 latent 与坐标 embedding 的 spatial tokens**
这是当前最符合项目目标、也最稳妥的 `Spatial-BEATs` 方案。