Spatial-BEATs / docs /0422.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
13.9 kB

2026-04-22 记录:v7h 基线、v8 融合升级与当前诊断

本文档记录 2026-04-22 这轮关于 v7h -> v8 的设计、问答、结果和代码修改。

相关旧文档:

  • docs/0422_v7h_v7j.md:主要记录 v7h / v7j 的诊断和尝试

本文聚焦:

  • 当前最可用版本 v7h 到底是什么
  • 为什么要做 v8
  • v8 的结构改了什么,没改什么
  • 目前 epoch0/epoch1 的真实现象
  • 后续应该继续看什么

1. 当前结论

截至本轮:

  • 当前最可用的稳定基线仍然是 v7h
  • v8 不是坏形态,值得继续训练
  • v8 当前的主要问题不是 duplicate 崩坏,而是 ov2/ov3 上的 class binding 还不稳,尤其 ov3 仍有明显 class collapse
  • v8 现在还处于 two-stage 的 stage 1,前 3 个 epoch 本来就不训 dir/dist,所以 F20 / LE_CD / oaziepoch0-2 不适合过早下结论

2. 关键问答与结论

2.1 BEATs 适不适合做逐帧、多源预测

结论:

  • 原始 BEATs 官方下游头更偏 clip-level audio tagging
  • BEATs 作为语义 backbone 是适合逐帧任务的
  • 真正不够的是“直接拿官方 head 做 strong-label / 多源 SELD”

对当前项目更准确的判断是:

  • BEATs 适合作为语义分支
  • local spatial branch 适合作为空间分支
  • 真正的难点在于:
    • 如何融合语义/空间表征
    • 如何让 query decoder 在同一时刻分辨多个源
    • 如何让 per-frame class / DoA 监督不互相拖累

2.2 现在 4 个 track 到底在预测什么

当前 local_spatial_track 路线的输出是:

  • 每帧固定产出 K=4 组 candidate tracks
  • 每条 track 在每帧预测:
    • activity
    • class
    • direction
    • distance

它不是“每帧一定有 4 个真实声源”,而是“每帧最多从 4 个候选槽位里激活若干条”。

结构链路是:

fused_embeddings [B, T_s, D]
  -> SourceQueryDecoder
     -> track_latents [B, K, D]
     -> track_time_features [B, K, T_s, D]
  -> FrameTrackPredictionHeads
     -> pred_activity
     -> pred_class_logits
     -> pred_direction
     -> pred_distance

2.3 oracle_cls / oracle_azi / oracle_ele 是什么

这几个量是 track 头的 oracle 诊断指标,不是官方 DCASE metric。

定义:

  • 先在 GT-active 的 frame-source 对上做 matching
  • matching 时不走 activity threshold,只看 head 本身
  • 在 matched pair 上统计:
    • oracle_cls:class 是否对
    • oracle_azi:azimuth MAE
    • oracle_ele:elevation MAE

用途:

  • 看 class/DoA head 本身有没有学到
  • 不包含最终检测误差,不等于 F20

2.4 要不要先训练一个独立的 per-frame class-only BEATs

结论:不建议单独另起一个 class-only 模型作为主路线。

原因:

  • 最终任务的 class 是从 fused -> source_query_decoder -> track_time_features -> class_head 这条路径读出来的
  • 单独训一个普通 per-frame classifier,学不到 query/track binding

更合理的做法是:

  • 保留同一套 track-query 架构
  • 先做 activity + class warmup
  • 再打开 dir/dist

这就是当前 v8 的 two-stage 设计。

2.5 v7 阶段关于初始化和解冻的结论

已经确认的结论:

  • source_query_decoderactivity_head 随机初始化是合理的
  • frame_track_prediction_heads.class/direction/distance 可以从旧的 local_spatial_prediction_heads 迁移同语义权重
  • 如果 v7 只在 ov1 上学过,到了 ov123 不建议永远冻结 trunk
  • 当前合理策略是:
    • 继续从已有可用 checkpoint 热启动
    • 解冻 trunk 顶部少量层,例如 top-4

3. 当前最可用基线:v7h

