Spatial-BEATs / docs /0416.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
9.18 kB

0416 Spatial-BEATs 实验记录

本文档记录 2026-04-16 对话期间的所有改动、实验和结果。

1. Bug 修复

1.1 方位角坐标系修复(val_predictions 显示 bug)

问题:val_predictions jsonl 文件中 gt_azimuth_deg 是 DCASE 标准 [-180°, 180°],但 pred_azimuth_degtorch.remainder(..., 360.0) 转成了 [0°, 360°]。导致 GT=-168° 和 Pred=192° 看起来差 360° 但实际是同一个方向。

影响:仅影响 val_predictions 文件的可读性。训练 loss 和 epoch log metrics 不受影响_circular_distance_deg 对两种坐标系都能正确计算圆周距离)。

修复spatial_loss.py):

  • _azi_ele_deg_from_direction_vector:去掉 torch.remainder(..., 360.0)atan2 直接返回 [-180°, 180°]
  • 新增 _to_dcase_azimuth() 工具函数
  • build_validation_examplesbuild_pretrunk_ast_validation_examples 中 pred_azi 使用 _to_dcase_azimuth 转换

1.2 Vocabulary 标签修复(65→63 类)

问题 1female_singing(171例) 和 male_singing(104例) 与 singing 是父子关系,模型在 65 类 softmax 下完全区分不了(test accuracy=0%)。

问题 2string_instrument 类中有 1644 个样本(22.8%)是 Hi-hat/Crash_cymbal/Cymbal 打击乐,被错误映射。这些应该归入 percussion

修复fix_vocabulary_and_manifests.py):

  • female_singing + male_singing → 合并到 singing(65→63 类)
  • string_instrumentmono_primary_label 为 Hi-hat/Crash_cymbal/Cymbal 的样本 → 重标为 percussion
  • 对 ov1/ov2/ov3 manifest 全部原地修复,备份为 .bak_20260416
  • final_vocabulary.csv 重编号为连续的 1~63

受影响文件

文件 变更
final_vocabulary.csv 65→63 类,去掉 female_singing/male_singing
ov1_foa.jsonl 460 条 singing 合并 + 1644 条 cymbal 修正
ov2_foa.jsonl / ov3_foa.jsonl 无变动(这些类别未出现)

代码 default 值更新

文件 变更
spatial_beats.py source_num_classes: 65 → 63
spatial_dataset.py num_classes: 65 → 63
spatial_modules.py 6 处函数参数默认值 65 → 63
spatial_atst.py source_num_classes: 65 → 63

1.3 load_checkpoint strict 改为 non-strict

问题:stage1→stage2 resume 时,如果两阶段 model config 不完全一致(比如 stage1 无 semantic_anchor 但 stage2 有),strict=True 会报错 missing key。

修复train_spatial_beats.py):load_checkpoint 改用 strict=False,missing/unexpected key 打印 warning 不报错。

2. v2 Test 集完整评测

编写了 eval_spatial_beats.py 评测脚本,在 ov1 完整 test 集(1800 样本) 上评测了 v2 stage2 best.pt。

checkpointcheckpoints/spatial_beats_ov1_local_spatial_v2_exp/02_spatial/best.pt

指标
class_acc 43.83%
azi_mae 19.87°
ele_mae 8.66°
dist_mae 0.554 m
SELD F1 0.3139
SELD LR 0.6956
SELD LE 8.64°
SELD Score ↓ 0.6027

63 类映射后(post-hoc)

指标 65类 63类映射
class_acc 43.83% 44.44% (+0.6%)
空间指标 不变 不变

提升来自 singing 子类合并(+11 个正确样本)。

Per-class 分析亮点

  • Accuracy=0 的 4 个类:female_singing, male_singing, tape, wood
  • Accuracy≥0.7 的 5 个类:bird(70.6%), guitar(70.9%), thunderstorm(76%), knock(77.8%), car(100%)
  • 最大混淆来源:父子标签层级冲突(guitar↔keyboard_instrument, wind_instrument→musical_instrument, male_singing→singing)

3. bypass / purify 实验结果

两个新架构实验的 stage1 都全 trunk 解冻(12 层),目标是用更激进的策略达到更高 class_acc。

实验 Stage1 策略 Stage1 class_acc Stage2 SELD
bypass 全解冻 + bypass_local_fusion + 零空间 ~25% ❌ 0.743
purify 全解冻 + freeze_local_spatial + 零空间 ~25% ❌ 0.816
v2(对照) top-2 + semantic_anchor + 空间多任务 ~56% 0.603

结论:全 trunk 解冻 + SpatialBEATs 架构 = 崩。top-2 解冻更稳定。

4. 分类能力损失分析

纯 BEATs 分类实验(无空间任务)

解冻策略 val_acc
head_only(trunk 全冻) 62.6%
top-8(层 4-11) 69.1%
full(全 12 层) 70.0%

SpatialBEATs 分类损失链

