2026-04-22 记录:v7h 基线、v8 融合升级与当前诊断
本文档记录 2026-04-22 这轮关于 v7h -> v8 的设计、问答、结果和代码修改。
相关旧文档:
docs/0422_v7h_v7j.md:主要记录v7h / v7j的诊断和尝试
本文聚焦:
- 当前最可用版本
v7h到底是什么 - 为什么要做
v8 v8的结构改了什么,没改什么- 目前
epoch0/epoch1的真实现象 - 后续应该继续看什么
1. 当前结论
截至本轮:
- 当前最可用的稳定基线仍然是
v7h v8不是坏形态,值得继续训练v8当前的主要问题不是 duplicate 崩坏,而是ov2/ov3上的 class binding 还不稳,尤其ov3仍有明显 class collapsev8现在还处于 two-stage 的 stage 1,前 3 个 epoch 本来就不训dir/dist,所以F20 / LE_CD / oazi在epoch0-2不适合过早下结论
2. 关键问答与结论
2.1 BEATs 适不适合做逐帧、多源预测
结论:
- 原始
BEATs官方下游头更偏clip-level audio tagging - 但
BEATs作为语义 backbone 是适合逐帧任务的 - 真正不够的是“直接拿官方 head 做 strong-label / 多源 SELD”
对当前项目更准确的判断是:
BEATs适合作为语义分支local spatial branch适合作为空间分支- 真正的难点在于:
- 如何融合语义/空间表征
- 如何让 query decoder 在同一时刻分辨多个源
- 如何让 per-frame class / DoA 监督不互相拖累
2.2 现在 4 个 track 到底在预测什么
当前 local_spatial_track 路线的输出是:
- 每帧固定产出
K=4组 candidate tracks - 每条 track 在每帧预测:
activityclassdirectiondistance
它不是“每帧一定有 4 个真实声源”,而是“每帧最多从 4 个候选槽位里激活若干条”。
结构链路是:
fused_embeddings [B, T_s, D]
-> SourceQueryDecoder
-> track_latents [B, K, D]
-> track_time_features [B, K, T_s, D]
-> FrameTrackPredictionHeads
-> pred_activity
-> pred_class_logits
-> pred_direction
-> pred_distance
2.3 oracle_cls / oracle_azi / oracle_ele 是什么
这几个量是 track 头的 oracle 诊断指标,不是官方 DCASE metric。
定义:
- 先在 GT-active 的 frame-source 对上做 matching
- matching 时不走 activity threshold,只看 head 本身
- 在 matched pair 上统计:
oracle_cls:class 是否对oracle_azi:azimuth MAEoracle_ele:elevation MAE
用途:
- 看 class/DoA head 本身有没有学到
- 不包含最终检测误差,不等于
F20
2.4 要不要先训练一个独立的 per-frame class-only BEATs
结论:不建议单独另起一个 class-only 模型作为主路线。
原因:
- 最终任务的 class 是从
fused -> source_query_decoder -> track_time_features -> class_head这条路径读出来的 - 单独训一个普通 per-frame classifier,学不到 query/track binding
更合理的做法是:
- 保留同一套
track-query架构 - 先做
activity + classwarmup - 再打开
dir/dist
这就是当前 v8 的 two-stage 设计。
2.5 v7 阶段关于初始化和解冻的结论
已经确认的结论:
source_query_decoder和activity_head随机初始化是合理的frame_track_prediction_heads.class/direction/distance可以从旧的local_spatial_prediction_heads迁移同语义权重- 如果
v7只在ov1上学过,到了ov123不建议永远冻结 trunk - 当前合理策略是:
- 继续从已有可用 checkpoint 热启动
- 解冻 trunk 顶部少量层,例如 top-4
3. 当前最可用基线:v7h
v7h 继承自 v7f_ov123_top4,是当前最可用的稳定基线。
核心训练设定:
readout_scheme = local_spatial_trackenable_clip_aux_head = Falseov1:ov2:ov3 = 1:3:3- 去掉 focal BCE
frame_activity_pos_weight = 3.0- class-cost warmup 采用
1 + 2epoch - trunk 顶部 4 层解冻
相关配置可参考:
train_spatial_beats.py:1465docs/0422_v7h_v7j.md
历史最好结果记录:
v7h ep3 val:
F20=0.246
LE_CD=34.3°
LR_CD=0.504
SELD=0.608
这也是当前继续做 v8 的热启动来源。
4. v8 的设计目标
4.1 为什么要做 v8
对 v7h 的判断是:
- 前端不要乱改
source_query_decoder和frame_track_prediction_heads也先不要乱改- 真正可能偏弱的是 fused token 的构造方式
v7h 的融合只有:
fused = LayerNorm(semantic_embeddings + local_update)
这隐含假设:
- semantic token 和 spatial token 已经天然对齐
- 只要相加,query decoder 就能自动学会利用空间信息
这个假设偏强,因此设计 v8:
- 前端完全不改
- 只升级 semantic/spatial 的 fusion
4.2 v8 设计原则
- 以 semantic token 为主骨架
- spatial token 不直接硬加,而是通过 cross-attention 注入
- 新模块要能安全从
v7h热启动
因此 v8 的 fusion 采用:
semantic <- spatial cross-attention (2 layers)
+ gated direct spatial residual
+ same final LayerNorm
5. v8 新模型架构
5.1 整体链路
v8 只改 fused 部分,其他都和 v7h 一致:
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> patch embedding + BEATs trunk
-> frequency_pool
-> TemporalResampler (2.5 Hz)
-> temporal_readout
-> semantic_embeddings [B, T_s, D]
-> LocalSpatialEncoder
-> local_spatial_resampler
-> local_spatial_proj
-> local_update [B, T_s, D]
-> LocalSpatialCrossFuser (new in v8)
semantic <- spatial cross-attn x 2
+ gated spatial residual
-> local_spatial_fusion_norm
-> fused_embeddings [B, T_s, D]
-> SourceQueryDecoder
-> FrameTrackPredictionHeads
5.2 代码落点
新增或修改位置:
spatial_modules.py:1723-1842LocalSpatialCrossFusionBlockLocalSpatialCrossFuser
spatial_beats.py:173-182- 新增 fusion 配置字段
spatial_beats.py:1104-1118build_local_spatial_fusion()改为支持cross_attn_gated
train_spatial_beats.py:1536-1549- 新增
make_ov1_local_spatial_v8_ov123_top4_config()
- 新增
run_ov1_v8_ov123_top4.sh:1-68- 新的启动脚本
5.3 v8 的训练可学习参数
为了保证 v8 新融合模块真的训练到,代码里做了两处接线:
train_spatial_beats.py:2528-2535local_spatial_fuser被加入always_train_prefixes
train_spatial_beats.py:2646-2653local_spatial_fuser.被加入_SPATIAL_PREFIXES
也就是说:
- 新 fuser 走
spatial_lr组 - 不会被当成 trunk 或 head 漏掉
6. v8 的两阶段训练
6.1 设计原因
直接在新融合结构上同时训练 class + direction + distance,早期很容易出现:
- class 还没稳定
- DoA 噪声先进入 matching
- assignment 被错误的 spatial cost 带偏
所以 v8 采用 two-stage:
- stage 1:先训
activity + class - stage 2:再恢复
dir/dist
6.2 具体配置
v8 preset 中:
frame_spatial_loss_warmup_epochs = 3
frame_spatial_loss_warmup_scale = 0.0
对应代码:
train_spatial_beats.py:1547-1548train_spatial_beats.py:4079-4102
含义:
epoch 0-2lambda_frame_direction = 0lambda_frame_distance = 0frame_match_dir_cost_weight = 0frame_match_dist_cost_weight = 0
epoch 3+- 自动恢复完整方向/距离监督
7. v8 热启动方式
默认脚本:
./run_ov1_v8_ov123_top4.sh
默认行为:
- 从
v7h的best.pt热启动 - 只新增
local_spatial_fuser.*参数 - 其余 trunk / local_spatial / source_query_decoder / frame-track heads 继承已有权重
- 训练比例继续保持
ov1:ov2:ov3 = 1:3:3
脚本位置:
run_ov1_v8_ov123_top4.sh:16-37
8. v8 当前结果
8.1 epoch 0 日志
用户记录:
[Epoch 0] train:
loss=1.3445 activity=0.3099 cls_aux=1.0346
direction=0.4527 dist=0.4894
act↑=0.887 act↓=0.136 sep=0.751
ocls=0.786 oazi=45.2° oele=18.4°
[Epoch 0] val:
loss=2.1304 activity=0.2310 cls_aux=1.8995
direction=0.4377 dist=0.4337
act↑=0.925 act↓=0.092 sep=0.832
ocls=0.652 oazi=46.9° oele=20.1°
ER20=1.047 F20=0.096 LE_CD=51.8° LR_CD=0.486 SELD=0.688
8.2 epoch 0 CSV 诊断
epoch_0000_csv 的诊断结论:
- 不是
v7j那种 3-4 条轨全亮的崩坏 activity分离已经很好- 当前主要瓶颈是
class + angle,不是 duplicate
粗统计:
TP=207 FP=2147 FN=1844
P=0.088 R=0.101 F1=0.094
FP breakdown:
class_wrong = 1211
angle_wrong = 911
duplicate = 25
active_hist:
0 active: 185
1 active: 748
2 active: 683
3 active: 80
no_gt_frames=286
no_gt_with_pred=119
ratio=0.416
mean_maxprob_gt = 0.972
mean_maxprob_nogt = 0.400
解释:
act↑/act↓/sep很强,说明 activity head 已经把有源/无源分开了- 但
class_wrong和angle_wrong很高 F20低并不意外,因为 stage 1 根本还没训dir/dist
8.3 epoch 0 的 ov1 / ov2 / ov3 形态
代表样本观察:
ov1
mean_act_by_track = {0: 0.88, 1: 0.007, 2: 0.001, 3: 0.0}
active_tracks_per_frame_hist = {1: 37}
top_pred_classes = singing(37)
说明单源样本已经能稳定只亮一条轨。
ov2
mean_act_by_track = {0: 0.999, 1: 0.745, 2: 0.133, 3: 0.0}
active_tracks_per_frame_hist = {1: 5, 2: 45}
top_pred_classes = frog(50), bird(37), tool(8)
说明双源样本已经在尝试输出两条轨。
ov3
mean_act_by_track = {0: 0.977, 1: 0.519, 2: 0.036, 3: 0.0}
active_tracks_per_frame_hist = {1: 17, 2: 32}
top_pred_classes = tool(81)
说明三源样本里第二条轨开始亮,但有明显 class collapse。
9. v8 的 epoch 1:重点看 ov2 / ov3
第二个 epoch 之后,重点检查了 ov23。
9.1 active track 数量变化
在“有预测的帧”上,平均亮起的轨数:
epoch0:
ov2 = 1.578
ov3 = 1.767
epoch1:
ov2 = 1.736
ov3 = 1.942
这说明:
ov2/ov3上第二条轨更积极了- 模型更愿意在 overlap 场景输出多轨
这本身不是坏事,但如果 class/DoA 没跟上,就会先表现为 FP 上升。
9.2 代表样本对比
ov2 代表样本 valid__ov2_000000__pred.csv
epoch0
mean_act_by_track = {0: 0.999, 1: 0.745, 2: 0.133, 3: 0.0}
active_hist = {1: 5, 2: 45}
top_classes = frog(50), bird(37), tool(8)
epoch1
mean_act_by_track = {0: 0.996, 1: 0.312, 2: 0.625, 3: 0.339}
active_hist = {1: 11, 2: 39}
top_classes = frog(50), wind(39)
解释:
- 第二条活跃轨从
track1转向track2 - 类别也从
bird变成了wind - 说明 query 责任在重排,但 class binding 还不稳定
ov3 代表样本 valid__ov3_000004__pred.csv
epoch0
mean_act_by_track = {0: 0.977, 1: 0.519, 2: 0.036, 3: 0.0}
active_hist = {1: 17, 2: 32}
top_classes = tool(81)
epoch1
mean_act_by_track = {0: 0.966, 1: 0.665, 2: 0.065, 3: 0.127}
active_hist = {1: 11, 2: 39}
top_classes = tool(89)
解释:
- 第二条轨更稳定地亮了
- 但 class collapse 没缓解,反而更统一地塌成
tool
因此当前对 epoch1 的判断是:
ov2:有变化,但不能算明显变好ov3:暂时没有变好,仍然是当前最大问题点
10. 当前诊断
截至目前,对 v8 的判断是:
10.1 已经证明的事
v8没有走向v7j那种 duplicate 崩坏activity分离做得比担心中要好ov2/ov3上的多轨输出能力确实在长出来
10.2 还没解决的事
ov2/ov3的 class binding 还不稳ov3仍然存在明显 class collapse- stage 1 期间不训
dir/dist,所以F20 / LE_CD / oazi现在还不能作为最终判断
10.3 当前最值得关注的信号
epoch 3+进入 stage 2 后:ocls是否继续上升oazi是否明显下降F20是否出现拐点
ov3的 top predicted classes 是否开始从单一tool分裂成多个类ov23的平均 active tracks 是否继续上涨过快
11. 本轮实际代码修改
本轮已经落地的修改:
11.1 新增 v8 融合模块
spatial_modules.py:1723-1842LocalSpatialCrossFusionBlockLocalSpatialCrossFuser
11.2 新增 fusion 配置
spatial_beats.py:173-182local_spatial_fusion_modelocal_spatial_fusion_layerslocal_spatial_fusion_headslocal_spatial_fusion_dropoutlocal_spatial_fusion_gate_biaslocal_spatial_fusion_direct_gate_bias
11.3 改 fused token 构造
spatial_beats.py:1104-1118- 从
semantic + local_update - 改成支持
local_spatial_fuser(...)
- 从
11.4 新增 v8 preset
train_spatial_beats.py:1536-1549- 继承
v7h - 打开
cross_attn_gated - 打开 two-stage spatial warmup
- 输出目录改为
v8_ov123_exp/03_ov123_top4
- 继承
11.5 训练侧接线
train_spatial_beats.py:2528-2535local_spatial_fuser加入always_train_prefixes
train_spatial_beats.py:2646-2653local_spatial_fuser.加入_SPATIAL_PREFIXES
11.6 新增脚本
run_ov1_v8_ov123_top4.sh:1-68
12. 下一步建议
当前建议不改结构,先继续训练 v8:
- 至少跑到
epoch 3之后,确认 stage 2 开启后的趋势 - 优先看
ov3是否开始摆脱toolcollapse - 如果
epoch 3-5之后仍然:ov3继续单类塌缩ocls不升oazi不降F20没有明显抬升 再考虑下一轮结构修改
当前最合理的工作顺序是:
- 先把
v8跑穿 stage 1 / stage 2 - 再根据
ov23的 class collapse 是否缓解,决定下一轮改:- query decoder
- matching
- finer token rate
- 或额外的 class-preserving auxiliary