v7h 继承自 v7f_ov123_top4,是当前最可用的稳定基线。

核心训练设定:

  • readout_scheme = local_spatial_track
  • enable_clip_aux_head = False
  • ov1:ov2:ov3 = 1:3:3
  • 去掉 focal BCE
  • frame_activity_pos_weight = 3.0
  • class-cost warmup 采用 1 + 2 epoch
  • trunk 顶部 4 层解冻

相关配置可参考:

  • train_spatial_beats.py:1465
  • docs/0422_v7h_v7j.md

历史最好结果记录:

v7h ep3 val:
  F20=0.246
  LE_CD=34.3°
  LR_CD=0.504
  SELD=0.608

这也是当前继续做 v8 的热启动来源。

4. v8 的设计目标

4.1 为什么要做 v8

v7h 的判断是:

  • 前端不要乱改
  • source_query_decoderframe_track_prediction_heads 也先不要乱改
  • 真正可能偏弱的是 fused token 的构造方式

v7h 的融合只有:

fused = LayerNorm(semantic_embeddings + local_update)

这隐含假设:

  • semantic token 和 spatial token 已经天然对齐
  • 只要相加,query decoder 就能自动学会利用空间信息

这个假设偏强,因此设计 v8

  • 前端完全不改
  • 只升级 semantic/spatial 的 fusion

4.2 v8 设计原则

  • 以 semantic token 为主骨架
  • spatial token 不直接硬加,而是通过 cross-attention 注入
  • 新模块要能安全从 v7h 热启动

因此 v8 的 fusion 采用:

semantic <- spatial cross-attention (2 layers)
  + gated direct spatial residual
  + same final LayerNorm

5. v8 新模型架构

5.1 整体链路

v8 只改 fused 部分,其他都和 v7h 一致:

FOA waveform [B, 4, T]
  -> SpatialBEATsPreprocessor
  -> patch embedding + BEATs trunk
  -> frequency_pool
  -> TemporalResampler (2.5 Hz)
  -> temporal_readout
  -> semantic_embeddings [B, T_s, D]

  -> LocalSpatialEncoder
  -> local_spatial_resampler
  -> local_spatial_proj
  -> local_update [B, T_s, D]

  -> LocalSpatialCrossFuser (new in v8)
     semantic <- spatial cross-attn x 2
     + gated spatial residual
  -> local_spatial_fusion_norm
  -> fused_embeddings [B, T_s, D]

  -> SourceQueryDecoder
  -> FrameTrackPredictionHeads

5.2 代码落点

新增或修改位置:

  • spatial_modules.py:1723-1842
    • LocalSpatialCrossFusionBlock
    • LocalSpatialCrossFuser
  • spatial_beats.py:173-182
    • 新增 fusion 配置字段
  • spatial_beats.py:1104-1118
    • build_local_spatial_fusion() 改为支持 cross_attn_gated
  • train_spatial_beats.py:1536-1549
    • 新增 make_ov1_local_spatial_v8_ov123_top4_config()
  • run_ov1_v8_ov123_top4.sh:1-68
    • 新的启动脚本

5.3 v8 的训练可学习参数

为了保证 v8 新融合模块真的训练到,代码里做了两处接线:

  • train_spatial_beats.py:2528-2535
    • local_spatial_fuser 被加入 always_train_prefixes
  • train_spatial_beats.py:2646-2653
    • local_spatial_fuser. 被加入 _SPATIAL_PREFIXES

也就是说:

  • 新 fuser 走 spatial_lr
  • 不会被当成 trunk 或 head 漏掉

6. v8 的两阶段训练

6.1 设计原因

直接在新融合结构上同时训练 class + direction + distance,早期很容易出现:

  • class 还没稳定
  • DoA 噪声先进入 matching
  • assignment 被错误的 spatial cost 带偏

所以 v8 采用 two-stage:

  • stage 1:先训 activity + class
  • stage 2:再恢复 dir/dist

6.2 具体配置

v8 preset 中:

frame_spatial_loss_warmup_epochs = 3
frame_spatial_loss_warmup_scale = 0.0

对应代码:

  • train_spatial_beats.py:1547-1548
  • train_spatial_beats.py:4079-4102