纯 BEATs 62.6%(trunk 全冻)
  → SpatialBEATs v2 stage1: 56%  (-6.6%: CNN噪声 + 多任务干扰 + 只解冻top-2)
  → SpatialBEATs v2 stage2: 44%  (-12%: λ_dir=12 空间梯度冲击语义)

关键发现:v4 的 trunk 起点

v4 config 中 class_finetuned_ckpt 之前指向的是 head_only/best.pt(62.6%),其 trunk 权重和原始 BEATs 完全相同(head_only 全冻 trunk 训练)。已修正为指向 02_full/best.pt(70%),获取 FSD50K 适配过的 trunk。

5. 失败实验

v3 / v3ws(bypass + top-8/top-4)

  • v3 epoch1: cls=0.116 ❌
  • v3ws epoch1: cls=0.080 ❌
  • 原因:bypass 模式 + SpatialBEATs = DDP 不稳定

v3b / v3bws(freeze_local_spatial + top-8/top-4)

  • v3b 15 epoch: cls=35.67%,azi=89.66°(空间几乎随机)
  • v3bws: 无效果
  • 原因:多变量同时改动(top-8 + freeze_local_spatial + ddp_find_unused + 无 anchor),不如 v2 的组合

6. 新增架构:Frame-Level Track Supervision

动机

当前 mono_ast 是 clip-level 预测(attention pool → 1 class + 1 direction),无法:

  1. 输出 DCASE 格式逐帧检测
  2. 扩展到 ov2/ov3 多源
  3. 约束 trunk 逐帧表征质量

实现

readout_scheme="local_spatial" 下新增可选 FrameTrack 分支,与 clip-level head 并行运行

fused_spatial_embeddings [B, T_s, D]
  ├── attention pool → clip-level mono_ast prediction (已有)
  └── SourceQueryDecoder → [B, K, T_s, D] → FrameTrack prediction (新增)

完全复用已有代码SourceQueryDecoderFrameTrackPredictionHeadscompute_frame_track_losses 一行都没改。

控制开关

  • SpatialBEATsConfig.enable_frame_track: bool = False
  • SpatialLossConfig.enable_frame_track_loss: bool = False
  • 默认关闭,所有现有实验零影响

改动文件

文件 改动
spatial_beats.py Config flag + init 创建 head + forward 产出
spatial_loss.py Config flag
train_spatial_beats.py run_train_step 追加 loss + validate 追加 examples + preset
spatial_modules.py 不改

新 presetov1_local_spatial_v4f_spatial

7. 当前实验矩阵

正在跑

实验 状态 配置要点
v4 stage1 运行中 v2 架构复刻 + 63 类 + 70% trunk init

等 v4 stage1 结束后

实验 脚本 配置要点
v4 stage2 run_ov1_v4.sh v2 复刻(λ_dir=12, anchor=0.5)
v4g stage2 run_ov1_v4g.sh 温和版(λ_dir=6, λ_cls=2, anchor=1.5)
v4f stage2 run_ov1_v4f.sh v4 + 并行 frame-level track head

预期效果

实验 预期 class_acc 预期 azi_mae 新能力
v4 stage2 ~44%(同 v2) ~20° baseline
v4g stage2 ~50%+ ~25-30° 分类更好,空间稍差
v4f stage2 ~44% + frame metrics ~20° DCASE 逐帧输出

8. 新增文件清单

文件 用途
eval_spatial_beats.py 独立评测脚本,支持所有 preset
fix_vocabulary_and_manifests.py 一次性 vocab+manifest 修复脚本
run_ov1_v3.sh v3 实验(bypass,已失败)
run_ov1_v3ws.sh v3ws 实验(bypass+warmstart,已失败)
run_ov1_v3b.sh v3b 实验(freeze_local_spatial,效果差)
run_ov1_v3bws.sh v3bws 实验(同上+warmstart)
run_ov1_v4.sh v4 实验(v2 复刻 + 63 类 + 70% trunk)
run_ov1_v4g.sh v4g 温和空间版
run_ov1_v4f.sh v4f frame-level track 版

9. 关键经验总结

  1. 不要全解冻 trunk:纯 BEATs 全解冻 OK(70%),但 SpatialBEATs 架构下全解冻必崩(25%)。top-2 是当前唯一验证过的安全策略。

  2. bypass_local_fusion 不可用:在 DDP 训练下导致不稳定,即使只是 top-8 解冻也会崩。freeze_local_spatial 稍好但仍差。

  3. Semantic anchor 有效:v2 的 anchor(λ=0.5)让 stage2 class_acc 从 25%(无 anchor 的 kaldi_spatial)保到 44%。

  4. 标签质量是分类瓶颈:65 类中父子层级冲突(singing/female_singing、musical_instrument/string_instrument)和标签映射 bug(cymbal→string_instrument)贡献了大量"假错误"。

  5. trunk 初始化很重要:v4 之前一直用的是 head_only/best.pt(trunk ≡ 原始 BEATs),现在改为 02_full/best.pt(trunk 已适配 FSD50K,70%),预期 stage1 起点更高。

  6. 一次只改一个变量:v3b/v3bws 同时改了 5 个变量导致无法诊断,v4 只改了 trunk init 这一项。