Spatial-BEATs / docs /0422.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
13.9 kB
# 2026-04-22 记录:v7h 基线、v8 融合升级与当前诊断
本文档记录 2026-04-22 这轮关于 `v7h -> v8` 的设计、问答、结果和代码修改。
相关旧文档:
- `docs/0422_v7h_v7j.md`:主要记录 `v7h / v7j` 的诊断和尝试
本文聚焦:
- 当前最可用版本 `v7h` 到底是什么
- 为什么要做 `v8`
- `v8` 的结构改了什么,没改什么
- 目前 `epoch0/epoch1` 的真实现象
- 后续应该继续看什么
## 1. 当前结论
截至本轮:
- 当前最可用的稳定基线仍然是 `v7h`
- `v8` 不是坏形态,值得继续训练
- `v8` 当前的主要问题不是 duplicate 崩坏,而是 `ov2/ov3` 上的 class binding 还不稳,尤其 `ov3` 仍有明显 class collapse
- `v8` 现在还处于 two-stage 的 stage 1,前 3 个 epoch 本来就不训 `dir/dist`,所以 `F20 / LE_CD / oazi``epoch0-2` 不适合过早下结论
## 2. 关键问答与结论
### 2.1 BEATs 适不适合做逐帧、多源预测
结论:
- 原始 `BEATs` 官方下游头更偏 `clip-level audio tagging`
-`BEATs` 作为语义 backbone 是适合逐帧任务的
- 真正不够的是“直接拿官方 head 做 strong-label / 多源 SELD”
对当前项目更准确的判断是:
- `BEATs` 适合作为语义分支
- `local spatial branch` 适合作为空间分支
- 真正的难点在于:
- 如何融合语义/空间表征
- 如何让 query decoder 在同一时刻分辨多个源
- 如何让 per-frame class / DoA 监督不互相拖累
### 2.2 现在 4 个 track 到底在预测什么
当前 `local_spatial_track` 路线的输出是:
- 每帧固定产出 `K=4` 组 candidate tracks
- 每条 track 在每帧预测:
- `activity`
- `class`
- `direction`
- `distance`
它不是“每帧一定有 4 个真实声源”,而是“每帧最多从 4 个候选槽位里激活若干条”。
结构链路是:
```text
fused_embeddings [B, T_s, D]
-> SourceQueryDecoder
-> track_latents [B, K, D]
-> track_time_features [B, K, T_s, D]
-> FrameTrackPredictionHeads
-> pred_activity
-> pred_class_logits
-> pred_direction
-> pred_distance
```
### 2.3 `oracle_cls / oracle_azi / oracle_ele` 是什么
这几个量是 `track` 头的 oracle 诊断指标,不是官方 DCASE metric。
定义:
- 先在 GT-active 的 frame-source 对上做 matching
- matching 时不走 activity threshold,只看 head 本身
- 在 matched pair 上统计:
- `oracle_cls`:class 是否对
- `oracle_azi`:azimuth MAE
- `oracle_ele`:elevation MAE
用途:
- 看 class/DoA head 本身有没有学到
- 不包含最终检测误差,不等于 `F20`
### 2.4 要不要先训练一个独立的 per-frame class-only BEATs
结论:不建议单独另起一个 class-only 模型作为主路线。
原因:
- 最终任务的 class 是从
`fused -> source_query_decoder -> track_time_features -> class_head`
这条路径读出来的
- 单独训一个普通 per-frame classifier,学不到 query/track binding
更合理的做法是:
- 保留同一套 `track-query` 架构
- 先做 `activity + class` warmup
- 再打开 `dir/dist`
这就是当前 `v8` 的 two-stage 设计。
### 2.5 v7 阶段关于初始化和解冻的结论
已经确认的结论:
- `source_query_decoder``activity_head` 随机初始化是合理的
- `frame_track_prediction_heads.class/direction/distance` 可以从旧的 `local_spatial_prediction_heads` 迁移同语义权重
- 如果 `v7` 只在 `ov1` 上学过,到了 `ov123` 不建议永远冻结 trunk
- 当前合理策略是:
- 继续从已有可用 checkpoint 热启动
- 解冻 trunk 顶部少量层,例如 top-4
## 3. 当前最可用基线:v7h
`v7h` 继承自 `v7f_ov123_top4`,是当前最可用的稳定基线。
核心训练设定:
- `readout_scheme = local_spatial_track`
- `enable_clip_aux_head = False`
- `ov1:ov2:ov3 = 1:3:3`
- 去掉 focal BCE
- `frame_activity_pos_weight = 3.0`
- class-cost warmup 采用 `1 + 2` epoch
- trunk 顶部 4 层解冻
相关配置可参考:
- `train_spatial_beats.py:1465`
- `docs/0422_v7h_v7j.md`
历史最好结果记录:
```text
v7h ep3 val:
F20=0.246
LE_CD=34.3°
LR_CD=0.504
SELD=0.608
```
这也是当前继续做 `v8` 的热启动来源。
## 4. v8 的设计目标
### 4.1 为什么要做 v8
`v7h` 的判断是:
- 前端不要乱改
- `source_query_decoder``frame_track_prediction_heads` 也先不要乱改
- 真正可能偏弱的是 fused token 的构造方式
`v7h` 的融合只有:
```text
fused = LayerNorm(semantic_embeddings + local_update)
```
这隐含假设:
- semantic token 和 spatial token 已经天然对齐
- 只要相加,query decoder 就能自动学会利用空间信息
这个假设偏强,因此设计 `v8`
- 前端完全不改
- 只升级 semantic/spatial 的 fusion
### 4.2 v8 设计原则
- 以 semantic token 为主骨架
- spatial token 不直接硬加,而是通过 cross-attention 注入
- 新模块要能安全从 `v7h` 热启动
因此 `v8` 的 fusion 采用:
```text
semantic <- spatial cross-attention (2 layers)
+ gated direct spatial residual
+ same final LayerNorm
```
## 5. v8 新模型架构
### 5.1 整体链路
`v8` 只改 fused 部分,其他都和 `v7h` 一致:
```text
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> patch embedding + BEATs trunk
-> frequency_pool
-> TemporalResampler (2.5 Hz)
-> temporal_readout
-> semantic_embeddings [B, T_s, D]
-> LocalSpatialEncoder
-> local_spatial_resampler
-> local_spatial_proj
-> local_update [B, T_s, D]
-> LocalSpatialCrossFuser (new in v8)
semantic <- spatial cross-attn x 2
+ gated spatial residual
-> local_spatial_fusion_norm
-> fused_embeddings [B, T_s, D]
-> SourceQueryDecoder
-> FrameTrackPredictionHeads
```
### 5.2 代码落点
新增或修改位置:
- `spatial_modules.py:1723-1842`
- `LocalSpatialCrossFusionBlock`
- `LocalSpatialCrossFuser`
- `spatial_beats.py:173-182`
- 新增 fusion 配置字段
- `spatial_beats.py:1104-1118`
- `build_local_spatial_fusion()` 改为支持 `cross_attn_gated`
- `train_spatial_beats.py:1536-1549`
- 新增 `make_ov1_local_spatial_v8_ov123_top4_config()`
- `run_ov1_v8_ov123_top4.sh:1-68`
- 新的启动脚本
### 5.3 v8 的训练可学习参数
为了保证 `v8` 新融合模块真的训练到,代码里做了两处接线:
- `train_spatial_beats.py:2528-2535`
- `local_spatial_fuser` 被加入 `always_train_prefixes`
- `train_spatial_beats.py:2646-2653`
- `local_spatial_fuser.` 被加入 `_SPATIAL_PREFIXES`
也就是说:
- 新 fuser 走 `spatial_lr`
- 不会被当成 trunk 或 head 漏掉
## 6. v8 的两阶段训练
### 6.1 设计原因
直接在新融合结构上同时训练 `class + direction + distance`,早期很容易出现:
- class 还没稳定
- DoA 噪声先进入 matching
- assignment 被错误的 spatial cost 带偏
所以 `v8` 采用 two-stage:
- stage 1:先训 `activity + class`
- stage 2:再恢复 `dir/dist`
### 6.2 具体配置
`v8` preset 中:
```text
frame_spatial_loss_warmup_epochs = 3
frame_spatial_loss_warmup_scale = 0.0
```
对应代码:
- `train_spatial_beats.py:1547-1548`
- `train_spatial_beats.py:4079-4102`
含义:
- `epoch 0-2`
- `lambda_frame_direction = 0`
- `lambda_frame_distance = 0`
- `frame_match_dir_cost_weight = 0`
- `frame_match_dist_cost_weight = 0`
- `epoch 3+`
- 自动恢复完整方向/距离监督
## 7. v8 热启动方式
默认脚本:
```text
./run_ov1_v8_ov123_top4.sh
```
默认行为:
-`v7h``best.pt` 热启动
- 只新增 `local_spatial_fuser.*` 参数
- 其余 trunk / local_spatial / source_query_decoder / frame-track heads 继承已有权重
- 训练比例继续保持 `ov1:ov2:ov3 = 1:3:3`
脚本位置:
- `run_ov1_v8_ov123_top4.sh:16-37`
## 8. v8 当前结果
### 8.1 epoch 0 日志
用户记录:
```text
[Epoch 0] train:
loss=1.3445 activity=0.3099 cls_aux=1.0346
direction=0.4527 dist=0.4894
act↑=0.887 act↓=0.136 sep=0.751
ocls=0.786 oazi=45.2° oele=18.4°
[Epoch 0] val:
loss=2.1304 activity=0.2310 cls_aux=1.8995
direction=0.4377 dist=0.4337
act↑=0.925 act↓=0.092 sep=0.832
ocls=0.652 oazi=46.9° oele=20.1°
ER20=1.047 F20=0.096 LE_CD=51.8° LR_CD=0.486 SELD=0.688
```
### 8.2 epoch 0 CSV 诊断
`epoch_0000_csv` 的诊断结论:
- 不是 `v7j` 那种 3-4 条轨全亮的崩坏
- `activity` 分离已经很好
- 当前主要瓶颈是 `class + angle`,不是 duplicate
粗统计:
```text
TP=207 FP=2147 FN=1844
P=0.088 R=0.101 F1=0.094
FP breakdown:
class_wrong = 1211
angle_wrong = 911
duplicate = 25
active_hist:
0 active: 185
1 active: 748
2 active: 683
3 active: 80
no_gt_frames=286
no_gt_with_pred=119
ratio=0.416
mean_maxprob_gt = 0.972
mean_maxprob_nogt = 0.400
```
解释:
- `act↑/act↓/sep` 很强,说明 activity head 已经把有源/无源分开了
- 但 `class_wrong` 和 `angle_wrong` 很高
- `F20` 低并不意外,因为 stage 1 根本还没训 `dir/dist`
### 8.3 epoch 0 的 ov1 / ov2 / ov3 形态
代表样本观察:
`ov1`
```text
mean_act_by_track = {0: 0.88, 1: 0.007, 2: 0.001, 3: 0.0}
active_tracks_per_frame_hist = {1: 37}
top_pred_classes = singing(37)
```
说明单源样本已经能稳定只亮一条轨。
`ov2`
```text
mean_act_by_track = {0: 0.999, 1: 0.745, 2: 0.133, 3: 0.0}
active_tracks_per_frame_hist = {1: 5, 2: 45}
top_pred_classes = frog(50), bird(37), tool(8)
```
说明双源样本已经在尝试输出两条轨。
`ov3`
```text
mean_act_by_track = {0: 0.977, 1: 0.519, 2: 0.036, 3: 0.0}
active_tracks_per_frame_hist = {1: 17, 2: 32}
top_pred_classes = tool(81)
```
说明三源样本里第二条轨开始亮,但有明显 class collapse。
## 9. v8 的 epoch 1:重点看 ov2 / ov3
第二个 epoch 之后,重点检查了 `ov23`。
### 9.1 active track 数量变化
在“有预测的帧”上,平均亮起的轨数:
```text
epoch0:
ov2 = 1.578
ov3 = 1.767
epoch1:
ov2 = 1.736
ov3 = 1.942
```
这说明:
- `ov2/ov3` 上第二条轨更积极了
- 模型更愿意在 overlap 场景输出多轨
这本身不是坏事,但如果 class/DoA 没跟上,就会先表现为 FP 上升。
### 9.2 代表样本对比
`ov2` 代表样本 `valid__ov2_000000__pred.csv`
`epoch0`
```text
mean_act_by_track = {0: 0.999, 1: 0.745, 2: 0.133, 3: 0.0}
active_hist = {1: 5, 2: 45}
top_classes = frog(50), bird(37), tool(8)
```
`epoch1`
```text
mean_act_by_track = {0: 0.996, 1: 0.312, 2: 0.625, 3: 0.339}
active_hist = {1: 11, 2: 39}
top_classes = frog(50), wind(39)
```
解释:
- 第二条活跃轨从 `track1` 转向 `track2`
- 类别也从 `bird` 变成了 `wind`
- 说明 query 责任在重排,但 class binding 还不稳定
`ov3` 代表样本 `valid__ov3_000004__pred.csv`
`epoch0`
```text
mean_act_by_track = {0: 0.977, 1: 0.519, 2: 0.036, 3: 0.0}
active_hist = {1: 17, 2: 32}
top_classes = tool(81)
```
`epoch1`
```text
mean_act_by_track = {0: 0.966, 1: 0.665, 2: 0.065, 3: 0.127}
active_hist = {1: 11, 2: 39}
top_classes = tool(89)
```
解释:
- 第二条轨更稳定地亮了
- 但 class collapse 没缓解,反而更统一地塌成 `tool`
因此当前对 `epoch1` 的判断是:
- `ov2`:有变化,但不能算明显变好
- `ov3`:暂时没有变好,仍然是当前最大问题点
## 10. 当前诊断
截至目前,对 `v8` 的判断是:
### 10.1 已经证明的事
- `v8` 没有走向 `v7j` 那种 duplicate 崩坏
- `activity` 分离做得比担心中要好
- `ov2/ov3` 上的多轨输出能力确实在长出来
### 10.2 还没解决的事
- `ov2/ov3` 的 class binding 还不稳
- `ov3` 仍然存在明显 class collapse
- stage 1 期间不训 `dir/dist`,所以 `F20 / LE_CD / oazi` 现在还不能作为最终判断
### 10.3 当前最值得关注的信号
- `epoch 3+` 进入 stage 2 后:
- `ocls` 是否继续上升
- `oazi` 是否明显下降
- `F20` 是否出现拐点
- `ov3` 的 top predicted classes 是否开始从单一 `tool` 分裂成多个类
- `ov23` 的平均 active tracks 是否继续上涨过快
## 11. 本轮实际代码修改
本轮已经落地的修改:
### 11.1 新增 v8 融合模块
- `spatial_modules.py:1723-1842`
- `LocalSpatialCrossFusionBlock`
- `LocalSpatialCrossFuser`
### 11.2 新增 fusion 配置
- `spatial_beats.py:173-182`
- `local_spatial_fusion_mode`
- `local_spatial_fusion_layers`
- `local_spatial_fusion_heads`
- `local_spatial_fusion_dropout`
- `local_spatial_fusion_gate_bias`
- `local_spatial_fusion_direct_gate_bias`
### 11.3 改 fused token 构造
- `spatial_beats.py:1104-1118`
- 从 `semantic + local_update`
- 改成支持 `local_spatial_fuser(...)`
### 11.4 新增 v8 preset
- `train_spatial_beats.py:1536-1549`
- 继承 `v7h`
- 打开 `cross_attn_gated`
- 打开 two-stage spatial warmup
- 输出目录改为 `v8_ov123_exp/03_ov123_top4`
### 11.5 训练侧接线
- `train_spatial_beats.py:2528-2535`
- `local_spatial_fuser` 加入 `always_train_prefixes`
- `train_spatial_beats.py:2646-2653`
- `local_spatial_fuser.` 加入 `_SPATIAL_PREFIXES`
### 11.6 新增脚本
- `run_ov1_v8_ov123_top4.sh:1-68`
## 12. 下一步建议
当前建议不改结构,先继续训练 `v8`:
1. 至少跑到 `epoch 3` 之后,确认 stage 2 开启后的趋势
2. 优先看 `ov3` 是否开始摆脱 `tool` collapse
3. 如果 `epoch 3-5` 之后仍然:
- `ov3` 继续单类塌缩
- `ocls` 不升
- `oazi` 不降
- `F20` 没有明显抬升
再考虑下一轮结构修改
当前最合理的工作顺序是:
- 先把 `v8` 跑穿 stage 1 / stage 2
- 再根据 `ov23` 的 class collapse 是否缓解,决定下一轮改:
- query decoder
- matching
- finer token rate
- 或额外的 class-preserving auxiliary