含义:

  • epoch 0-2
    • lambda_frame_direction = 0
    • lambda_frame_distance = 0
    • frame_match_dir_cost_weight = 0
    • frame_match_dist_cost_weight = 0
  • epoch 3+
    • 自动恢复完整方向/距离监督

7. v8 热启动方式

默认脚本:

./run_ov1_v8_ov123_top4.sh

默认行为:

  • v7hbest.pt 热启动
  • 只新增 local_spatial_fuser.* 参数
  • 其余 trunk / local_spatial / source_query_decoder / frame-track heads 继承已有权重
  • 训练比例继续保持 ov1:ov2:ov3 = 1:3:3

脚本位置:

  • run_ov1_v8_ov123_top4.sh:16-37

8. v8 当前结果

8.1 epoch 0 日志

用户记录:

[Epoch 0] train:
  loss=1.3445  activity=0.3099  cls_aux=1.0346
  direction=0.4527  dist=0.4894
  act↑=0.887  act↓=0.136  sep=0.751
  ocls=0.786  oazi=45.2°  oele=18.4°

[Epoch 0] val:
  loss=2.1304  activity=0.2310  cls_aux=1.8995
  direction=0.4377  dist=0.4337
  act↑=0.925  act↓=0.092  sep=0.832
  ocls=0.652  oazi=46.9°  oele=20.1°
  ER20=1.047 F20=0.096 LE_CD=51.8° LR_CD=0.486 SELD=0.688

8.2 epoch 0 CSV 诊断

epoch_0000_csv 的诊断结论:

  • 不是 v7j 那种 3-4 条轨全亮的崩坏
  • activity 分离已经很好
  • 当前主要瓶颈是 class + angle,不是 duplicate

粗统计:

TP=207  FP=2147  FN=1844
P=0.088  R=0.101  F1=0.094

FP breakdown:
  class_wrong = 1211
  angle_wrong = 911
  duplicate   = 25

active_hist:
  0 active: 185
  1 active: 748
  2 active: 683
  3 active: 80

no_gt_frames=286
no_gt_with_pred=119
ratio=0.416

mean_maxprob_gt   = 0.972
mean_maxprob_nogt = 0.400

解释:

  • act↑/act↓/sep 很强,说明 activity head 已经把有源/无源分开了
  • class_wrongangle_wrong 很高
  • F20 低并不意外,因为 stage 1 根本还没训 dir/dist

8.3 epoch 0 的 ov1 / ov2 / ov3 形态

代表样本观察:

ov1

mean_act_by_track = {0: 0.88, 1: 0.007, 2: 0.001, 3: 0.0}
active_tracks_per_frame_hist = {1: 37}
top_pred_classes = singing(37)

说明单源样本已经能稳定只亮一条轨。

ov2

mean_act_by_track = {0: 0.999, 1: 0.745, 2: 0.133, 3: 0.0}
active_tracks_per_frame_hist = {1: 5, 2: 45}
top_pred_classes = frog(50), bird(37), tool(8)

说明双源样本已经在尝试输出两条轨。

ov3

mean_act_by_track = {0: 0.977, 1: 0.519, 2: 0.036, 3: 0.0}
active_tracks_per_frame_hist = {1: 17, 2: 32}
top_pred_classes = tool(81)

说明三源样本里第二条轨开始亮,但有明显 class collapse。

9. v8 的 epoch 1:重点看 ov2 / ov3

第二个 epoch 之后,重点检查了 ov23

9.1 active track 数量变化

在“有预测的帧”上,平均亮起的轨数:

epoch0:
  ov2 = 1.578
  ov3 = 1.767

epoch1:
  ov2 = 1.736
  ov3 = 1.942

这说明:

  • ov2/ov3 上第二条轨更积极了
  • 模型更愿意在 overlap 场景输出多轨

这本身不是坏事,但如果 class/DoA 没跟上,就会先表现为 FP 上升。

9.2 代表样本对比

ov2 代表样本 valid__ov2_000000__pred.csv

epoch0

