Spatial-BEATs / docs /spatial_beats_simplified_implementation.md
dieKarotte's picture
Add files using upload-large-folder tool
86cbd36 verified
|
Raw
History Blame Contribute Delete
17 kB
# Spatial-BEATs 简化版实现文档
## 1. 文档目的
本文档给出当前推荐的 `Spatial-BEATs` 简化版实现方案。
这份方案基于一个更务实的判断:
- 真正重要的是让 `FOA -> BEATs trunk` 学到稳定的空间表征
- 后面的模块只需要承担 `readout / decode / supervision` 的作用
- 不需要一开始就引入复杂的 `slot query decoder`
- 最终给 LLM 的 token 应尽量来自前面的空间 embedding,而不是最终任务头输出
- 当前阶段需要先支持 `encoder-only training`
因此,本方案不再以“内部 4 个 slots + decoder 聚合”作为主线,而改为:
```text
FOA waveform
-> Spatial preprocessor
-> Multi-channel patch embedding
-> BEATs trunk
-> Temporal readout
-> Spatial embeddings at 2.5 Hz
-> Fixed-slot prediction heads
-> Projector
-> LLM spatial tokens
```
## 2. 当前确定的约束
- 采样率:`16 kHz`
- 输入:`FOA waveform`
- 多源数据:有
- 最大同时源数:约 `4`
- 每个源有稳定 `source-level class label`
- source vocabulary:`/apdcephfs_cq12/share_302080740/user/schmittzhu/data/fsd50k/FSD50K.ground_truth/final_vocabulary.csv`
- source class count:当前默认 `65`
- 距离预测:`连续回归`
- mel 前端参数:对齐 `Qwen-2.5-Omni` audio tower 的底层配置
- `sample_rate=16000`
- `num_mel_bins=128`
- `n_fft=400`
- `win_length=400`
- `hop_length=160`
- `dither=0.0`
- 时间 supervision:`弱时间窗口`
- 当前位置:`clip 内固定`
- 未来会扩展到:`位置随时间变化`
- 目标输出 token rate:`2.5 Hz`
- 对于任意 clip,第 `i` 个样本的有效 token 数是 `T_s_i = round(duration_i * 2.5)`
- 对于 `10 s` clip,`T_s_i = 25`
- batch 内部按 `T_s_max = max_i T_s_i` 做 padding
- 主干初始化:`BEATs_iter3+ AS2M pre-trained`
- 当前第一阶段:优先支持只训练 encoder 的监督方案
## 3. 方案核心思想
### 3.1 什么是主角
主角是:
- `BEATs trunk`
它负责从 FOA 空间特征中学习空间表征。
### 3.2 什么是配角
配角是:
- temporal readout
- prediction heads
- projector
这些模块的作用只是:
- 从 trunk 特征中“读出”空间信息
- 建立 loss 回传路径
- 把空间 embedding 投影到 LLM 接口
### 3.3 给 LLM 的 token 从哪里来
最终给 LLM 的 token 应该来自:
- trunk 后
- 或 trunk 后再经过一层很浅的 temporal readout 之后
而不是来自:
- 最终 logits
- 复杂 decoder 的末端输出
## 4. 最终结构总览
```text
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T_f, F]
-> SpatialPatchEmbedding
-> patch tokens [B, N_p, D_in]
-> BEATs trunk
-> encoder memory [B, N_p, D]
-> reshape / frequency pooling
-> temporal tokens [B, T_s_max, D]
-> shallow temporal readout
-> spatial embeddings [B, T_s_max, D]
-> prediction heads
-> Spatial projector
-> llm spatial tokens [B, T_s_max, d_llm]
```
其中:
- 每个样本:
- `T_s_i = round(duration_i * 2.5)`
- 一个 batch 内:
- `T_s_max = max_i T_s_i`
-`10 s` 输入:
- `T_s_i = 25`
## 5. 从 FOA 到 spatial token 的完整过程
## 5.1 输入层
输入:
- `waveform: [B, 4, T]`
四个通道分别是:
- `W`
- `X`
- `Y`
- `Z`
`10 s, 16kHz`
- `T = 160000`
## 5.2 SpatialBEATsPreprocessor
目标:
- 把原始 FOA 波形转成多通道空间特征图
推荐输出通道:
- `W_logmel`
- `X_logmel`
- `Y_logmel`
- `Z_logmel`
- `IVx`
- `IVy`
- `IVz`
因此:
- `C_foa = 7`
内部步骤:
1.`WXYZ` 做 STFT
2. 计算各通道 `log-mel`
3. 计算 `IVx, IVy, IVz`
4.`IV` 映射到 mel 维
5. 拼接成特征图
输出:
- `foa_feat: [B, 7, T_f, 128]`
其中:
- `128` 是 mel bins
- `T_f` 约等于 `1000`,如果帧移为 `10 ms`
## 5.3 SpatialPatchEmbedding
目标:
-`7` 通道 FOA 特征图切成 patch token
原始 BEATs 的 stem 是单通道:
```text
Conv2d(1, embed_dim, kernel_size=patch, stride=patch)
```
新的 stem 改成:
```text
Conv2d(7, embed_dim, kernel_size=patch, stride=patch)
```
输入:
- `foa_feat: [B, 7, T_f, 128]`
输出:
- `patch_tokens: [B, N_p, D_in]`
- `grid_hw = (T_p, F_p)`
建议:
- `D_in = 512`
原因:
- 与 BEATs patch embedding dim 对齐
## 5.4 BEATs trunk
这是整个模型最重要的部分。
保留并复用的模块包括:
- patch token layer norm
- `post_extract_proj`
- `dropout_input`
- `TransformerEncoder`
- `conv_pos`
- 所有 transformer layers
- LayerNorm / FFN / attention
输入:
- `patch_tokens: [B, N_p, 512]`
输出:
- `encoder_memory: [B, N_p, 768]`
建议:
- 主干 hidden dim 保持 `768`
- 直接加载 `BEATs_iter3+ AS2M pre-trained`
## 5.5 Reshape + Frequency Pooling
目标:
- 把 trunk 输出变成可读的时间序列表示
步骤:
1.`encoder_memory [B, N_p, D]` reshape 成:
- `grid_memory [B, T_p, F_p, D]`
2.`F_p` 维做 pooling
推荐第一版:
- `mean pooling over frequency`
输出:
- `temporal_patch_tokens: [B, T_p, D]`
这一步的意义是:
- 先把频率信息压到时间轴上
- 便于后面构造固定 `2.5 Hz` 的时序空间 token
## 5.6 Temporal Resampler
目标:
- 把 patch 级时间序列压到目标 token rate
输入:
- `temporal_patch_tokens: [B, T_p, D]`
输出:
- `temporal_tokens: [B, T_s, D]`
其中:
-`i` 个样本:
- `T_s_i = round(duration_i * 2.5)`
- batch padding 后:
- `temporal_tokens: [B, T_s_max, D]`
-`10 s` clip:
- `T_s_i = 25`
推荐第一版:
- 线性插值
- 或轻量 `Conv1d` 下采样
注意:
- `2.5 Hz` 明确指最终给 LLM 的 token rate
- 对单个 `10 s` 样本,`10 s -> 25` 个有效 token
- 对 mixed-length batch,张量会 pad 到 `T_s_max`
## 5.7 Shallow Temporal Readout
目标:
- 在时间维上再做一层轻量整理
- 让 trunk 输出更适合做空间监督和给 LLM 使用
推荐做法:
- `1~2` 层 transformer encoder
输入:
- `temporal_tokens: [B, T_s_max, 768]`
输出:
- `spatial_embeddings: [B, T_s_max, 768]`
这一层的作用是:
- 做一个浅层 readout neck
- 不承担复杂检测器职责
- 不强调 decoder 身份
- 只是把 trunk 表征整理成更干净的时间步级空间 embedding
如果你想先做最简单版本,也可以直接:
- `temporal_tokens -> LayerNorm -> spatial_embeddings`
把 shallow transformer 作为后续增强项。
## 5.8 Prediction Heads
目标:
-`spatial_embeddings` 上接显式监督
- 通过 loss 让前面的 trunk 真正学到空间表征
- 在不引入复杂 decoder 的前提下提供多源监督出口
输入:
- `spatial_embeddings: [B, T_s_max, 768]`
### 5.8.1 Encoder-only 阶段的固定槽位 readout
虽然简化版不再使用复杂 `slot query decoder`,但当前仍然有:
- 最大同时源数约为 `4`
因此,encoder-only 训练阶段推荐在 `spatial_embeddings` 后接一个很轻的固定槽位 readout:
```text
spatial_embeddings [B, T_s_max, 768]
-> Linear / MLP expand
-> slot_latents [B, T_s_max, 4, H]
-> shared prediction heads
```
推荐默认:
- `H = 768`
最简单实现:
```text
Linear(768 -> 4 * 768)
reshape -> [B, 25, 4, 768]
```
更稳一点的实现:
```text
Linear(768 -> 768)
-> GELU
-> Linear(768 -> 4 * 768)
reshape -> [B, 25, 4, 768]
```
这一步的作用不是做复杂目标解析,而只是:
- 给每个时间步提供 `4` 个固定 source 槽位
- 让多源监督有明确着力点
- 把 loss 稳定回传到前面的 trunk
### 5.8.2 预测头设计
`slot_latents [B, 25, 4, H]` 接共享或独立 heads。
建议输出头:
- `activity / objectness head`
- `azimuth head`
- `elevation head`
- `distance head`
- `class auxiliary head`
输出形式建议:
- `pred_activity: [B, T_s_max, 4]`
- `pred_azi_logits: [B, T_s_max, 4, 360]`
- `pred_ele_logits: [B, T_s_max, 4, 180]`
- `pred_dist: [B, T_s_max, 4, 1]`
- `pred_class_logits: [B, T_s_max, 4, C_cls]`
其中:
- `pred_activity` 负责当前时间步当前槽位是否解释某个源
- `pred_azi_logits` / `pred_ele_logits` 负责方向
- `pred_dist` 负责连续距离回归
- `pred_class_logits` 负责 source-level 辅助类别监督
当前默认建议:
- `C_cls = 65`
- label vocabulary 来自:
- `final_vocabulary.csv`
- 推荐字段:
- `label_id`
- `final_label`
### 5.8.3 为什么这里仍然保留 K=4
在简化版中:
- `K=4` 不再体现在复杂 decoder 结构里
- 但仍然通过固定槽位 readout 出现在监督头中
也就是说:
- 不需要复杂 query-based object slots
- 但仍然需要一个简单的多槽位 readout 来承载多源标签
这更符合当前“先把主干训练出来”的目标。
## 5.9 Spatial Projector
目标:
-`spatial_embeddings` 投影到 LLM hidden size
输入:
- `spatial_embeddings: [B, T_s_max, 768]`
推荐:
- 独立的 `MLP projector`
形式:
```text
Linear(768 -> D_mid)
-> GELU
-> LayerNorm
-> Linear(D_mid -> d_llm)
```
输出:
- `llm_spatial_tokens: [B, T_s_max, d_llm]`
这就是最终喂给 LLM 的 spatial tokens。
### 5.9.1 Encoder-only 阶段 projector 的角色
当前第一阶段如果只训练 encoder,本质目标是:
- 用监督把 `spatial_embeddings` 训好
因此 projector 在这一阶段有两种合理策略:
#### 方案 A:先不训练 projector,推荐默认
- 只训练:
- preprocessor
- patch embedding
- BEATs trunk
- temporal readout
- fixed-slot prediction heads
- projector 只保留接口,不参与训练
#### 方案 B:一并训练 projector
- 当你希望尽早固定 LLM 接口维度时可以启用
但第一版默认推荐:
- **先把 encoder 训好,再训练 projector**
## 6. 每层的输入输出总结
### 6.1 SpatialBEATsPreprocessor
- 输入:`[B, 4, T]`
- 输出:`[B, 7, T_f, 128]`
### 6.2 SpatialPatchEmbedding
- 输入:`[B, 7, T_f, 128]`
- 输出:`[B, N_p, 512]`
### 6.3 BEATs trunk
- 输入:`[B, N_p, 512]`
- 输出:`[B, N_p, 768]`
### 6.4 Reshape + Frequency Pooling
- 输入:`[B, N_p, 768]`
- 输出:`[B, T_p, 768]`
### 6.5 Temporal Resampler
- 输入:`[B, T_p, 768]`
- 输出:`[B, 25, 768]`
### 6.6 Shallow Temporal Readout
- 输入:`[B, T_s_max, 768]`
- 输出:`[B, T_s_max, 768]`
### 6.7 Prediction Heads
- 输入:`[B, T_s_max, 768]`
- 先经 `Linear / MLP expand` 变成:`[B, T_s_max, 4, H]`
- 输出:
- `activity [B, T_s_max, 4]`
- `azimuth [B, T_s_max, 4, 360]`
- `elevation [B, T_s_max, 4, 180]`
- `distance [B, T_s_max, 4, 1]`
- `class [B, T_s_max, 4, C_cls]`
### 6.8 Spatial Projector
- 输入:`[B, T_s_max, 768]`
- 输出:`[B, T_s_max, d_llm]`
## 7. Loss 如何作用到前面的表征
这个简化方案的关键点在于:
- 后面的 heads 并不是模型重点
- 它们只是为了接监督
loss 从这些 heads 回传后,会更新:
1. readout neck
2. temporal resampler
3. BEATs trunk
4. patch embedding
5. FOA preprocessor
因此:
- 只要后面的 decode/head 能稳定预测显式空间信息
- 前面的 trunk 就会被训练成空间 encoder
### 推荐 loss
当前建议:
- `L_activity`
- `L_azi`
- `L_ele`
- `L_dist`
- `L_cls_aux`
- `L_temp`
总损失:
```text
L_total =
lambda_act * L_activity
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls_aux
+ lambda_temp * L_temp
```
当前可继续保留:
- 方向分类
- 仰角分类
- 距离连续回归
- 辅助类别监督
- 时间一致性正则
### 7.1 Encoder-only 训练阶段的 loss 定义
#### `L_activity`
用于监督:
- 当前时间步当前槽位是否激活
建议:
- `BCEWithLogitsLoss` 或 `focal loss`
结合当前的弱时间窗口 supervision:
- 窗口外:负样本
- 窗口内:弱正样本
#### `L_cls_aux`
用于监督:
- 被分配到某个 GT source 的槽位类别
建议:
- `CrossEntropyLoss`
#### `L_azi` 和 `L_ele`
用于监督:
- 槽位的方向分类
建议:
- `CrossEntropyLoss`
- `azimuth`: `360` bins
- `elevation`: `180` bins
#### `L_dist`
用于监督:
- 槽位的连续距离回归
建议:
- 先将距离归一化到 `[0, 1]`
- 使用 `SmoothL1Loss`
#### `L_temp`
由于当前位置当前是 clip 内固定,建议加入:
- 时间一致性约束
例如对同一 source 在相邻时间步对应的槽位加:
- class 分布平滑
- direction 分布平滑
- distance 平滑
第一版可以先只加:
- distance smoothness
- activity smoothness
## 7.2 Encoder-only 训练时的匹配方式
因为当前 heads 是固定 `4` 槽位,而不是 query decoder,所以建议:
- 使用轻量匹配
- 不必依赖复杂 decoder 结构
推荐策略:
1. 对每个 GT source,根据其时间窗口筛出候选时间步
2. 在该时间步的 `4` 个固定槽位中,选择 cost 最小的槽位
3. 对该槽位施加:
- activity
- class
- azi
- ele
- dist
推荐 cost:
```text
cost =
w_act * cost_act
+ w_cls * cost_cls
+ w_azi * cost_azi
+ w_ele * cost_ele
+ w_dist * cost_dist
```
第一版可以不做全局 Hungarian,直接做:
- `per-step fixed-slot matching`
如果后面发现多源冲突严重,再升级 matching。
## 8. 关于多源 supervision 的理解
虽然当前结构没有显式 `slot query decoder`,但这并不等于完全不考虑多源。
当前更合理的理解是:
- 让 `spatial_embeddings [B, 25, D]` 表示该时间步的空间场景表征
- 再通过固定 `4` 槽位 readout head 承载多源标签
- supervision 用这些多源标签约束 trunk 必须编码多个源的空间信息
这相当于:
- 先学一个强的 time-step level spatial embedding
- 再决定是否需要升级成更复杂的 object-centric / query-based 版本
这是更稳的第一版路径。
## 9. 为什么这个简化版更合适当前阶段
### 9.1 更符合老师的建议
老师的核心意见可以总结为:
- 不必执着于 decoder 结构本身
- 后面的模块只要能 decode 出空间监督即可
- 真正要训好的是前面的 trunk
这个简化版完全符合这个思路。
### 9.2 工程复杂度更低
不需要一开始就实现:
- slot query decoder
- cross-attention decoder
- objectness pooling
- 复杂 slot matching
### 9.3 更利于先验证 trunk 是否真的学到空间表征
如果这版能成功:
- 说明 trunk + temporal readout 本身已经足够表达空间信息
如果这版不够:
- 再升级到 slot/object 版本
这样路线更清晰。
## 10. 推荐实现版本
### V1:最小可行版
```text
FOA -> preprocessor -> patch embed -> BEATs trunk
-> frequency pooling -> temporal resampler
-> LayerNorm
-> spatial embeddings
-> Linear expand to 4 slots
-> heads
```
### V2:推荐正式版
```text
FOA -> preprocessor -> patch embed -> BEATs trunk
-> frequency pooling -> temporal resampler
-> 1~2 层 shallow transformer readout
-> spatial embeddings
-> Linear / MLP expand to 4 slots
-> heads
-> projector
```
### V3:后续增强版
如果未来发现:
- 多源关系建模仍然不足
- 动态轨迹下表达不够
再升级为:
- slot-based decoder
- query-based object-centric readout
## 11. 推荐代码划分
建议新增或保留以下文件:
- `spatial_beats.py`
- `spatial_modules.py`
- `spatial_loss.py`
- `spatial_dataset.py`
- `train_spatial_beats.py`
### `spatial_modules.py`
建议包含:
- `SpatialBEATsPreprocessor`
- `SpatialPatchEmbedding`
- `TemporalResampler`
- `TemporalReadoutTransformer`
- `FixedSlotReadoutHead`
- `SpatialPredictionHeads`
- `SpatialProjector`
### `spatial_beats.py`
建议主类:
- `SpatialBEATsConfig`
- `SpatialBEATs`
主类中建议暴露:
- `extract_spatial_embeddings()`
- `project_for_llm()`
- `forward()`
## 12. 最终输出接口
推荐 `forward()` 返回:
```python
{
"encoder_memory": FloatTensor[B, N_p, 768],
"temporal_tokens": FloatTensor[B, 25, 768],
"spatial_embeddings": FloatTensor[B, 25, 768],
"slot_latents": FloatTensor[B, 25, 4, H],
"pred_activity": FloatTensor[B, 25, 4],
"pred_azi_logits": FloatTensor[B, 25, 4, 360],
"pred_ele_logits": FloatTensor[B, 25, 4, 180],
"pred_dist": FloatTensor[B, 25, 4, 1],
"pred_class_logits": FloatTensor[B, 25, 4, C_cls],
"llm_spatial_tokens": FloatTensor[B, 25, d_llm],
}
```
其中:
- `spatial_embeddings` 是最核心的中间表示
- `slot_latents` 只是 encoder-only 监督出口
- `llm_spatial_tokens` 是最终给 LLM 的接口
## 13. 结论
当前推荐的简化版最终架构是:
- `FOA -> spatial features -> BEATs trunk -> 2.5Hz temporal spatial embeddings -> 固定4槽位 readout heads -> projector`
这套方案的重点是:
- 用显式空间监督把前面的 `BEATs trunk` 训练成空间 encoder
- 后面的 fixed-slot head 只承担监督和 readout 的职责
- 给 LLM 的 token 直接来自前面的 `spatial_embeddings`
这是当前最适合进入代码实现的主线方案。