Spatial-BEATs / docs /0416.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
9.18 kB
# 0416 Spatial-BEATs 实验记录
本文档记录 2026-04-16 对话期间的所有改动、实验和结果。
## 1. Bug 修复
### 1.1 方位角坐标系修复(val_predictions 显示 bug)
**问题**:val_predictions jsonl 文件中 `gt_azimuth_deg` 是 DCASE 标准 `[-180°, 180°]`,但 `pred_azimuth_deg` 被 `torch.remainder(..., 360.0)` 转成了 `[0°, 360°]`。导致 GT=-168° 和 Pred=192° 看起来差 360° 但实际是同一个方向。
**影响**:仅影响 val_predictions 文件的可读性。**训练 loss 和 epoch log metrics 不受影响**(`_circular_distance_deg` 对两种坐标系都能正确计算圆周距离)。
**修复**`spatial_loss.py`):
- `_azi_ele_deg_from_direction_vector`:去掉 `torch.remainder(..., 360.0)``atan2` 直接返回 `[-180°, 180°]`
- 新增 `_to_dcase_azimuth()` 工具函数
- `build_validation_examples``build_pretrunk_ast_validation_examples` 中 pred_azi 使用 `_to_dcase_azimuth` 转换
### 1.2 Vocabulary 标签修复(65→63 类)
**问题 1**`female_singing`(171例) 和 `male_singing`(104例) 与 `singing` 是父子关系,模型在 65 类 softmax 下完全区分不了(test accuracy=0%)。
**问题 2**`string_instrument` 类中有 **1644 个样本**(22.8%)是 Hi-hat/Crash_cymbal/Cymbal 打击乐,被错误映射。这些应该归入 `percussion`。
**修复**(`fix_vocabulary_and_manifests.py`):
- `female_singing` + `male_singing` → 合并到 `singing`(65→63 类)
- `string_instrument``mono_primary_label` 为 Hi-hat/Crash_cymbal/Cymbal 的样本 → 重标为 `percussion`
- 对 ov1/ov2/ov3 manifest 全部原地修复,备份为 `.bak_20260416`
- `final_vocabulary.csv` 重编号为连续的 1~63
**受影响文件**
| 文件 | 变更 |
|---|---|
| `final_vocabulary.csv` | 65→63 类,去掉 female_singing/male_singing |
| `ov1_foa.jsonl` | 460 条 singing 合并 + 1644 条 cymbal 修正 |
| `ov2_foa.jsonl` / `ov3_foa.jsonl` | 无变动(这些类别未出现) |
**代码 default 值更新**
| 文件 | 变更 |
|---|---|
| `spatial_beats.py` | `source_num_classes: 65 → 63` |
| `spatial_dataset.py` | `num_classes: 65 → 63` |
| `spatial_modules.py` | 6 处函数参数默认值 `65 → 63` |
| `spatial_atst.py` | `source_num_classes: 65 → 63` |
### 1.3 load_checkpoint strict 改为 non-strict
**问题**:stage1→stage2 resume 时,如果两阶段 model config 不完全一致(比如 stage1 无 semantic_anchor 但 stage2 有),`strict=True` 会报错 missing key。
**修复**`train_spatial_beats.py`):`load_checkpoint` 改用 `strict=False`,missing/unexpected key 打印 warning 不报错。
## 2. v2 Test 集完整评测
编写了 `eval_spatial_beats.py` 评测脚本,在 ov1 **完整 test 集(1800 样本)** 上评测了 v2 stage2 best.pt。
**checkpoint**`checkpoints/spatial_beats_ov1_local_spatial_v2_exp/02_spatial/best.pt`
| 指标 | 值 |
|---|---|
| class_acc | **43.83%** |
| azi_mae | **19.87°** |
| ele_mae | 8.66° |
| dist_mae | 0.554 m |
| SELD F1 | 0.3139 |
| SELD LR | 0.6956 |
| SELD LE | 8.64° |
| **SELD Score ↓** | **0.6027** |
### 63 类映射后(post-hoc)
| 指标 | 65类 | 63类映射 |
|---|---|---|
| class_acc | 43.83% | **44.44%** (+0.6%) |
| 空间指标 | 不变 | 不变 |
提升来自 singing 子类合并(+11 个正确样本)。
### Per-class 分析亮点
- **Accuracy=0 的 4 个类**:female_singing, male_singing, tape, wood
- **Accuracy≥0.7 的 5 个类**:bird(70.6%), guitar(70.9%), thunderstorm(76%), knock(77.8%), car(100%)
- **最大混淆来源**:父子标签层级冲突(guitar↔keyboard_instrument, wind_instrument→musical_instrument, male_singing→singing)
## 3. bypass / purify 实验结果
两个新架构实验的 stage1 都**全 trunk 解冻**(12 层),目标是用更激进的策略达到更高 class_acc。
| 实验 | Stage1 策略 | Stage1 class_acc | Stage2 SELD |
|---|---|---|---|
| bypass | 全解冻 + bypass_local_fusion + 零空间 | ~25% ❌ | 0.743 |
| purify | 全解冻 + freeze_local_spatial + 零空间 | ~25% ❌ | 0.816 |
| **v2(对照)** | top-2 + semantic_anchor + 空间多任务 | **~56%** ✅ | **0.603** |
**结论**:全 trunk 解冻 + SpatialBEATs 架构 = 崩。top-2 解冻更稳定。
## 4. 分类能力损失分析
### 纯 BEATs 分类实验(无空间任务)
| 解冻策略 | val_acc |
|---|---|
| head_only(trunk 全冻) | 62.6% |
| top-8(层 4-11) | **69.1%** |
| full(全 12 层) | 70.0% |
### SpatialBEATs 分类损失链
```
纯 BEATs 62.6%(trunk 全冻)
→ SpatialBEATs v2 stage1: 56% (-6.6%: CNN噪声 + 多任务干扰 + 只解冻top-2)
→ SpatialBEATs v2 stage2: 44% (-12%: λ_dir=12 空间梯度冲击语义)
```
### 关键发现:v4 的 trunk 起点
v4 config 中 `class_finetuned_ckpt` 之前指向的是 `head_only/best.pt`(62.6%),其 trunk 权重和原始 BEATs **完全相同**(head_only 全冻 trunk 训练)。已修正为指向 `02_full/best.pt`(70%),获取 FSD50K 适配过的 trunk。
## 5. 失败实验
### v3 / v3ws(bypass + top-8/top-4)
- v3 epoch1: cls=0.116 ❌
- v3ws epoch1: cls=0.080 ❌
- 原因:bypass 模式 + SpatialBEATs = DDP 不稳定
### v3b / v3bws(freeze_local_spatial + top-8/top-4)
- v3b 15 epoch: cls=35.67%,azi=89.66°(空间几乎随机)
- v3bws: 无效果
- 原因:多变量同时改动(top-8 + freeze_local_spatial + ddp_find_unused + 无 anchor),不如 v2 的组合
## 6. 新增架构:Frame-Level Track Supervision
### 动机
当前 `mono_ast` 是 clip-level 预测(attention pool → 1 class + 1 direction),无法:
1. 输出 DCASE 格式逐帧检测
2. 扩展到 ov2/ov3 多源
3. 约束 trunk 逐帧表征质量
### 实现
`readout_scheme="local_spatial"` 下新增可选 `FrameTrack` 分支,与 clip-level head **并行运行**
```
fused_spatial_embeddings [B, T_s, D]
├── attention pool → clip-level mono_ast prediction (已有)
└── SourceQueryDecoder → [B, K, T_s, D] → FrameTrack prediction (新增)
```
**完全复用已有代码**`SourceQueryDecoder``FrameTrackPredictionHeads``compute_frame_track_losses` 一行都没改。
**控制开关**
- `SpatialBEATsConfig.enable_frame_track: bool = False`
- `SpatialLossConfig.enable_frame_track_loss: bool = False`
- 默认关闭,所有现有实验零影响
**改动文件**
| 文件 | 改动 |
|---|---|
| `spatial_beats.py` | Config flag + __init__ 创建 head + forward 产出 |
| `spatial_loss.py` | Config flag |
| `train_spatial_beats.py` | run_train_step 追加 loss + validate 追加 examples + preset |
| `spatial_modules.py` | **不改** |
**新 preset**`ov1_local_spatial_v4f_spatial`
## 7. 当前实验矩阵
### 正在跑
| 实验 | 状态 | 配置要点 |
|---|---|---|
| v4 stage1 | 运行中 | v2 架构复刻 + 63 类 + **70% trunk init** |
### 等 v4 stage1 结束后
| 实验 | 脚本 | 配置要点 |
|---|---|---|
| v4 stage2 | `run_ov1_v4.sh` | v2 复刻(λ_dir=12, anchor=0.5) |
| v4g stage2 | `run_ov1_v4g.sh` | 温和版(λ_dir=6, λ_cls=2, anchor=1.5) |
| v4f stage2 | `run_ov1_v4f.sh` | v4 + 并行 frame-level track head |
### 预期效果
| 实验 | 预期 class_acc | 预期 azi_mae | 新能力 |
|---|---|---|---|
| v4 stage2 | ~44%(同 v2) | ~20° | baseline |
| v4g stage2 | **~50%+** | ~25-30° | 分类更好,空间稍差 |
| v4f stage2 | ~44% + frame metrics | ~20° | DCASE 逐帧输出 |
## 8. 新增文件清单
| 文件 | 用途 |
|---|---|
| `eval_spatial_beats.py` | 独立评测脚本,支持所有 preset |
| `fix_vocabulary_and_manifests.py` | 一次性 vocab+manifest 修复脚本 |
| `run_ov1_v3.sh` | v3 实验(bypass,已失败) |
| `run_ov1_v3ws.sh` | v3ws 实验(bypass+warmstart,已失败) |
| `run_ov1_v3b.sh` | v3b 实验(freeze_local_spatial,效果差) |
| `run_ov1_v3bws.sh` | v3bws 实验(同上+warmstart) |
| `run_ov1_v4.sh` | v4 实验(v2 复刻 + 63 类 + 70% trunk) |
| `run_ov1_v4g.sh` | v4g 温和空间版 |
| `run_ov1_v4f.sh` | v4f frame-level track 版 |
## 9. 关键经验总结
1. **不要全解冻 trunk**:纯 BEATs 全解冻 OK(70%),但 SpatialBEATs 架构下全解冻必崩(25%)。top-2 是当前唯一验证过的安全策略。
2. **bypass_local_fusion 不可用**:在 DDP 训练下导致不稳定,即使只是 top-8 解冻也会崩。freeze_local_spatial 稍好但仍差。
3. **Semantic anchor 有效**:v2 的 anchor(λ=0.5)让 stage2 class_acc 从 25%(无 anchor 的 kaldi_spatial)保到 44%。
4. **标签质量是分类瓶颈**:65 类中父子层级冲突(singing/female_singing、musical_instrument/string_instrument)和标签映射 bug(cymbal→string_instrument)贡献了大量"假错误"。
5. **trunk 初始化很重要**:v4 之前一直用的是 `head_only/best.pt`(trunk ≡ 原始 BEATs),现在改为 `02_full/best.pt`(trunk 已适配 FSD50K,70%),预期 stage1 起点更高。
6. **一次只改一个变量**:v3b/v3bws 同时改了 5 个变量导致无法诊断,v4 只改了 trunk init 这一项。