mean_act_by_track = {0: 0.999, 1: 0.745, 2: 0.133, 3: 0.0}
active_hist = {1: 5, 2: 45}
top_classes = frog(50), bird(37), tool(8)

epoch1

mean_act_by_track = {0: 0.996, 1: 0.312, 2: 0.625, 3: 0.339}
active_hist = {1: 11, 2: 39}
top_classes = frog(50), wind(39)

解释:

  • 第二条活跃轨从 track1 转向 track2
  • 类别也从 bird 变成了 wind
  • 说明 query 责任在重排,但 class binding 还不稳定

ov3 代表样本 valid__ov3_000004__pred.csv

epoch0

mean_act_by_track = {0: 0.977, 1: 0.519, 2: 0.036, 3: 0.0}
active_hist = {1: 17, 2: 32}
top_classes = tool(81)

epoch1

mean_act_by_track = {0: 0.966, 1: 0.665, 2: 0.065, 3: 0.127}
active_hist = {1: 11, 2: 39}
top_classes = tool(89)

解释:

  • 第二条轨更稳定地亮了
  • 但 class collapse 没缓解,反而更统一地塌成 tool

因此当前对 epoch1 的判断是:

  • ov2:有变化,但不能算明显变好
  • ov3:暂时没有变好,仍然是当前最大问题点

10. 当前诊断

截至目前,对 v8 的判断是:

10.1 已经证明的事

  • v8 没有走向 v7j 那种 duplicate 崩坏
  • activity 分离做得比担心中要好
  • ov2/ov3 上的多轨输出能力确实在长出来

10.2 还没解决的事

  • ov2/ov3 的 class binding 还不稳
  • ov3 仍然存在明显 class collapse
  • stage 1 期间不训 dir/dist,所以 F20 / LE_CD / oazi 现在还不能作为最终判断

10.3 当前最值得关注的信号

  • epoch 3+ 进入 stage 2 后:
    • ocls 是否继续上升
    • oazi 是否明显下降
    • F20 是否出现拐点
  • ov3 的 top predicted classes 是否开始从单一 tool 分裂成多个类
  • ov23 的平均 active tracks 是否继续上涨过快

11. 本轮实际代码修改

本轮已经落地的修改:

11.1 新增 v8 融合模块

  • spatial_modules.py:1723-1842
    • LocalSpatialCrossFusionBlock
    • LocalSpatialCrossFuser

11.2 新增 fusion 配置

  • spatial_beats.py:173-182
    • local_spatial_fusion_mode
    • local_spatial_fusion_layers
    • local_spatial_fusion_heads
    • local_spatial_fusion_dropout
    • local_spatial_fusion_gate_bias
    • local_spatial_fusion_direct_gate_bias

11.3 改 fused token 构造

  • spatial_beats.py:1104-1118
    • semantic + local_update
    • 改成支持 local_spatial_fuser(...)

11.4 新增 v8 preset

  • train_spatial_beats.py:1536-1549
    • 继承 v7h
    • 打开 cross_attn_gated
    • 打开 two-stage spatial warmup
    • 输出目录改为 v8_ov123_exp/03_ov123_top4

11.5 训练侧接线

  • train_spatial_beats.py:2528-2535
    • local_spatial_fuser 加入 always_train_prefixes
  • train_spatial_beats.py:2646-2653
    • local_spatial_fuser. 加入 _SPATIAL_PREFIXES

11.6 新增脚本

  • run_ov1_v8_ov123_top4.sh:1-68

12. 下一步建议

当前建议不改结构,先继续训练 v8

  1. 至少跑到 epoch 3 之后,确认 stage 2 开启后的趋势
  2. 优先看 ov3 是否开始摆脱 tool collapse
  3. 如果 epoch 3-5 之后仍然:
    • ov3 继续单类塌缩
    • ocls 不升
    • oazi 不降
    • F20 没有明显抬升 再考虑下一轮结构修改

当前最合理的工作顺序是:

  • 先把 v8 跑穿 stage 1 / stage 2
  • 再根据 ov23 的 class collapse 是否缓解,决定下一轮改:
    • query decoder
    • matching
    • finer token rate
    • 或额外的 class-preserving auxiliary