| # 2026-04-23 — v9_ov123_top4:class-first 针对性修复(A→D→E→B→C→F) |
|
|
| ## 0. 背景与本轮任务 |
|
|
| v8 / v8a 在 `03_ov123_top4` 上完成了 10+ 个 epoch 的训练,整体 spatial 指标比 v7h 有小幅改善,但 `class_ok` 一直卡在 **~45%** 左右。用户要求: |
|
|
| 1. 专注 frame 级预测,不引入任何 clip 级监督; |
| 2. 分析 v8 / v8a 的 val CSV(epoch 9 / epoch 11),找出 class 准确率上不去的真正瓶颈; |
| 3. 在不破坏现有代码框架和训练逻辑的前提下,按顺序落地 Fix A→D→E→B→C→F; |
| 4. 全部作为 v9 新 preset + 新脚本,从 v8a `best.pt` 热启动(保证 epoch0 前向输出与 v8a 完全相同)。 |
|
|
| 本文档记录 2026-04-23 这一轮的 **所有代码修改细节**,以及每个修改背后的 CSV 诊断证据。 |
|
|
| 相关旧文档: |
| - `docs/0422.md`:v7h → v8 架构升级(cross-attention fusion) |
| - `docs/0422_v7h_v7j.md`:v7h / v7j / v7i 的 per-frame 多源诊断与 class-weighted CE 的引入 |
| - `docs/0421.md`:v7f → ov123 per-frame 扩展与 frame-track CSV dump |
|
|
| --- |
|
|
| ## 1. CSV 诊断:v8 ep9 / v8a ep11 的错误模式 |
|
|
| ### 1.1 整体准确率(按 "activity ≥ 0.5 的 active track 里 DOA 最近的一条" 取 class) |
|
|
| | 指标 | v8 ep9 | v8a ep11 | |
| |---|---|---| |
| | overall `cls_ok` | 933/2051 = **45.5%** | 934/2051 = **45.5%** | |
| | DOA≤20° 命中内 `cls_ok` | 48.1% | 51.9% | |
| | **Oracle cls**(无视 activity,取 DOA 最近 track 的 class) | **38.7%** | **46.4%** | |
|
|
| 两个关键数字几乎相等:`cls_ok(nearest active) ≈ oracle_cls`。这说明: |
|
|
| > **activity 选的 track 和 DOA 选的 track 语义上一致**,K=4 的 track binding 是 ok 的;真正错的是 **"那条被选中的 track 自己的 class_logits 就预测错了"**。 |
| |
| 换句话说:**瓶颈不在 matching,不在 activity,而在 class head 本身的输出分布**。 |
| |
| ### 1.2 按 ov 分组 |
| |
| v8 ep9: |
| ``` |
| ov2 : gt= 682 DOA_ok=51.6% cls_ok=37.1% any_track_cls_ok=38.9% cls&DOA=23.9% |
| ov3 : gt=1158 DOA_ok=42.5% cls_ok=47.1% any_track_cls_ok=59.6% cls&DOA=18.0% |
| other: gt= 211 DOA_ok=97.6% cls_ok=64.0% any_track_cls_ok=64.0% cls&DOA=63.0% |
| ``` |
| |
| v8a ep11 类似,ov3 略退: |
| ``` |
| ov2 : cls_ok=44.0% cls&DOA=33.7% |
| ov3 : cls_ok=43.1% cls&DOA=14.2% |
| other: cls_ok=64.0% cls&DOA=63.0% |
| ``` |
| |
| `other`(近似 ov1 单源)是 64%,ov2 是 37-44%,ov3 降到 43% 且 `cls&DOA` 只有 14%。**多源重叠帧的 class 比单源帧低 20 个点**,这是 ov3 demixing 失败的直接证据。 |
| |
| ### 1.3 每类错误模式(Oracle:无视 activity 取 DOA 最近) |
| |
| ``` |
| aircraft (n=100): speech(50%) human_vocalization(50%) → 永远 0% 正确 |
| vehicle (n= 68): machine(74%) aircraft(13%) train(10%) → 永远 0% 正确 |
| frog (n= 50): bird(98~100%) → 永远预测成 bird |
| speech (n= 53): human_vocalization(94%) breathing(6%) → 永远 0% 正确 |
| crackle (n= 51): rain(53%) machine(41%) → 永远 0% 正确 |
| wind (n= 16): fire(50~88%) vehicle/wood → 永远 0% 正确 |
| drawer_cab(n= 50): tool(74~84%) → 永远 0% 正确 |
| tape (n= 50): human_vocalization(88%) typing(46%) → 永远 0% 正确 |
| knock (n= 36): home_sound(64%) human_vocalization → 永远 0% 正确 |
| train (n=115): train(43%) crushing(37%) vehicle(15%) → ~10% |
| ``` |
| |
| 错误全都是 **同父类 "sibling collapse"**: |
| |
| - `aircraft / vehicle / train` 同属 transportation; |
| - `speech / human_vocalization / breathing / laughter` 同属 human-voice; |
| - `frog / bird / insect` 同属 animal-vocal; |
| - `drawer_cabinet / tool / home_sound / door` 同属 indoor-mechanical。 |
| |
| ### 1.4 Label bug 复核:frog→bird 是不是数据映射错了? |
| |
| 查了 `build_ov{1,23}_foa_dataset.py`、`final_vocabulary.csv`、`spatial_dataset.py:_resolve_class_index`,以及 jsonl 里具体样本: |
| |
| ``` |
| ov1 train frog: 113 ov1 train bird: 3270 → 1:29 |
| ov1 valid frog: 9 ov1 valid bird: 56 |
| ov1 test frog: 3 ov1 test bird: 51 |
| ov2 train frog:1078 ov2 train bird: 1255 |
| ov3 train frog:1384 ov3 train bird: 1432 |
| ``` |
| |
| **frog 标签完全正确**,不是 label bug。100% 预测成 bird 的根因是 ov1 训练集里 bird 比 frog 多 **29 倍**,纯粹的 class imbalance;其他两个 manifest (ov2/ov3) 已经平衡,但 ov1 训练数据量占比仍然不小,整体 prior 偏向 bird。 |
| |
| ## 2. 六个 Fix 的整体设计原则 |
| |
| 1. **全部 additive**:v9 只 **增加** 新 config 字段与新子模块,从不删除或覆盖 v8a 已有的路径; |
| 2. **全部 zero-init 或 identity-init**:新参数在 ckpt 加载时贡献 0(可验证 `v9(v8a.pt) - v8a(v8a.pt)` 前向最大 abs diff = 0); |
| 3. **strict=False 兼容**:v9 的 18 个新参数在 v8a ckpt 中缺失,走现有 `_load_spatial_init_checkpoint` / `load_state_dict(..., strict=False)` 流程自然初始化; |
| 4. **不改 forward signature 的兼容性**:`FrameTrackPredictionHeads.forward` 新增的几个关键字参数默认 `None`,旧调用链(如 v7 / v8 非-track readout)不受影响; |
| 5. **所有 LR / schedule 默认行为等价于 v8a**:`class_head_lr_scale=1.0` 时走原 3-group 快速路径,`frame_class_ontology_smoothing=0.0` 时走原 `F.cross_entropy`。 |
|
|
| ## 3. 文件级修改清单 |
|
|
| | 文件 | 修改内容 | |
| |---|---| |
| | `spatial_loss.py` | 新增 `frame_class_ontology_smoothing`、`frame_class_ontology_groups` 字段;CE 分支支持本体软标签 | |
| | `spatial_modules.py` | `FrameTrackPredictionHeads` 新增 6 个构造参数 + 新增 `ClassHeadSpectralDemixer` 模块 | |
| | `spatial_beats.py` | `SpatialBEATsConfig` 新增 8 个字段;两处 `FrameTrackPredictionHeads` 构造传入新参数;forward 两处 call 传入 `pre_pool_features`;新增 `_derive_pre_pool_time_mask` helper | |
| | `train_spatial_beats.py` | 新增 `class_head_lr_scale`、`class_head_freeze_during_ramp_epochs`、`class_head_lr_scale_during_ramp` 字段;`build_optimizer` 拆出 `cls_head` group;epoch loop 动态写 cls_head LR;新增 `_V9_CLASS_WEIGHTS`、`_V9_ONTOLOGY_GROUPS`、`make_ov1_local_spatial_v9_ov123_top4_config`;preset dispatch 和 `--preset` choices 注册 `v9_ov123_top4` | |
| | `run_ov1_v9_ov123_top4.sh` | 新增启动脚本(chmod +x) | |
|
|
| --- |
|
|
| ## 4. Fix A —— frog/稀有类的诊断(无代码改动) |
|
|
| Fix A 是诊断性质的,没有代码落地。结论:frog 不是 label bug,是 ov1 train 的 29:1 imbalance。解决方案被折进 Fix D: |
|
|
| - `bird` 权重:1.0 → **0.6**(压 catch-all) |
| - `frog` 权重:1.0 → **3.0**(提 minority) |
| - `insect` 权重:从 v7I 的 4.0 → **1.0**(v8/v8a 已经 100%,不需要加权) |
|
|
| ## 5. Fix D —— 重新设计 class 权重 `_V9_CLASS_WEIGHTS` |
| |
| ### 5.1 数据证据 |
| |
| v7I 的权重表给了 aircraft / vehicle / insect **4×** 权重: |
| |
| ``` |
| v8 ep9: aircraft 0%, vehicle 0%, insect 100% |
| v8a ep11: aircraft 0%, vehicle 0%, insect 100%, printer 100%→59%(退化!) |
| ``` |
| |
| - aircraft / vehicle 的 4× 没有任何效果,它们的声学特征与 speech / machine 真实接近; |
| - 4× 让误分类(aircraft→speech)的 loss 放大 4 倍,模型为了降 loss 把 **speech 的预测分布也拉保守**,导致 speech 也掉到 0%(双输); |
| - v8a ep11 还有一个异常:printer 从 v8 ep9 的 100% 掉到 59%,原因见 Fix E。 |
| |
| ### 5.2 设计规则 |
| |
| - **"catch-all" 类**(GT frame 上被当万金油预测的类)→ 权重下调: |
| - `human_vocalization` 0.4(被 speech / tape / typing / knock 当靶子) |
| - `bird` 0.6(frog 几乎全部坍缩到 bird) |
| - `machine` 0.5(vehicle / crackle 坍缩到 machine) |
| - `rain` 0.5(crackle / singing 的 collapse 目标) |
| - `breathing` 0.5(speech 94% 错分到这里) |
| - `home_sound` 0.6(knock / printer 的 collapse) |
| - **易被 collapse 的稀有类** → 权重上调(但不过分,≤ 3×): |
| - `frog` 3.0 |
| - `crackle` 2.0、`tape` 2.0、`knock` 2.0、`drawer_cabinet` 2.0、`speech` 2.0 |
| - `aircraft` / `vehicle` / `train` 回到 1.0-1.5(4× 已证实无效且伤 sibling) |
| - **过度自信的稳健类** → 略压: |
| - `singing` 0.7、`printer` 0.7(在 ov123 任务里吸收错 FP) |
|
|
| ### 5.3 代码改动 |
|
|
| 位置:`train_spatial_beats.py`(紧挨 `_V7I_CLASS_WEIGHTS` 之后) |
|
|
| ```python |
| _V9_CLASS_WEIGHTS: List[float] = [ |
| # 0 wind_instrument 1 string_instrument 2 guitar 3 body_sound |
| 1.0, 1.0, 1.0, 1.0, |
| # 4 drum 5 water 6 human_vocalization 7 keyboard_instrument |
| 1.0, 1.0, 0.4, 1.0, |
| # 8 bird 9 tool 10 machine 11 war_sound |
| 0.6, 1.0, 0.5, 1.0, |
| # 12 metal_clink 13 breathing 14 laughter 15 percussion |
| 1.0, 0.5, 1.0, 1.0, |
| # 16 speech 17 bell 18 dog 19 vehicle |
| 2.0, 1.0, 1.0, 1.5, |
| # 20 alarm 21 footsteps 22 train 23 telephone_alarm |
| 1.0, 1.0, 1.5, 1.0, |
| # 24 glass 25 wind 26 kitchenware 27 animal |
| 1.0, 1.5, 1.0, 1.0, |
| # 28 musical_instrument 29 thunderstorm 30 door 31 male_speech |
| 1.0, 1.0, 1.0, 1.0, |
| # 32 female_speech 33 cat 34 home_sound 35 insect |
| 1.0, 1.0, 0.6, 1.0, |
| # 36 typing 37 zipper 38 camera 39 clock |
| 1.0, 1.0, 1.0, 1.0, |
| # 40 fire 41 singing 42 tearing 43 writing |
| 1.0, 0.7, 1.0, 1.0, |
| # 44 car 45 rain 46 scratch 47 gong |
| 1.0, 0.5, 1.0, 1.0, |
| # 48 appliance 49 paper 50 drawer_cabinet 51 ocean |
| 1.0, 1.0, 2.0, 1.0, |
| # 52 knock 53 crackle 54 finger_snapping 55 aircraft |
| 2.0, 2.0, 1.0, 1.0, |
| # 56 crushing 57 printer 58 tape 59 wood |
| 1.0, 0.7, 2.0, 1.0, |
| # 60 crack 61 cooking 62 frog |
| 1.0, 1.0, 3.0, |
| ] |
| assert len(_V9_CLASS_WEIGHTS) == 63 |
| ``` |
|
|
| v9 preset 里:`cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)`。 |
|
|
| ## 6. Fix E —— class head LR 单独分组 + DOA ramp 期冻结 |
|
|
| ### 6.1 数据证据 |
|
|
| v8a ep5 vs ep11(stage 2 DOA ramp 打开前后): |
|
|
| ``` |
| ep5 ov3: ocls=44.4% printer=100% |
| ep11 ov3: ocls=44.4% printer=59% ← class 被 DOA 梯度扰动掉了 |
| ``` |
|
|
| 打开 dir/dist 的 Hungarian cost 和 loss 之后,class binding 被扰动。传统做法是整体降 LR,但那会连带把 trunk / decoder 的学习也拖慢。更好的做法是 **单独把 class_head 从 optimizer 拉出来,DOA ramp 期间冻结**。 |
| |
| ### 6.2 代码改动 |
| |
| #### 6.2.1 新增 config 字段(`train_spatial_beats.py`) |
| |
| ```python |
| trunk_lr_scale: float = 1.0 |
| spatial_lr_scale: float = 1.0 |
| # v9: isolated LR multiplier for the class_head inside |
| # frame_track_prediction_heads. When < 1.0 the class head is put in its |
| # own param group with lr = base_lr * class_head_lr_scale. Used during |
| # DOA ramp (stage 2) to prevent class binding from being perturbed by |
| # the newly-unlocked dir/dist gradients. 1.0 = legacy behaviour. |
| class_head_lr_scale: float = 1.0 |
| # Optional epoch-range override that further scales the class head LR |
| # specifically during the DOA ramp. When set, between |
| # frame_spatial_loss_warmup_epochs and frame_spatial_loss_warmup_epochs |
| # + class_head_freeze_during_ramp_epochs the class head LR is set to |
| # class_head_lr_scale_during_ramp (defaults to 0.0 = frozen). After the |
| # ramp window the LR returns to class_head_lr_scale. |
| class_head_freeze_during_ramp_epochs: int = 0 |
| class_head_lr_scale_during_ramp: float = 0.0 |
| ``` |
| |
| #### 6.2.2 `build_optimizer` 拆出 cls_head group |
| |
| - 新增 `_CLASS_HEAD_PREFIXES`: |
| ```python |
| _CLASS_HEAD_PREFIXES = ( |
| "frame_track_prediction_heads.class_head.", |
| "frame_track_prediction_heads.class_head_mlp.", |
| "frame_track_prediction_heads.class_head_demixer.", |
| ) |
| ``` |
| - fast path 条件从 `trunk_scale == 1.0 and spatial_scale == 1.0` 改为再加 `and cls_head_scale == 1.0`; |
| - 当 `cls_head_scale != 1.0` 时,匹配 `_CLASS_HEAD_PREFIXES` 的参数从 head_params 抽出,放入 `cls_head_params`; |
| - 每个 param group 增加 `group_name` 字段(`"trunk"` / `"spatial"` / `"head"` / `"cls_head"`),便于 epoch loop 按 name 定位。 |
| |
| 关键代码片段: |
| ```python |
| for name, param in model.named_parameters(): |
| if not param.requires_grad: |
| continue |
| if name.startswith(_CLASS_HEAD_PREFIXES) and cls_head_scale != 1.0: |
| cls_head_params.append(param) |
| elif name.startswith(_TRUNK_PREFIXES): |
| trunk_params.append(param) |
| elif name.startswith(_SPATIAL_PREFIXES): |
| spatial_params.append(param) |
| else: |
| head_params.append(param) |
| |
| param_groups = [] |
| if trunk_params: param_groups.append({"params": trunk_params, "lr": base_lr * trunk_scale, "weight_decay": wd, "group_name": "trunk"}) |
| if spatial_params: param_groups.append({"params": spatial_params, "lr": base_lr * spatial_scale, "weight_decay": wd, "group_name": "spatial"}) |
| if head_params: param_groups.append({"params": head_params, "lr": base_lr, "weight_decay": wd, "group_name": "head"}) |
| if cls_head_params:param_groups.append({"params": cls_head_params,"lr": base_lr * cls_head_scale, "weight_decay": wd, "group_name": "cls_head"}) |
| ``` |
| |
| #### 6.2.3 epoch loop 动态写 cls_head LR |
| |
| 紧跟 spatial loss schedule 之后(在 `_log(f"[Epoch {epoch}] start")` 之前): |
| |
| ```python |
| _cls_ramp_len = int(train_cfg.class_head_freeze_during_ramp_epochs) |
| if _cls_ramp_len > 0 and _sp_warmup > 0 and train_cfg.class_head_lr_scale != 1.0: |
| in_ramp = _sp_warmup <= epoch < _sp_warmup + _cls_ramp_len |
| if in_ramp: |
| _cls_scale = train_cfg.class_head_lr_scale_during_ramp |
| else: |
| _cls_scale = train_cfg.class_head_lr_scale |
| for _g in optimizer.param_groups: |
| if _g.get("group_name") == "cls_head": |
| _g["lr"] = train_cfg.learning_rate * _cls_scale |
| _log( |
| f"[Epoch {epoch}] cls_head_lr scale={_cls_scale:.3f} " |
| f"lr={train_cfg.learning_rate * _cls_scale:.2e} " |
| f"(ramp_window={_sp_warmup}..{_sp_warmup + _cls_ramp_len - 1})" |
| ) |
| ``` |
| |
| #### 6.2.4 v9 preset 设定 |
| |
| ```python |
| cfg.class_head_lr_scale = 0.3 |
| cfg.class_head_freeze_during_ramp_epochs = 4 |
| cfg.class_head_lr_scale_during_ramp = 0.0 |
| ``` |
| |
| v9 的 `SPATIAL_LR=1.5e-5`,各阶段 cls_head LR: |
| |
| | 阶段 | epoch | lambda_dir | cls_head_lr | |
| |---|---|---|---| |
| | stage 1(class-only warmup) | 0-2 | 0.0 | 4.5e-6 | |
| | stage 2(DOA ramp,cls_head 冻结) | 3-6 | ramp 0→1 | **0.0** | |
| | stage 3(全放开) | 7+ | 1.0 | 4.5e-6 | |
|
|
| ## 7. Fix B —— 本体论(hierarchical)标签平滑 |
|
|
| ### 7.1 核心思想 |
|
|
| v8/v8a 的 class 错误 **≥70% 是 sibling collapse**(同 AudioSet 父类)。Hard CE 把 `frog→bird` 和 `frog→aircraft` 一视同仁地惩罚满 log loss,这不合理: |
|
|
| - 对下游 LLM 而言 `frog↔bird` 混淆可以靠语义上下文恢复; |
| - `aircraft↔speech` 这种跨域错误则完全荒谬。 |
|
|
| 希望 loss 的惩罚强度 **匹配错误的 "语义距离"**。最简单的实现:在同父类内部做 label smoothing,跨父类保持硬 CE。 |
|
|
| ### 7.2 代码改动 |
|
|
| #### 7.2.1 `SpatialLossConfig` 新字段(`spatial_loss.py`) |
| |
| ```python |
| # v9 hierarchical label smoothing for the frame-track class head. |
| # When frame_class_ontology_smoothing > 0, the CE target becomes a soft |
| # label distribution: |
| # target[c_gt] = 1 - eps |
| # target[c_sib] = eps / |siblings| (for each sibling in same ontology |
| # group as c_gt; excludes c_gt itself) |
| # target[c_other] = 0 |
| # ... |
| frame_class_ontology_smoothing: float = 0.0 |
| # Parallel list of sibling groups. Each entry is a list of class indices |
| # belonging to the same AudioSet ontology parent. A class may appear in |
| # only one group. Empty list = no hierarchical smoothing. |
| frame_class_ontology_groups: List[List[int]] = None |
| ``` |
| |
| `__post_init__` 把 `None` 替换为 `[]` 以防意外。 |
|
|
| #### 7.2.2 CE 分支改写(`spatial_loss.py:compute_frame_track_losses`) |
|
|
| 原有分支: |
| ```python |
| loss_class = F.cross_entropy( |
| class_logits_flat, class_target_flat, weight=_cls_weights |
| ) |
| ``` |
|
|
| 改为: |
| ```python |
| eps_onto = float(config.frame_class_ontology_smoothing) |
| onto_groups = config.frame_class_ontology_groups |
| if eps_onto > 0.0 and onto_groups: |
| num_classes = class_logits_flat.size(-1) |
| # Build a [C, C] soft-target "mixing" table on first use and |
| # cache it on the config object to avoid per-batch rebuild. |
| if ( |
| not hasattr(config, "_onto_mixing_table") |
| or config._onto_mixing_table is None |
| or config._onto_mixing_table.shape[0] != num_classes |
| or config._onto_mixing_table.dtype != class_logits_flat.dtype |
| or config._onto_mixing_table.device != device |
| ): |
| table = torch.zeros((num_classes, num_classes), dtype=class_logits_flat.dtype, device=device) |
| table.fill_diagonal_(1.0) |
| for group in onto_groups: |
| members = [int(c) for c in group if 0 <= int(c) < num_classes] |
| if len(members) < 2: |
| continue |
| for c in members: |
| siblings = [s for s in members if s != c] |
| table[c].zero_() |
| table[c, c] = 1.0 - eps_onto |
| sib_mass = eps_onto / len(siblings) |
| for s in siblings: |
| table[c, s] = sib_mass |
| config._onto_mixing_table = table |
| soft_target = config._onto_mixing_table[class_target_flat] |
| log_probs = F.log_softmax(class_logits_flat, dim=-1) |
| if _cls_weights is not None: |
| sample_w = _cls_weights[class_target_flat] |
| per_sample_loss = -(soft_target * log_probs).sum(dim=-1) |
| loss_class = (per_sample_loss * sample_w).sum() / sample_w.sum().clamp_min(1e-8) |
| else: |
| loss_class = -(soft_target * log_probs).sum(dim=-1).mean) |
| else: |
| loss_class = F.cross_entropy(class_logits_flat, class_target_flat, weight=_cls_weights) |
| ``` |
|
|
| 实现要点: |
| - **mixing table 缓存**:表存在 `config` 对象上(不是模块),第一次构建并缓存,之后按形状/device/dtype 重用,开销可忽略; |
| - **per-class weight 兼容**:当同时启用 ontology smoothing 和 class weights 时,正类权重照常生效(按 GT hard class idx 加权每样本 loss); |
| - **不在 group 里的类**:table 默认对角线 = 1,所以未列入任何 group 的类自然退化为 hard one-hot,零影响; |
| - **完全开关**:`frame_class_ontology_smoothing=0` 或 `frame_class_ontology_groups=[]` 都会走原 `F.cross_entropy`,v8/v8a 训练复现不受影响。 |
|
|
| ### 7.3 Ontology groups(`_V9_ONTOLOGY_GROUPS`) |
| |
| 位置:`train_spatial_beats.py`,v9 preset 之前。 |
| |
| ```python |
| _V9_ONTOLOGY_GROUPS: List[List[int]] = [ |
| # transportation: aircraft, vehicle, train, car |
| [55, 19, 22, 44], |
| # human voice (non-singing): speech, human_vocalization, male_speech, |
| # female_speech, breathing, laughter |
| [16, 6, 31, 32, 13, 14], |
| # animal vocal: bird, frog, insect, dog, cat, animal |
| [8, 62, 35, 18, 33, 27], |
| # indoor mechanical + appliances: tool, machine, appliance, printer, |
| # home_sound, door, drawer_cabinet, kitchenware, camera, clock, typing, |
| # zipper, tape, cooking |
| [9, 10, 48, 57, 34, 30, 50, 26, 38, 39, 36, 37, 58, 61], |
| # percussive / impact: knock, footsteps, crack, crackle, crushing, |
| # scratch, finger_snapping, tearing, writing, paper |
| [52, 21, 60, 53, 56, 46, 54, 42, 43, 49], |
| # weather / water / ambience: wind, rain, thunderstorm, ocean, water, |
| # fire, glass, metal_clink, wood |
| [25, 45, 29, 51, 5, 40, 24, 12, 59], |
| # musical instruments: wind_instrument, string_instrument, guitar, drum, |
| # keyboard_instrument, percussion, musical_instrument, gong, bell, |
| # singing |
| [0, 1, 2, 4, 7, 15, 28, 47, 17, 41], |
| # alarms / signals: alarm, telephone_alarm, war_sound |
| [20, 23, 11], |
| ] |
| ``` |
| |
| **8 组,覆盖 62/63 个 class**(只有 `body_sound` 没归组,走 hard CE)。每个 class 仅出现在一个组中。 |
|
|
| v9 preset: |
| ```python |
| cfg.loss.frame_class_ontology_smoothing = 0.1 |
| cfg.loss.frame_class_ontology_groups = [list(g) for g in _V9_ONTOLOGY_GROUPS] |
| ``` |
|
|
| ## 8. Fix C —— 频谱级 demixing cross-attention |
|
|
| ### 8.1 动机 |
|
|
| - `track_time_features[B, K, T_s, D]` 是 `SourceQueryDecoder` 从 `fused_embeddings[B, T_s, D]` 里 decode 出来的;`fused_embeddings` 在 `frequency_pool` 之后,**频率维已经被池化掉了**; |
| - DOA 能在多源重叠帧 demix(IV 通道物理上就编码了方向),但 class 没有等价物理解混通路; |
| - 需要让每个 track latent 能"回看"池化前的 trunk 输出,从 F_p 个频率 token 里挑自己负责的那部分。 |
| |
| ### 8.2 模型结构 |
| |
| - 输入: |
| - `track_time_features: [B, K, T_s, D]`(已经过 `input_norm`) |
| - `pre_pool_features: [B, T_p * F_p, D]`(BEATs trunk 输出,无 task tokens,未经 frequency_pool) |
| - `pre_pool_grid_size: (T_p, F_p)` |
| - `pre_pool_time_mask: [B, T_p]`(True = 有效时间步) |
| - 时间对齐:frame `t ∈ [0, T_s)` → trunk 时间步 `t_p = round(t * T_p / T_s)`,clip 到 `[0, T_p-1]`; |
| - KV 构造:`kv_grid[:, t_p, :, :]` → `[B, T_s, F_p, D]`,然后在 K 轴 expand 到 `[B, K, T_s, F_p, D]`,flatten 成 `[B*K*T_s, F_p, D]`; |
| - Query:`track_time_features` reshape 成 `[B*K*T_s, 1, D]`; |
| - 经 1 层 `nn.MultiheadAttention`(`num_heads=8`, `dropout=0.1`,`batch_first=True`),再 `out_proj(Linear 768→768)`,乘以标量 `gate`,加回 class head 的 `class_input`。 |
|
|
| ### 8.3 初始化策略(关键) |
|
|
| - `out_proj.weight` 全零、`out_proj.bias` 全零 → **加载时 demixer 输出 = 0**,class_logits 与 v8a 完全相同; |
| - `gate = 1e-2`(**不是 0**!)→ 前向 = `gate * 0 = 0`(身份保证),但 `∂L/∂out_proj.weight = gate × ...` **非零**,梯度从 step 0 就能流进 demixer 的 attention 权重;这是关键的"gradient-warmup trick",否则两端 zero 会把 demixer 永久冻在 0。 |
|
|
| ### 8.4 `ClassHeadSpectralDemixer` 源码(`spatial_modules.py`) |
| |
| ```python |
| class ClassHeadSpectralDemixer(nn.Module): |
| def __init__(self, embed_dim=768, num_layers=1, num_heads=8, dropout=0.1): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_layers = max(1, int(num_layers)) |
| self.kv_norm = nn.LayerNorm(embed_dim) |
| self.q_norm = nn.LayerNorm(embed_dim) |
| self.layers = nn.ModuleList([ |
| nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, |
| dropout=dropout, batch_first=True) |
| for _ in range(self.num_layers) |
| ]) |
| self.out_proj = nn.Linear(embed_dim, embed_dim) |
| nn.init.zeros_(self.out_proj.weight) |
| nn.init.zeros_(self.out_proj.bias) |
| self.gate = nn.Parameter(torch.full((1,), 1e-2)) |
| |
| def forward(self, track_time_features, pre_pool_features, |
| pre_pool_grid_size, pre_pool_time_mask=None): |
| B, K, T_s, D = track_time_features.shape |
| T_p, F_p = int(pre_pool_grid_size[0]), int(pre_pool_grid_size[1]) |
| if pre_pool_features.size(-1) != D: |
| raise ValueError(...) |
| expected = T_p * F_p |
| if pre_pool_features.size(1) != expected: |
| # Fall back gracefully — demixer is additive & zero-gated. |
| return track_time_features.new_zeros(track_time_features.shape) |
| kv_grid = pre_pool_features.view(B, T_p, F_p, D) |
| if T_s > 0 and T_p > 0: |
| time_idx = torch.arange(T_s, device=kv_grid.device).float() * (T_p / max(1, T_s)) |
| time_idx = time_idx.round().clamp_(0, T_p - 1).long() |
| else: |
| time_idx = torch.zeros((T_s,), dtype=torch.long, device=kv_grid.device) |
| kv_per_frame = kv_grid[:, time_idx, :, :] |
| kv_per_frame = kv_per_frame.unsqueeze(1).expand(B, K, T_s, F_p, D).contiguous() |
| kv_flat = kv_per_frame.view(B * K * T_s, F_p, D) |
| kv_flat = self.kv_norm(kv_flat) |
| q_flat = track_time_features.reshape(B * K * T_s, 1, D) |
| q_flat = self.q_norm(q_flat) |
| key_padding_mask = None |
| if pre_pool_time_mask is not None: |
| per_frame_valid = pre_pool_time_mask[:, time_idx] |
| per_frame_valid = per_frame_valid.unsqueeze(1).expand(B, K, T_s).reshape(-1) |
| if not per_frame_valid.all(): |
| ignore = ~per_frame_valid |
| key_padding_mask = ignore.unsqueeze(1).expand(-1, F_p).contiguous() |
| attn_out = q_flat |
| for layer in self.layers: |
| attn_out, _ = layer(attn_out, kv_flat, kv_flat, |
| key_padding_mask=key_padding_mask, need_weights=False) |
| residual = self.out_proj(attn_out).view(B, K, T_s, D) |
| return residual * self.gate |
| ``` |
| |
| ### 8.5 `_derive_pre_pool_time_mask` helper(`spatial_beats.py`) |
|
|
| 紧跟 `_build_patch_padding_mask` 之后新增: |
|
|
| ```python |
| def _derive_pre_pool_time_mask( |
| self, |
| patch_padding_mask: Optional[Tensor], |
| grid_size: Tuple[int, int], |
| ) -> Optional[Tensor]: |
| """Return a [B, T_p] boolean mask where True marks *valid* trunk time |
| steps. Used by the v9 class-head spectral demixer to ignore padded |
| tail frames.""" |
| if patch_padding_mask is None: |
| return None |
| t_p, f_p = grid_size |
| B = patch_padding_mask.size(0) |
| pad_grid = patch_padding_mask.view(B, t_p, f_p) |
| time_valid = ~pad_grid.all(dim=-1) |
| return time_valid |
| ``` |
|
|
| `patch_padding_mask` 的语义是 `True = padded`,time-valid 是 "某时间步还有非 padded 频率位置" → `~pad_grid.all(dim=-1)`。 |
|
|
| ### 8.6 forward 两处 call 更新(`spatial_beats.py`) |
| |
| 两处调用 `self.frame_track_prediction_heads(...)` 都扩展: |
|
|
| ```python |
| _pre_pool_time_mask = self._derive_pre_pool_time_mask( |
| patch_padding_mask=patch_padding_mask, |
| grid_size=grid_size, |
| ) |
| frame_track_prediction_output = self.frame_track_prediction_heads( |
| track_time_features=track_time_features, |
| track_latents=track_latents, |
| pre_pool_features=encoder_memory, |
| pre_pool_grid_size=grid_size, |
| pre_pool_time_mask=_pre_pool_time_mask, |
| ) |
| ``` |
|
|
| - 第 1 处在 `readout_scheme == "local_spatial"` 分支下 frame-track parallel 路径; |
| - 第 2 处在 `readout_scheme == "local_spatial_track"` 分支下纯 per-frame 路径(v9 走这里)。 |
|
|
| 两处都用到的上下文变量: |
| - `encoder_memory`:`self.encode_patches(...)` 的输出,`[B, T_p*F_p, D]`,**已经过 trunk 但未 frequency_pool**,正好是 demixer 需要的 pre_pool features; |
| - `grid_size`:`self.extract_patch_tokens(...)` 返回的 `(T_p, F_p)`; |
| - `patch_padding_mask`:`self._build_patch_padding_mask(...)` 返回。 |
| |
| ## 9. Fix F —— 2-layer MLP 残差 |
| |
| ### 9.1 动机 |
| |
| 当前 class head 是 `nn.Linear(768, 63)`,对多源混合 token 的表达能力可能不够。加一个 2-layer MLP 残差 branch: |
| |
| ``` |
| class_logits = class_head(x) + gate * class_head_mlp(x) |
| class_head_mlp = Linear(768, 1536) → GELU → Dropout → LayerNorm → Linear(1536, 63) |
| ``` |
| |
| ### 9.2 初始化策略 |
| |
| 与 Fix C 的 demixer 同构: |
| - `class_head_mlp[-1].weight / bias` 全零 → 残差输出 = 0; |
| - `class_head_mlp_gate = 1e-2` → 前向仍然 = 0,但梯度可流进 MLP 最后一层(非零); |
| - 经过 1-2 个 step MLP 最后一层有非零权重后,前一层(GELU 前)也开始获得梯度。 |
| |
| ### 9.3 `FrameTrackPredictionHeads` 重构(`spatial_modules.py`) |
| |
| 构造签名扩展: |
| ```python |
| def __init__( |
| self, |
| embed_dim: int = 768, |
| num_classes: int = 63, |
| dropout: float = 0.1, |
| use_class_head_mlp_residual: bool = False, |
| class_head_mlp_hidden_multiplier: int = 2, |
| class_head_mlp_dropout: float = 0.1, |
| use_class_head_demixer: bool = False, |
| class_head_demixer_layers: int = 1, |
| class_head_demixer_heads: int = 8, |
| class_head_demixer_dropout: float = 0.1, |
| ) -> None: |
| ``` |
| |
| 构造体内新增: |
| ```python |
| self.use_class_head_mlp_residual = bool(use_class_head_mlp_residual) |
| if self.use_class_head_mlp_residual: |
| hidden = embed_dim * max(1, int(class_head_mlp_hidden_multiplier)) |
| self.class_head_mlp = nn.Sequential( |
| nn.Linear(embed_dim, hidden), |
| nn.GELU(), |
| nn.Dropout(class_head_mlp_dropout), |
| nn.LayerNorm(hidden), |
| nn.Linear(hidden, num_classes), |
| ) |
| nn.init.zeros_(self.class_head_mlp[-1].weight) |
| nn.init.zeros_(self.class_head_mlp[-1].bias) |
| self.class_head_mlp_gate = nn.Parameter(torch.full((1,), 1e-2)) |
| else: |
| self.class_head_mlp = None |
| self.class_head_mlp_gate = None |
| |
| self.use_class_head_demixer = bool(use_class_head_demixer) |
| if self.use_class_head_demixer: |
| self.class_head_demixer = ClassHeadSpectralDemixer( |
| embed_dim=embed_dim, |
| num_layers=class_head_demixer_layers, |
| num_heads=class_head_demixer_heads, |
| dropout=class_head_demixer_dropout, |
| ) |
| else: |
| self.class_head_demixer = None |
| ``` |
| |
| forward 改写: |
| ```python |
| def forward( |
| self, |
| track_time_features: Tensor, |
| track_latents: Tensor, |
| pre_pool_features: Optional[Tensor] = None, |
| pre_pool_grid_size: Optional[Tuple[int, int]] = None, |
| pre_pool_time_mask: Optional[Tensor] = None, |
| ) -> FrameTrackPredictionOutput: |
| ... |
| x = self.input_norm(track_time_features) |
| activity = self.activity_head(x).squeeze(-1) |
| class_input = x |
| if ( |
| self.class_head_demixer is not None |
| and pre_pool_features is not None |
| and pre_pool_grid_size is not None |
| ): |
| demix_residual = self.class_head_demixer( |
| track_time_features=x, |
| pre_pool_features=pre_pool_features, |
| pre_pool_grid_size=pre_pool_grid_size, |
| pre_pool_time_mask=pre_pool_time_mask, |
| ) |
| class_input = class_input + demix_residual |
| class_logits = self.class_head(class_input) |
| if self.class_head_mlp is not None and self.class_head_mlp_gate is not None: |
| class_logits = class_logits + self.class_head_mlp_gate * self.class_head_mlp(class_input) |
| direction = F.normalize(self.direction_head(x), dim=-1) |
| distance = F.softplus(self.distance_head(x)).squeeze(-1) |
| ... |
| ``` |
| |
| **关键细节**: |
| - demixer 的 residual 加到 `class_input` 上(即 class_head 的输入),而不是加到 logits 上;这样 demixer 得到的"纠偏"信号先经过 `class_head` 的共享投影再生成 logits; |
| - MLP 分支也看 `class_input`(包含 demixer residual),这样 demixer 的信息同样能进入 MLP 分支; |
| - `direction_head` 和 `distance_head` 的输入 `x` 不加 demixer(demixer 是 class-specific),保持 DOA head 对 v8a ckpt 的完全等价。 |
| |
| ### 9.4 `SpatialBEATsConfig` 新字段(`spatial_beats.py`) |
| |
| ```python |
| self.frame_track_dropout: float = 0.1 |
| self.frame_accdoa_hidden_dim: int = 256 |
| self.frame_accdoa_dropout: float = 0.1 |
| |
| # v9: optional zero-initialised MLP residual branch inside |
| # FrameTrackPredictionHeads. ... |
| self.use_class_head_mlp_residual: bool = False |
| self.class_head_mlp_hidden_multiplier: int = 2 |
| self.class_head_mlp_dropout: float = 0.1 |
| # v9: optional spectral demixing cross-attention branch. ... |
| self.use_class_head_demixer: bool = False |
| self.class_head_demixer_layers: int = 1 |
| self.class_head_demixer_heads: int = 8 |
| self.class_head_demixer_dropout: float = 0.1 |
| ``` |
| |
| ### 9.5 两处构造都传入新参数 |
| |
| `spatial_beats.py` 两处 `FrameTrackPredictionHeads(...)` 调用都改为: |
| |
| ```python |
| self.frame_track_prediction_heads = FrameTrackPredictionHeads( |
| embed_dim=cfg.encoder_embed_dim, |
| num_classes=cfg.source_num_classes, |
| dropout=cfg.frame_track_dropout, |
| use_class_head_mlp_residual=cfg.use_class_head_mlp_residual, |
| class_head_mlp_hidden_multiplier=cfg.class_head_mlp_hidden_multiplier, |
| class_head_mlp_dropout=cfg.class_head_mlp_dropout, |
| use_class_head_demixer=cfg.use_class_head_demixer, |
| class_head_demixer_layers=cfg.class_head_demixer_layers, |
| class_head_demixer_heads=cfg.class_head_demixer_heads, |
| class_head_demixer_dropout=cfg.class_head_demixer_dropout, |
| ) |
| ``` |
| |
| ## 10. v9 preset 与启动脚本 |
| |
| ### 10.1 `make_ov1_local_spatial_v9_ov123_top4_config`(`train_spatial_beats.py`) |
| |
| 位置:v7k 系列的最后一个 preset 之后,`v3: top-8 unfreeze` 注释分割线之前。 |
| |
| ```python |
| def make_ov1_local_spatial_v9_ov123_top4_config( |
| ov1_manifest_path: str = DEFAULT_OV1_MANIFEST, |
| ov2_manifest_path: str = DEFAULT_OV2_MANIFEST, |
| ov3_manifest_path: str = DEFAULT_OV3_MANIFEST, |
| ) -> TrainSpatialBEATsConfig: |
| """v9 = v8a + class-first cleanup (fixes A..F). |
| |
| Inherits v8a (cross-attn fusion + segment matching + 4-epoch DOA ramp), |
| applies: |
| - _V9_CLASS_WEIGHTS (suppress catch-all classes, boost frog/crackle/tape) |
| - ontology-aware label smoothing (eps=0.1) |
| - class head residual MLP + spectral demixer (both zero-init) |
| - class_head_lr_scale=0.3 with full freeze during the 4-epoch DOA ramp |
| |
| Frontend / trunk / source_query_decoder / activity / dir / dist heads |
| are unchanged. Hot-start from v8a best.pt works with strict=False. |
| """ |
| cfg = make_ov1_local_spatial_v8a_ov123_top4_config( |
| ov1_manifest_path=ov1_manifest_path, |
| ov2_manifest_path=ov2_manifest_path, |
| ov3_manifest_path=ov3_manifest_path, |
| ) |
| |
| # (D) Re-balanced class weights driven by v8/v8a CSV confusion analysis. |
| cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS) |
| |
| # (B) Hierarchical (ontology-aware) label smoothing. |
| cfg.loss.frame_class_ontology_smoothing = 0.1 |
| cfg.loss.frame_class_ontology_groups = [list(g) for g in _V9_ONTOLOGY_GROUPS] |
| |
| # (F) Zero-gated MLP residual on the class head. |
| cfg.model.use_class_head_mlp_residual = True |
| cfg.model.class_head_mlp_hidden_multiplier = 2 |
| cfg.model.class_head_mlp_dropout = 0.1 |
| |
| # (C) Zero-gated spectral demixing cross-attention on the class head. |
| cfg.model.use_class_head_demixer = True |
| cfg.model.class_head_demixer_layers = 1 |
| cfg.model.class_head_demixer_heads = 8 |
| cfg.model.class_head_demixer_dropout = 0.1 |
| |
| # (E) Class head gets its own LR group. |
| cfg.class_head_lr_scale = 0.3 |
| cfg.class_head_freeze_during_ramp_epochs = 4 |
| cfg.class_head_lr_scale_during_ramp = 0.0 |
| |
| cfg.num_epochs = 12 |
| cfg.output_dir = "checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/03_ov123_top4" |
| return cfg |
| ``` |
| |
| ### 10.2 preset 注册 |
| |
| `train_spatial_beats.py` 中两处: |
| |
| 1. 在 `args.preset == "ov1_local_spatial_v8a_ov123_top4"` 分支之后,新增: |
| ```python |
| elif args.preset == "ov1_local_spatial_v9_ov123_top4": |
| cfg = make_ov1_local_spatial_v9_ov123_top4_config( |
| ov1_manifest_path=args.ov1_manifest, |
| ov2_manifest_path=args.ov2_manifest, |
| ov3_manifest_path=args.ov3_manifest, |
| ) |
| ``` |
| 2. argparse `--preset` 的 `choices=(...)` 列表里在 `"ov1_local_spatial_v8a_ov123_top4"` 之后加入 `"ov1_local_spatial_v9_ov123_top4"`。 |
| |
| ### 10.3 `run_ov1_v9_ov123_top4.sh` |
| |
| 新建文件(chmod +x): |
| |
| ```bash |
| #!/usr/bin/env bash |
| set -euo pipefail |
| |
| # v9_ov123_top4: v8a + class-first cleanup (Fix A..F from docs/0423.md analysis) |
| # 所有 fix 都是 additive + zero-init,从 v8a best.pt 热启动 epoch-0 输出与 v8a 完全相同 |
| |
| GPUS="${GPUS:-8}" |
| BATCH_SIZE="${BATCH_SIZE:-8}" |
| NUM_WORKERS="${NUM_WORKERS:-8}" |
| SPATIAL_EPOCHS="${SPATIAL_EPOCHS:-12}" |
| SPATIAL_LR="${SPATIAL_LR:-1.5e-5}" |
| AMP="${AMP:-fp32}" |
| |
| OV1_MANIFEST="${OV1_MANIFEST:-/apdcephfs_cq10/.../ov1_foa.jsonl}" |
| OV2_MANIFEST="${OV2_MANIFEST:-/apdcephfs_cq10/.../ov2_foa.jsonl}" |
| OV3_MANIFEST="${OV3_MANIFEST:-/apdcephfs_cq10/.../ov3_foa.jsonl}" |
| |
| RESUME_CKPT="${RESUME_CKPT:-checkpoints/spatial_beats_ov1_local_spatial_v8a_ov123_exp/03_ov123_top4/best.pt}" |
| OUT_DIR="${OUT_DIR:-checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/03_ov123_top4}" |
| |
| torchrun --nproc_per_node="${GPUS}" --master-port="${MASTER_PORT:-29557}" train_spatial_beats.py \ |
| --preset ov1_local_spatial_v9_ov123_top4 \ |
| --resume "${RESUME_CKPT}" \ |
| --output-dir "${OUT_DIR}" \ |
| --ov1-manifest "${OV1_MANIFEST}" \ |
| --ov2-manifest "${OV2_MANIFEST}" \ |
| --ov3-manifest "${OV3_MANIFEST}" \ |
| --batch-size "${BATCH_SIZE}" \ |
| --num-workers "${NUM_WORKERS}" \ |
| --num-epochs "${SPATIAL_EPOCHS}" \ |
| --learning-rate "${SPATIAL_LR}" \ |
| --amp "${AMP}" \ |
| --no-resume-optimizer \ |
| --reset-epoch-on-resume \ |
| --reset-best-on-resume |
| ``` |
| |
| --- |
| |
| ## 11. 正确性验证(已完成) |
| |
| ### 11.1 语法检查 |
| |
| 四个 py 文件(`train_spatial_beats.py` / `spatial_loss.py` / `spatial_beats.py` / `spatial_modules.py`)+ `run_ov1_v9_ov123_top4.sh` 全部通过 ast.parse / bash -n。 |
| |
| ### 11.2 模块级单测 |
| |
| `FrameTrackPredictionHeads` + `ClassHeadSpectralDemixer`: |
| - 开启所有 v9 选项后,**max abs diff = 0.00e+00** 相对于无 v9 选项的同一模型; |
| - 带 padding mask 时 `pred_class_logits` 无 NaN; |
| - 打开 MLP gate / demixer gate + 轻微扰动权重后,logits 差异 ~1e-2(符合预期)。 |
| |
| ### 11.3 v9 vs v8a 端到端前向恒等 |
| |
| ```python |
| torch.manual_seed(42) |
| m_v8a = SpatialBEATs(make_ov1_local_spatial_v8a_ov123_top4_config().model).eval() |
| torch.manual_seed(42) |
| m_v9 = SpatialBEATs(make_ov1_local_spatial_v9_ov123_top4_config().model).eval() |
| |
| # 两个模型都 load v8a/best.pt,strict=False |
| # 同一 waveform 输入 |
| max_abs_diff = (o_v8a - o_v9).abs().max() |
| # 实测:0.00e+00 |
| ``` |
| |
| **v9 加载 v8a ckpt 后的 `pred_class_logits` 与 v8a 模型加载同一 ckpt 的输出逐元素完全相等**(0.00e+00)。 |
| |
| ### 11.4 v8a ckpt 加载统计 |
| |
| ``` |
| ckpt keys = 425 |
| loadable (v9 shapes) = 425 |
| v9 missing params = 18 ← class_head_mlp.* (7 个) + class_head_demixer.* (11 个) |
| ckpt unexpected = 0 |
| ``` |
| |
| 18 个新参数通过 `strict=False` 默认初始化(按 Fix C / F 的 zero-out-proj + tiny-gate 策略)。 |
| |
| ### 11.5 梯度流 |
| |
| 假 CE loss + backward: |
| - `class_head.weight`:grad_norm ≈ 9.06(正常) |
| - `class_head_mlp.4.weight`(MLP 最后一层,zero-init):grad_norm ≈ 1.32e-1(**非零,因为 gate = 1e-2**) |
| - `class_head_mlp.0.weight`(MLP 第一层):grad_norm = 0(正常,要等最后一层非零后才有梯度,1 step 内开解) |
| - `class_head_demixer.out_proj.weight`:grad_norm ≈ 1.60e-2(**非零**) |
| - `class_head_demixer.layers.0.in_proj_weight`:grad_norm = 0(同理) |
| |
| ### 11.6 Optimizer 分组 |
| |
| ``` |
| group 0: name=trunk lr=3.00e-06 n_params=238 |
| group 1: name=spatial lr=9.00e-06 n_params= 84 |
| group 2: name=head lr=3.00e-05 n_params= 91 |
| group 3: name=cls_head lr=9.00e-06 n_params= 19 |
| ``` |
| |
| (上例 base_lr=3e-5;v9 启动脚本里 SPATIAL_LR=1.5e-5,对应 cls_head lr = 4.5e-6。) |
| |
| ## 12. 关键训练时间线(SPATIAL_EPOCHS=12) |
| |
| 继承 v8a 的 `frame_spatial_loss_warmup_epochs=3`、`frame_spatial_loss_ramp_epochs=4`,叠加 v9 新的 cls_head LR 调度: |
| |
| | epoch | lambda_dir / dist | dir/dist match cost | cls_head lr | 说明 | |
| |---|---|---|---|---| |
| | 0-2 | 0.0 | 0.0 | 4.5e-6 | Stage 1:class-only warmup,cls_head 低 LR 继续微调(但本来就 v8a 延续) | |
| | 3 | ramp 0.25 | 0.25 | **0.0** | Stage 2 起点,cls_head 冻结 | |
| | 4 | ramp 0.50 | 0.50 | **0.0** | cls_head 冻结 | |
| | 5 | ramp 0.75 | 0.75 | **0.0** | cls_head 冻结 | |
| | 6 | ramp 1.00 | 1.00 | **0.0** | cls_head 冻结,DOA 全开 | |
| | 7-11 | 1.0 | 1.0 | 4.5e-6 | Stage 3:全部放开,cls_head 低 LR 微调 | |
|
|
| ## 13. 验证清单 / 观察优先级 |
|
|
| 启动后按以下顺序观察指标: |
|
|
| 1. **epoch 0 validation metrics 是否与 v8a 的最后一个 epoch 等价** |
| - 预期 val loss ≈ v8a best val loss(因为前向 = v8a) |
| - 若不等价,说明 hot-start 出问题; |
| 2. **epoch 0-2(stage 1)**: |
| - `ocls`(oracle class acc)是否比 v8a ep2 同期高?关键指标,主要受 Fix D + B + C + F 驱动; |
| - 若 Fix D 生效:aircraft / vehicle 不再 0%,frog 不再 100%→bird; |
| - 若 Fix B 生效:smooth CE loss 比 hard CE 降得更快(从 1 - eps 开始); |
| - 若 Fix C/F 生效:每 epoch 训完后 `class_head_mlp_gate` / `class_head_demixer.gate` 应从 0.01 慢慢增长; |
| 3. **epoch 3-6(stage 2,cls_head 冻结)**: |
| - DOA 指标改善(`LE_CD` 下降,`oazi` 下降),`ocls` **不应退化**(cls_head LR=0,梯度不回传); |
| - 若 `ocls` 退化:可能是 fusion/trunk 的梯度通过 demixer 影响了 class_head 输入分布,此时需要考虑把 demixer 也一起冻; |
| 4. **epoch 7+(stage 3,全开)**: |
| - `F20` 是否突破 v8a 的上限(v7h 基线 0.246,v8a ~ 0.22-0.25); |
| - ov3 class_ok 是否从 43% 往上走; |
| - 若 ov2/ov3 class_ok 都有改善但 ov3 仍然落后:Fix C 的频谱 demixing 还需要再加层数或 heads。 |
| |
| ## 14. 代码 diff 总览 |
| |
| ### 14.1 `spatial_loss.py` |
| - 新增 2 个 `SpatialLossConfig` 字段:`frame_class_ontology_smoothing`、`frame_class_ontology_groups`; |
| - `__post_init__` 处理 `frame_class_ontology_groups = None`; |
| - `compute_frame_track_losses` 里 CE 分支支持 soft target + class weight 复合。 |
| |
| ### 14.2 `spatial_modules.py` |
| - `FrameTrackPredictionHeads.__init__` 新增 6 个 kwargs; |
| - `FrameTrackPredictionHeads.forward` 新增 3 个 Optional kwargs(`pre_pool_features` / `pre_pool_grid_size` / `pre_pool_time_mask`),内部整合 demixer + MLP residual; |
| - 新增 `ClassHeadSpectralDemixer` 模块。 |
| |
| ### 14.3 `spatial_beats.py` |
| - `SpatialBEATsConfig.__init__` 新增 8 个 `self.*` 字段; |
| - 两处 `FrameTrackPredictionHeads(...)` 构造传入全部新参数; |
| - 两处 `self.frame_track_prediction_heads(...)` 调用传入 pre-pool 参数; |
| - 新增 `SpatialBEATs._derive_pre_pool_time_mask` 实例方法。 |
| |
| ### 14.4 `train_spatial_beats.py` |
| - `TrainSpatialBEATsConfig` 新增 3 个字段:`class_head_lr_scale`、`class_head_freeze_during_ramp_epochs`、`class_head_lr_scale_during_ramp`; |
| - 新增常量 `_V9_CLASS_WEIGHTS`(63 长度)+ `_V9_ONTOLOGY_GROUPS`(8 组,覆盖 62/63 class); |
| - 新增 preset 函数 `make_ov1_local_spatial_v9_ov123_top4_config`; |
| - `build_optimizer` 新增 `_CLASS_HEAD_PREFIXES`、`cls_head_params` 分组、`group_name` 标签、fast-path 条件更新; |
| - epoch loop 动态写 cls_head LR(紧跟 spatial loss schedule 后); |
| - preset dispatch 和 `--preset` argparse choices 注册 `ov1_local_spatial_v9_ov123_top4`。 |
| |
| ### 14.5 `run_ov1_v9_ov123_top4.sh` |
| - 新建 shell 脚本,默认 8 GPU / BS=8 / 12 epoch / LR=1.5e-5 / fp32 / master_port=29557; |
| - 默认 RESUME_CKPT 指向 v8a best.pt; |
| - 复用 `--no-resume-optimizer --reset-epoch-on-resume --reset-best-on-resume` 三开关。 |
| |
| ## 15. 关键文件与行数 |
| |
| | 文件 | 新增代码主要位置 | |
| |---|---| |
| | `spatial_loss.py` | `SpatialLossConfig` 新增字段紧跟 `frame_class_loss_weights` 之后;CE 分支改写在 `compute_frame_track_losses` 内 | |
| | `spatial_modules.py` | `FrameTrackPredictionHeads` 整块重写;`ClassHeadSpectralDemixer` 定义在 `FrameTrackPredictionHeads` 之后 | |
| | `spatial_beats.py` | `SpatialBEATsConfig.__init__` 的 v9 字段紧跟 `frame_accdoa_dropout`;`_derive_pre_pool_time_mask` 紧跟 `_build_patch_padding_mask` | |
| | `train_spatial_beats.py` | `_V9_CLASS_WEIGHTS` 紧跟 `_V7I_CLASS_WEIGHTS` 之后;`_V9_ONTOLOGY_GROUPS` 和 `make_ov1_local_spatial_v9_ov123_top4_config` 紧跟 v7k 系列结束处;`class_head_*_scale` 字段紧跟 `spatial_lr_scale` 之后 | |
| |
| ## 16. 后续计划 |
| |
| 如果 v9 ep3 `ocls` 仍不超过 v8a 同期,下一步(v9a/v9b)候选: |
| |
| 1. **demixer 加层**:`class_head_demixer_layers = 2` 或增加 heads 到 16; |
| 2. **demixer 冻结共进退**:DOA ramp 期把 `class_head_demixer.*` 也当作 cls_head group 的一部分(目前 `_CLASS_HEAD_PREFIXES` 已经包含它,在 ramp 期也会冻结——已经生效,但需要验证效果); |
| 3. **ontology smoothing eps 提升**:0.1 → 0.2,进一步容忍 sibling collapse; |
| 4. **source_query_decoder 加正交正则**(前一版建议):`L_orth = ||Q Q^T - I||_F^2`,λ=0.01; |
| 5. **重新审视 matching**:如果 cls 瓶颈解了但 F20 仍不涨,回到 `docs/0422.md` 的 track-dead / duplicate 诊断。 |
| |
| 但当前 v9 已经把 "class head 本身" 这一块做得比较彻底,应该先跑满 12 epoch 再决定。 |
| |