# 2026-04-23 — v9_ov123_top4:class-first 针对性修复(A→D→E→B→C→F) ## 0. 背景与本轮任务 v8 / v8a 在 `03_ov123_top4` 上完成了 10+ 个 epoch 的训练,整体 spatial 指标比 v7h 有小幅改善,但 `class_ok` 一直卡在 **~45%** 左右。用户要求: 1. 专注 frame 级预测,不引入任何 clip 级监督; 2. 分析 v8 / v8a 的 val CSV(epoch 9 / epoch 11),找出 class 准确率上不去的真正瓶颈; 3. 在不破坏现有代码框架和训练逻辑的前提下,按顺序落地 Fix A→D→E→B→C→F; 4. 全部作为 v9 新 preset + 新脚本,从 v8a `best.pt` 热启动(保证 epoch0 前向输出与 v8a 完全相同)。 本文档记录 2026-04-23 这一轮的 **所有代码修改细节**,以及每个修改背后的 CSV 诊断证据。 相关旧文档: - `docs/0422.md`:v7h → v8 架构升级(cross-attention fusion) - `docs/0422_v7h_v7j.md`:v7h / v7j / v7i 的 per-frame 多源诊断与 class-weighted CE 的引入 - `docs/0421.md`:v7f → ov123 per-frame 扩展与 frame-track CSV dump --- ## 1. CSV 诊断:v8 ep9 / v8a ep11 的错误模式 ### 1.1 整体准确率(按 "activity ≥ 0.5 的 active track 里 DOA 最近的一条" 取 class) | 指标 | v8 ep9 | v8a ep11 | |---|---|---| | overall `cls_ok` | 933/2051 = **45.5%** | 934/2051 = **45.5%** | | DOA≤20° 命中内 `cls_ok` | 48.1% | 51.9% | | **Oracle cls**(无视 activity,取 DOA 最近 track 的 class) | **38.7%** | **46.4%** | 两个关键数字几乎相等:`cls_ok(nearest active) ≈ oracle_cls`。这说明: > **activity 选的 track 和 DOA 选的 track 语义上一致**,K=4 的 track binding 是 ok 的;真正错的是 **"那条被选中的 track 自己的 class_logits 就预测错了"**。 换句话说:**瓶颈不在 matching,不在 activity,而在 class head 本身的输出分布**。 ### 1.2 按 ov 分组 v8 ep9: ``` ov2 : gt= 682 DOA_ok=51.6% cls_ok=37.1% any_track_cls_ok=38.9% cls&DOA=23.9% ov3 : gt=1158 DOA_ok=42.5% cls_ok=47.1% any_track_cls_ok=59.6% cls&DOA=18.0% other: gt= 211 DOA_ok=97.6% cls_ok=64.0% any_track_cls_ok=64.0% cls&DOA=63.0% ``` v8a ep11 类似,ov3 略退: ``` ov2 : cls_ok=44.0% cls&DOA=33.7% ov3 : cls_ok=43.1% cls&DOA=14.2% other: cls_ok=64.0% cls&DOA=63.0% ``` `other`(近似 ov1 单源)是 64%,ov2 是 37-44%,ov3 降到 43% 且 `cls&DOA` 只有 14%。**多源重叠帧的 class 比单源帧低 20 个点**,这是 ov3 demixing 失败的直接证据。 ### 1.3 每类错误模式(Oracle:无视 activity 取 DOA 最近) ``` aircraft (n=100): speech(50%) human_vocalization(50%) → 永远 0% 正确 vehicle (n= 68): machine(74%) aircraft(13%) train(10%) → 永远 0% 正确 frog (n= 50): bird(98~100%) → 永远预测成 bird speech (n= 53): human_vocalization(94%) breathing(6%) → 永远 0% 正确 crackle (n= 51): rain(53%) machine(41%) → 永远 0% 正确 wind (n= 16): fire(50~88%) vehicle/wood → 永远 0% 正确 drawer_cab(n= 50): tool(74~84%) → 永远 0% 正确 tape (n= 50): human_vocalization(88%) typing(46%) → 永远 0% 正确 knock (n= 36): home_sound(64%) human_vocalization → 永远 0% 正确 train (n=115): train(43%) crushing(37%) vehicle(15%) → ~10% ``` 错误全都是 **同父类 "sibling collapse"**: - `aircraft / vehicle / train` 同属 transportation; - `speech / human_vocalization / breathing / laughter` 同属 human-voice; - `frog / bird / insect` 同属 animal-vocal; - `drawer_cabinet / tool / home_sound / door` 同属 indoor-mechanical。 ### 1.4 Label bug 复核:frog→bird 是不是数据映射错了? 查了 `build_ov{1,23}_foa_dataset.py`、`final_vocabulary.csv`、`spatial_dataset.py:_resolve_class_index`,以及 jsonl 里具体样本: ``` ov1 train frog: 113 ov1 train bird: 3270 → 1:29 ov1 valid frog: 9 ov1 valid bird: 56 ov1 test frog: 3 ov1 test bird: 51 ov2 train frog:1078 ov2 train bird: 1255 ov3 train frog:1384 ov3 train bird: 1432 ``` **frog 标签完全正确**,不是 label bug。100% 预测成 bird 的根因是 ov1 训练集里 bird 比 frog 多 **29 倍**,纯粹的 class imbalance;其他两个 manifest (ov2/ov3) 已经平衡,但 ov1 训练数据量占比仍然不小,整体 prior 偏向 bird。 ## 2. 六个 Fix 的整体设计原则 1. **全部 additive**:v9 只 **增加** 新 config 字段与新子模块,从不删除或覆盖 v8a 已有的路径; 2. **全部 zero-init 或 identity-init**:新参数在 ckpt 加载时贡献 0(可验证 `v9(v8a.pt) - v8a(v8a.pt)` 前向最大 abs diff = 0); 3. **strict=False 兼容**:v9 的 18 个新参数在 v8a ckpt 中缺失,走现有 `_load_spatial_init_checkpoint` / `load_state_dict(..., strict=False)` 流程自然初始化; 4. **不改 forward signature 的兼容性**:`FrameTrackPredictionHeads.forward` 新增的几个关键字参数默认 `None`,旧调用链(如 v7 / v8 非-track readout)不受影响; 5. **所有 LR / schedule 默认行为等价于 v8a**:`class_head_lr_scale=1.0` 时走原 3-group 快速路径,`frame_class_ontology_smoothing=0.0` 时走原 `F.cross_entropy`。 ## 3. 文件级修改清单 | 文件 | 修改内容 | |---|---| | `spatial_loss.py` | 新增 `frame_class_ontology_smoothing`、`frame_class_ontology_groups` 字段;CE 分支支持本体软标签 | | `spatial_modules.py` | `FrameTrackPredictionHeads` 新增 6 个构造参数 + 新增 `ClassHeadSpectralDemixer` 模块 | | `spatial_beats.py` | `SpatialBEATsConfig` 新增 8 个字段;两处 `FrameTrackPredictionHeads` 构造传入新参数;forward 两处 call 传入 `pre_pool_features`;新增 `_derive_pre_pool_time_mask` helper | | `train_spatial_beats.py` | 新增 `class_head_lr_scale`、`class_head_freeze_during_ramp_epochs`、`class_head_lr_scale_during_ramp` 字段;`build_optimizer` 拆出 `cls_head` group;epoch loop 动态写 cls_head LR;新增 `_V9_CLASS_WEIGHTS`、`_V9_ONTOLOGY_GROUPS`、`make_ov1_local_spatial_v9_ov123_top4_config`;preset dispatch 和 `--preset` choices 注册 `v9_ov123_top4` | | `run_ov1_v9_ov123_top4.sh` | 新增启动脚本(chmod +x) | --- ## 4. Fix A —— frog/稀有类的诊断(无代码改动) Fix A 是诊断性质的,没有代码落地。结论:frog 不是 label bug,是 ov1 train 的 29:1 imbalance。解决方案被折进 Fix D: - `bird` 权重:1.0 → **0.6**(压 catch-all) - `frog` 权重:1.0 → **3.0**(提 minority) - `insect` 权重:从 v7I 的 4.0 → **1.0**(v8/v8a 已经 100%,不需要加权) ## 5. Fix D —— 重新设计 class 权重 `_V9_CLASS_WEIGHTS` ### 5.1 数据证据 v7I 的权重表给了 aircraft / vehicle / insect **4×** 权重: ``` v8 ep9: aircraft 0%, vehicle 0%, insect 100% v8a ep11: aircraft 0%, vehicle 0%, insect 100%, printer 100%→59%(退化!) ``` - aircraft / vehicle 的 4× 没有任何效果,它们的声学特征与 speech / machine 真实接近; - 4× 让误分类(aircraft→speech)的 loss 放大 4 倍,模型为了降 loss 把 **speech 的预测分布也拉保守**,导致 speech 也掉到 0%(双输); - v8a ep11 还有一个异常:printer 从 v8 ep9 的 100% 掉到 59%,原因见 Fix E。 ### 5.2 设计规则 - **"catch-all" 类**(GT frame 上被当万金油预测的类)→ 权重下调: - `human_vocalization` 0.4(被 speech / tape / typing / knock 当靶子) - `bird` 0.6(frog 几乎全部坍缩到 bird) - `machine` 0.5(vehicle / crackle 坍缩到 machine) - `rain` 0.5(crackle / singing 的 collapse 目标) - `breathing` 0.5(speech 94% 错分到这里) - `home_sound` 0.6(knock / printer 的 collapse) - **易被 collapse 的稀有类** → 权重上调(但不过分,≤ 3×): - `frog` 3.0 - `crackle` 2.0、`tape` 2.0、`knock` 2.0、`drawer_cabinet` 2.0、`speech` 2.0 - `aircraft` / `vehicle` / `train` 回到 1.0-1.5(4× 已证实无效且伤 sibling) - **过度自信的稳健类** → 略压: - `singing` 0.7、`printer` 0.7(在 ov123 任务里吸收错 FP) ### 5.3 代码改动 位置:`train_spatial_beats.py`(紧挨 `_V7I_CLASS_WEIGHTS` 之后) ```python _V9_CLASS_WEIGHTS: List[float] = [ # 0 wind_instrument 1 string_instrument 2 guitar 3 body_sound 1.0, 1.0, 1.0, 1.0, # 4 drum 5 water 6 human_vocalization 7 keyboard_instrument 1.0, 1.0, 0.4, 1.0, # 8 bird 9 tool 10 machine 11 war_sound 0.6, 1.0, 0.5, 1.0, # 12 metal_clink 13 breathing 14 laughter 15 percussion 1.0, 0.5, 1.0, 1.0, # 16 speech 17 bell 18 dog 19 vehicle 2.0, 1.0, 1.0, 1.5, # 20 alarm 21 footsteps 22 train 23 telephone_alarm 1.0, 1.0, 1.5, 1.0, # 24 glass 25 wind 26 kitchenware 27 animal 1.0, 1.5, 1.0, 1.0, # 28 musical_instrument 29 thunderstorm 30 door 31 male_speech 1.0, 1.0, 1.0, 1.0, # 32 female_speech 33 cat 34 home_sound 35 insect 1.0, 1.0, 0.6, 1.0, # 36 typing 37 zipper 38 camera 39 clock 1.0, 1.0, 1.0, 1.0, # 40 fire 41 singing 42 tearing 43 writing 1.0, 0.7, 1.0, 1.0, # 44 car 45 rain 46 scratch 47 gong 1.0, 0.5, 1.0, 1.0, # 48 appliance 49 paper 50 drawer_cabinet 51 ocean 1.0, 1.0, 2.0, 1.0, # 52 knock 53 crackle 54 finger_snapping 55 aircraft 2.0, 2.0, 1.0, 1.0, # 56 crushing 57 printer 58 tape 59 wood 1.0, 0.7, 2.0, 1.0, # 60 crack 61 cooking 62 frog 1.0, 1.0, 3.0, ] assert len(_V9_CLASS_WEIGHTS) == 63 ``` v9 preset 里:`cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)`。 ## 6. Fix E —— class head LR 单独分组 + DOA ramp 期冻结 ### 6.1 数据证据 v8a ep5 vs ep11(stage 2 DOA ramp 打开前后): ``` ep5 ov3: ocls=44.4% printer=100% ep11 ov3: ocls=44.4% printer=59% ← class 被 DOA 梯度扰动掉了 ``` 打开 dir/dist 的 Hungarian cost 和 loss 之后,class binding 被扰动。传统做法是整体降 LR,但那会连带把 trunk / decoder 的学习也拖慢。更好的做法是 **单独把 class_head 从 optimizer 拉出来,DOA ramp 期间冻结**。 ### 6.2 代码改动 #### 6.2.1 新增 config 字段(`train_spatial_beats.py`) ```python trunk_lr_scale: float = 1.0 spatial_lr_scale: float = 1.0 # v9: isolated LR multiplier for the class_head inside # frame_track_prediction_heads. When < 1.0 the class head is put in its # own param group with lr = base_lr * class_head_lr_scale. Used during # DOA ramp (stage 2) to prevent class binding from being perturbed by # the newly-unlocked dir/dist gradients. 1.0 = legacy behaviour. class_head_lr_scale: float = 1.0 # Optional epoch-range override that further scales the class head LR # specifically during the DOA ramp. When set, between # frame_spatial_loss_warmup_epochs and frame_spatial_loss_warmup_epochs # + class_head_freeze_during_ramp_epochs the class head LR is set to # class_head_lr_scale_during_ramp (defaults to 0.0 = frozen). After the # ramp window the LR returns to class_head_lr_scale. class_head_freeze_during_ramp_epochs: int = 0 class_head_lr_scale_during_ramp: float = 0.0 ``` #### 6.2.2 `build_optimizer` 拆出 cls_head group - 新增 `_CLASS_HEAD_PREFIXES`: ```python _CLASS_HEAD_PREFIXES = ( "frame_track_prediction_heads.class_head.", "frame_track_prediction_heads.class_head_mlp.", "frame_track_prediction_heads.class_head_demixer.", ) ``` - fast path 条件从 `trunk_scale == 1.0 and spatial_scale == 1.0` 改为再加 `and cls_head_scale == 1.0`; - 当 `cls_head_scale != 1.0` 时,匹配 `_CLASS_HEAD_PREFIXES` 的参数从 head_params 抽出,放入 `cls_head_params`; - 每个 param group 增加 `group_name` 字段(`"trunk"` / `"spatial"` / `"head"` / `"cls_head"`),便于 epoch loop 按 name 定位。 关键代码片段: ```python for name, param in model.named_parameters(): if not param.requires_grad: continue if name.startswith(_CLASS_HEAD_PREFIXES) and cls_head_scale != 1.0: cls_head_params.append(param) elif name.startswith(_TRUNK_PREFIXES): trunk_params.append(param) elif name.startswith(_SPATIAL_PREFIXES): spatial_params.append(param) else: head_params.append(param) param_groups = [] if trunk_params: param_groups.append({"params": trunk_params, "lr": base_lr * trunk_scale, "weight_decay": wd, "group_name": "trunk"}) if spatial_params: param_groups.append({"params": spatial_params, "lr": base_lr * spatial_scale, "weight_decay": wd, "group_name": "spatial"}) if head_params: param_groups.append({"params": head_params, "lr": base_lr, "weight_decay": wd, "group_name": "head"}) if cls_head_params:param_groups.append({"params": cls_head_params,"lr": base_lr * cls_head_scale, "weight_decay": wd, "group_name": "cls_head"}) ``` #### 6.2.3 epoch loop 动态写 cls_head LR 紧跟 spatial loss schedule 之后(在 `_log(f"[Epoch {epoch}] start")` 之前): ```python _cls_ramp_len = int(train_cfg.class_head_freeze_during_ramp_epochs) if _cls_ramp_len > 0 and _sp_warmup > 0 and train_cfg.class_head_lr_scale != 1.0: in_ramp = _sp_warmup <= epoch < _sp_warmup + _cls_ramp_len if in_ramp: _cls_scale = train_cfg.class_head_lr_scale_during_ramp else: _cls_scale = train_cfg.class_head_lr_scale for _g in optimizer.param_groups: if _g.get("group_name") == "cls_head": _g["lr"] = train_cfg.learning_rate * _cls_scale _log( f"[Epoch {epoch}] cls_head_lr scale={_cls_scale:.3f} " f"lr={train_cfg.learning_rate * _cls_scale:.2e} " f"(ramp_window={_sp_warmup}..{_sp_warmup + _cls_ramp_len - 1})" ) ``` #### 6.2.4 v9 preset 设定 ```python cfg.class_head_lr_scale = 0.3 cfg.class_head_freeze_during_ramp_epochs = 4 cfg.class_head_lr_scale_during_ramp = 0.0 ``` v9 的 `SPATIAL_LR=1.5e-5`,各阶段 cls_head LR: | 阶段 | epoch | lambda_dir | cls_head_lr | |---|---|---|---| | stage 1(class-only warmup) | 0-2 | 0.0 | 4.5e-6 | | stage 2(DOA ramp,cls_head 冻结) | 3-6 | ramp 0→1 | **0.0** | | stage 3(全放开) | 7+ | 1.0 | 4.5e-6 | ## 7. Fix B —— 本体论(hierarchical)标签平滑 ### 7.1 核心思想 v8/v8a 的 class 错误 **≥70% 是 sibling collapse**(同 AudioSet 父类)。Hard CE 把 `frog→bird` 和 `frog→aircraft` 一视同仁地惩罚满 log loss,这不合理: - 对下游 LLM 而言 `frog↔bird` 混淆可以靠语义上下文恢复; - `aircraft↔speech` 这种跨域错误则完全荒谬。 希望 loss 的惩罚强度 **匹配错误的 "语义距离"**。最简单的实现:在同父类内部做 label smoothing,跨父类保持硬 CE。 ### 7.2 代码改动 #### 7.2.1 `SpatialLossConfig` 新字段(`spatial_loss.py`) ```python # v9 hierarchical label smoothing for the frame-track class head. # When frame_class_ontology_smoothing > 0, the CE target becomes a soft # label distribution: # target[c_gt] = 1 - eps # target[c_sib] = eps / |siblings| (for each sibling in same ontology # group as c_gt; excludes c_gt itself) # target[c_other] = 0 # ... frame_class_ontology_smoothing: float = 0.0 # Parallel list of sibling groups. Each entry is a list of class indices # belonging to the same AudioSet ontology parent. A class may appear in # only one group. Empty list = no hierarchical smoothing. frame_class_ontology_groups: List[List[int]] = None ``` `__post_init__` 把 `None` 替换为 `[]` 以防意外。 #### 7.2.2 CE 分支改写(`spatial_loss.py:compute_frame_track_losses`) 原有分支: ```python loss_class = F.cross_entropy( class_logits_flat, class_target_flat, weight=_cls_weights ) ``` 改为: ```python eps_onto = float(config.frame_class_ontology_smoothing) onto_groups = config.frame_class_ontology_groups if eps_onto > 0.0 and onto_groups: num_classes = class_logits_flat.size(-1) # Build a [C, C] soft-target "mixing" table on first use and # cache it on the config object to avoid per-batch rebuild. if ( not hasattr(config, "_onto_mixing_table") or config._onto_mixing_table is None or config._onto_mixing_table.shape[0] != num_classes or config._onto_mixing_table.dtype != class_logits_flat.dtype or config._onto_mixing_table.device != device ): table = torch.zeros((num_classes, num_classes), dtype=class_logits_flat.dtype, device=device) table.fill_diagonal_(1.0) for group in onto_groups: members = [int(c) for c in group if 0 <= int(c) < num_classes] if len(members) < 2: continue for c in members: siblings = [s for s in members if s != c] table[c].zero_() table[c, c] = 1.0 - eps_onto sib_mass = eps_onto / len(siblings) for s in siblings: table[c, s] = sib_mass config._onto_mixing_table = table soft_target = config._onto_mixing_table[class_target_flat] log_probs = F.log_softmax(class_logits_flat, dim=-1) if _cls_weights is not None: sample_w = _cls_weights[class_target_flat] per_sample_loss = -(soft_target * log_probs).sum(dim=-1) loss_class = (per_sample_loss * sample_w).sum() / sample_w.sum().clamp_min(1e-8) else: loss_class = -(soft_target * log_probs).sum(dim=-1).mean) else: loss_class = F.cross_entropy(class_logits_flat, class_target_flat, weight=_cls_weights) ``` 实现要点: - **mixing table 缓存**:表存在 `config` 对象上(不是模块),第一次构建并缓存,之后按形状/device/dtype 重用,开销可忽略; - **per-class weight 兼容**:当同时启用 ontology smoothing 和 class weights 时,正类权重照常生效(按 GT hard class idx 加权每样本 loss); - **不在 group 里的类**:table 默认对角线 = 1,所以未列入任何 group 的类自然退化为 hard one-hot,零影响; - **完全开关**:`frame_class_ontology_smoothing=0` 或 `frame_class_ontology_groups=[]` 都会走原 `F.cross_entropy`,v8/v8a 训练复现不受影响。 ### 7.3 Ontology groups(`_V9_ONTOLOGY_GROUPS`) 位置:`train_spatial_beats.py`,v9 preset 之前。 ```python _V9_ONTOLOGY_GROUPS: List[List[int]] = [ # transportation: aircraft, vehicle, train, car [55, 19, 22, 44], # human voice (non-singing): speech, human_vocalization, male_speech, # female_speech, breathing, laughter [16, 6, 31, 32, 13, 14], # animal vocal: bird, frog, insect, dog, cat, animal [8, 62, 35, 18, 33, 27], # indoor mechanical + appliances: tool, machine, appliance, printer, # home_sound, door, drawer_cabinet, kitchenware, camera, clock, typing, # zipper, tape, cooking [9, 10, 48, 57, 34, 30, 50, 26, 38, 39, 36, 37, 58, 61], # percussive / impact: knock, footsteps, crack, crackle, crushing, # scratch, finger_snapping, tearing, writing, paper [52, 21, 60, 53, 56, 46, 54, 42, 43, 49], # weather / water / ambience: wind, rain, thunderstorm, ocean, water, # fire, glass, metal_clink, wood [25, 45, 29, 51, 5, 40, 24, 12, 59], # musical instruments: wind_instrument, string_instrument, guitar, drum, # keyboard_instrument, percussion, musical_instrument, gong, bell, # singing [0, 1, 2, 4, 7, 15, 28, 47, 17, 41], # alarms / signals: alarm, telephone_alarm, war_sound [20, 23, 11], ] ``` **8 组,覆盖 62/63 个 class**(只有 `body_sound` 没归组,走 hard CE)。每个 class 仅出现在一个组中。 v9 preset: ```python cfg.loss.frame_class_ontology_smoothing = 0.1 cfg.loss.frame_class_ontology_groups = [list(g) for g in _V9_ONTOLOGY_GROUPS] ``` ## 8. Fix C —— 频谱级 demixing cross-attention ### 8.1 动机 - `track_time_features[B, K, T_s, D]` 是 `SourceQueryDecoder` 从 `fused_embeddings[B, T_s, D]` 里 decode 出来的;`fused_embeddings` 在 `frequency_pool` 之后,**频率维已经被池化掉了**; - DOA 能在多源重叠帧 demix(IV 通道物理上就编码了方向),但 class 没有等价物理解混通路; - 需要让每个 track latent 能"回看"池化前的 trunk 输出,从 F_p 个频率 token 里挑自己负责的那部分。 ### 8.2 模型结构 - 输入: - `track_time_features: [B, K, T_s, D]`(已经过 `input_norm`) - `pre_pool_features: [B, T_p * F_p, D]`(BEATs trunk 输出,无 task tokens,未经 frequency_pool) - `pre_pool_grid_size: (T_p, F_p)` - `pre_pool_time_mask: [B, T_p]`(True = 有效时间步) - 时间对齐:frame `t ∈ [0, T_s)` → trunk 时间步 `t_p = round(t * T_p / T_s)`,clip 到 `[0, T_p-1]`; - KV 构造:`kv_grid[:, t_p, :, :]` → `[B, T_s, F_p, D]`,然后在 K 轴 expand 到 `[B, K, T_s, F_p, D]`,flatten 成 `[B*K*T_s, F_p, D]`; - Query:`track_time_features` reshape 成 `[B*K*T_s, 1, D]`; - 经 1 层 `nn.MultiheadAttention`(`num_heads=8`, `dropout=0.1`,`batch_first=True`),再 `out_proj(Linear 768→768)`,乘以标量 `gate`,加回 class head 的 `class_input`。 ### 8.3 初始化策略(关键) - `out_proj.weight` 全零、`out_proj.bias` 全零 → **加载时 demixer 输出 = 0**,class_logits 与 v8a 完全相同; - `gate = 1e-2`(**不是 0**!)→ 前向 = `gate * 0 = 0`(身份保证),但 `∂L/∂out_proj.weight = gate × ...` **非零**,梯度从 step 0 就能流进 demixer 的 attention 权重;这是关键的"gradient-warmup trick",否则两端 zero 会把 demixer 永久冻在 0。 ### 8.4 `ClassHeadSpectralDemixer` 源码(`spatial_modules.py`) ```python class ClassHeadSpectralDemixer(nn.Module): def __init__(self, embed_dim=768, num_layers=1, num_heads=8, dropout=0.1): super().__init__() self.embed_dim = embed_dim self.num_layers = max(1, int(num_layers)) self.kv_norm = nn.LayerNorm(embed_dim) self.q_norm = nn.LayerNorm(embed_dim) self.layers = nn.ModuleList([ nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout, batch_first=True) for _ in range(self.num_layers) ]) self.out_proj = nn.Linear(embed_dim, embed_dim) nn.init.zeros_(self.out_proj.weight) nn.init.zeros_(self.out_proj.bias) self.gate = nn.Parameter(torch.full((1,), 1e-2)) def forward(self, track_time_features, pre_pool_features, pre_pool_grid_size, pre_pool_time_mask=None): B, K, T_s, D = track_time_features.shape T_p, F_p = int(pre_pool_grid_size[0]), int(pre_pool_grid_size[1]) if pre_pool_features.size(-1) != D: raise ValueError(...) expected = T_p * F_p if pre_pool_features.size(1) != expected: # Fall back gracefully — demixer is additive & zero-gated. return track_time_features.new_zeros(track_time_features.shape) kv_grid = pre_pool_features.view(B, T_p, F_p, D) if T_s > 0 and T_p > 0: time_idx = torch.arange(T_s, device=kv_grid.device).float() * (T_p / max(1, T_s)) time_idx = time_idx.round().clamp_(0, T_p - 1).long() else: time_idx = torch.zeros((T_s,), dtype=torch.long, device=kv_grid.device) kv_per_frame = kv_grid[:, time_idx, :, :] kv_per_frame = kv_per_frame.unsqueeze(1).expand(B, K, T_s, F_p, D).contiguous() kv_flat = kv_per_frame.view(B * K * T_s, F_p, D) kv_flat = self.kv_norm(kv_flat) q_flat = track_time_features.reshape(B * K * T_s, 1, D) q_flat = self.q_norm(q_flat) key_padding_mask = None if pre_pool_time_mask is not None: per_frame_valid = pre_pool_time_mask[:, time_idx] per_frame_valid = per_frame_valid.unsqueeze(1).expand(B, K, T_s).reshape(-1) if not per_frame_valid.all(): ignore = ~per_frame_valid key_padding_mask = ignore.unsqueeze(1).expand(-1, F_p).contiguous() attn_out = q_flat for layer in self.layers: attn_out, _ = layer(attn_out, kv_flat, kv_flat, key_padding_mask=key_padding_mask, need_weights=False) residual = self.out_proj(attn_out).view(B, K, T_s, D) return residual * self.gate ``` ### 8.5 `_derive_pre_pool_time_mask` helper(`spatial_beats.py`) 紧跟 `_build_patch_padding_mask` 之后新增: ```python def _derive_pre_pool_time_mask( self, patch_padding_mask: Optional[Tensor], grid_size: Tuple[int, int], ) -> Optional[Tensor]: """Return a [B, T_p] boolean mask where True marks *valid* trunk time steps. Used by the v9 class-head spectral demixer to ignore padded tail frames.""" if patch_padding_mask is None: return None t_p, f_p = grid_size B = patch_padding_mask.size(0) pad_grid = patch_padding_mask.view(B, t_p, f_p) time_valid = ~pad_grid.all(dim=-1) return time_valid ``` `patch_padding_mask` 的语义是 `True = padded`,time-valid 是 "某时间步还有非 padded 频率位置" → `~pad_grid.all(dim=-1)`。 ### 8.6 forward 两处 call 更新(`spatial_beats.py`) 两处调用 `self.frame_track_prediction_heads(...)` 都扩展: ```python _pre_pool_time_mask = self._derive_pre_pool_time_mask( patch_padding_mask=patch_padding_mask, grid_size=grid_size, ) frame_track_prediction_output = self.frame_track_prediction_heads( track_time_features=track_time_features, track_latents=track_latents, pre_pool_features=encoder_memory, pre_pool_grid_size=grid_size, pre_pool_time_mask=_pre_pool_time_mask, ) ``` - 第 1 处在 `readout_scheme == "local_spatial"` 分支下 frame-track parallel 路径; - 第 2 处在 `readout_scheme == "local_spatial_track"` 分支下纯 per-frame 路径(v9 走这里)。 两处都用到的上下文变量: - `encoder_memory`:`self.encode_patches(...)` 的输出,`[B, T_p*F_p, D]`,**已经过 trunk 但未 frequency_pool**,正好是 demixer 需要的 pre_pool features; - `grid_size`:`self.extract_patch_tokens(...)` 返回的 `(T_p, F_p)`; - `patch_padding_mask`:`self._build_patch_padding_mask(...)` 返回。 ## 9. Fix F —— 2-layer MLP 残差 ### 9.1 动机 当前 class head 是 `nn.Linear(768, 63)`,对多源混合 token 的表达能力可能不够。加一个 2-layer MLP 残差 branch: ``` class_logits = class_head(x) + gate * class_head_mlp(x) class_head_mlp = Linear(768, 1536) → GELU → Dropout → LayerNorm → Linear(1536, 63) ``` ### 9.2 初始化策略 与 Fix C 的 demixer 同构: - `class_head_mlp[-1].weight / bias` 全零 → 残差输出 = 0; - `class_head_mlp_gate = 1e-2` → 前向仍然 = 0,但梯度可流进 MLP 最后一层(非零); - 经过 1-2 个 step MLP 最后一层有非零权重后,前一层(GELU 前)也开始获得梯度。 ### 9.3 `FrameTrackPredictionHeads` 重构(`spatial_modules.py`) 构造签名扩展: ```python def __init__( self, embed_dim: int = 768, num_classes: int = 63, dropout: float = 0.1, use_class_head_mlp_residual: bool = False, class_head_mlp_hidden_multiplier: int = 2, class_head_mlp_dropout: float = 0.1, use_class_head_demixer: bool = False, class_head_demixer_layers: int = 1, class_head_demixer_heads: int = 8, class_head_demixer_dropout: float = 0.1, ) -> None: ``` 构造体内新增: ```python self.use_class_head_mlp_residual = bool(use_class_head_mlp_residual) if self.use_class_head_mlp_residual: hidden = embed_dim * max(1, int(class_head_mlp_hidden_multiplier)) self.class_head_mlp = nn.Sequential( nn.Linear(embed_dim, hidden), nn.GELU(), nn.Dropout(class_head_mlp_dropout), nn.LayerNorm(hidden), nn.Linear(hidden, num_classes), ) nn.init.zeros_(self.class_head_mlp[-1].weight) nn.init.zeros_(self.class_head_mlp[-1].bias) self.class_head_mlp_gate = nn.Parameter(torch.full((1,), 1e-2)) else: self.class_head_mlp = None self.class_head_mlp_gate = None self.use_class_head_demixer = bool(use_class_head_demixer) if self.use_class_head_demixer: self.class_head_demixer = ClassHeadSpectralDemixer( embed_dim=embed_dim, num_layers=class_head_demixer_layers, num_heads=class_head_demixer_heads, dropout=class_head_demixer_dropout, ) else: self.class_head_demixer = None ``` forward 改写: ```python def forward( self, track_time_features: Tensor, track_latents: Tensor, pre_pool_features: Optional[Tensor] = None, pre_pool_grid_size: Optional[Tuple[int, int]] = None, pre_pool_time_mask: Optional[Tensor] = None, ) -> FrameTrackPredictionOutput: ... x = self.input_norm(track_time_features) activity = self.activity_head(x).squeeze(-1) class_input = x if ( self.class_head_demixer is not None and pre_pool_features is not None and pre_pool_grid_size is not None ): demix_residual = self.class_head_demixer( track_time_features=x, pre_pool_features=pre_pool_features, pre_pool_grid_size=pre_pool_grid_size, pre_pool_time_mask=pre_pool_time_mask, ) class_input = class_input + demix_residual class_logits = self.class_head(class_input) if self.class_head_mlp is not None and self.class_head_mlp_gate is not None: class_logits = class_logits + self.class_head_mlp_gate * self.class_head_mlp(class_input) direction = F.normalize(self.direction_head(x), dim=-1) distance = F.softplus(self.distance_head(x)).squeeze(-1) ... ``` **关键细节**: - demixer 的 residual 加到 `class_input` 上(即 class_head 的输入),而不是加到 logits 上;这样 demixer 得到的"纠偏"信号先经过 `class_head` 的共享投影再生成 logits; - MLP 分支也看 `class_input`(包含 demixer residual),这样 demixer 的信息同样能进入 MLP 分支; - `direction_head` 和 `distance_head` 的输入 `x` 不加 demixer(demixer 是 class-specific),保持 DOA head 对 v8a ckpt 的完全等价。 ### 9.4 `SpatialBEATsConfig` 新字段(`spatial_beats.py`) ```python self.frame_track_dropout: float = 0.1 self.frame_accdoa_hidden_dim: int = 256 self.frame_accdoa_dropout: float = 0.1 # v9: optional zero-initialised MLP residual branch inside # FrameTrackPredictionHeads. ... self.use_class_head_mlp_residual: bool = False self.class_head_mlp_hidden_multiplier: int = 2 self.class_head_mlp_dropout: float = 0.1 # v9: optional spectral demixing cross-attention branch. ... self.use_class_head_demixer: bool = False self.class_head_demixer_layers: int = 1 self.class_head_demixer_heads: int = 8 self.class_head_demixer_dropout: float = 0.1 ``` ### 9.5 两处构造都传入新参数 `spatial_beats.py` 两处 `FrameTrackPredictionHeads(...)` 调用都改为: ```python self.frame_track_prediction_heads = FrameTrackPredictionHeads( embed_dim=cfg.encoder_embed_dim, num_classes=cfg.source_num_classes, dropout=cfg.frame_track_dropout, use_class_head_mlp_residual=cfg.use_class_head_mlp_residual, class_head_mlp_hidden_multiplier=cfg.class_head_mlp_hidden_multiplier, class_head_mlp_dropout=cfg.class_head_mlp_dropout, use_class_head_demixer=cfg.use_class_head_demixer, class_head_demixer_layers=cfg.class_head_demixer_layers, class_head_demixer_heads=cfg.class_head_demixer_heads, class_head_demixer_dropout=cfg.class_head_demixer_dropout, ) ``` ## 10. v9 preset 与启动脚本 ### 10.1 `make_ov1_local_spatial_v9_ov123_top4_config`(`train_spatial_beats.py`) 位置:v7k 系列的最后一个 preset 之后,`v3: top-8 unfreeze` 注释分割线之前。 ```python def make_ov1_local_spatial_v9_ov123_top4_config( ov1_manifest_path: str = DEFAULT_OV1_MANIFEST, ov2_manifest_path: str = DEFAULT_OV2_MANIFEST, ov3_manifest_path: str = DEFAULT_OV3_MANIFEST, ) -> TrainSpatialBEATsConfig: """v9 = v8a + class-first cleanup (fixes A..F). Inherits v8a (cross-attn fusion + segment matching + 4-epoch DOA ramp), applies: - _V9_CLASS_WEIGHTS (suppress catch-all classes, boost frog/crackle/tape) - ontology-aware label smoothing (eps=0.1) - class head residual MLP + spectral demixer (both zero-init) - class_head_lr_scale=0.3 with full freeze during the 4-epoch DOA ramp Frontend / trunk / source_query_decoder / activity / dir / dist heads are unchanged. Hot-start from v8a best.pt works with strict=False. """ cfg = make_ov1_local_spatial_v8a_ov123_top4_config( ov1_manifest_path=ov1_manifest_path, ov2_manifest_path=ov2_manifest_path, ov3_manifest_path=ov3_manifest_path, ) # (D) Re-balanced class weights driven by v8/v8a CSV confusion analysis. cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS) # (B) Hierarchical (ontology-aware) label smoothing. cfg.loss.frame_class_ontology_smoothing = 0.1 cfg.loss.frame_class_ontology_groups = [list(g) for g in _V9_ONTOLOGY_GROUPS] # (F) Zero-gated MLP residual on the class head. cfg.model.use_class_head_mlp_residual = True cfg.model.class_head_mlp_hidden_multiplier = 2 cfg.model.class_head_mlp_dropout = 0.1 # (C) Zero-gated spectral demixing cross-attention on the class head. cfg.model.use_class_head_demixer = True cfg.model.class_head_demixer_layers = 1 cfg.model.class_head_demixer_heads = 8 cfg.model.class_head_demixer_dropout = 0.1 # (E) Class head gets its own LR group. cfg.class_head_lr_scale = 0.3 cfg.class_head_freeze_during_ramp_epochs = 4 cfg.class_head_lr_scale_during_ramp = 0.0 cfg.num_epochs = 12 cfg.output_dir = "checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/03_ov123_top4" return cfg ``` ### 10.2 preset 注册 `train_spatial_beats.py` 中两处: 1. 在 `args.preset == "ov1_local_spatial_v8a_ov123_top4"` 分支之后,新增: ```python elif args.preset == "ov1_local_spatial_v9_ov123_top4": cfg = make_ov1_local_spatial_v9_ov123_top4_config( ov1_manifest_path=args.ov1_manifest, ov2_manifest_path=args.ov2_manifest, ov3_manifest_path=args.ov3_manifest, ) ``` 2. argparse `--preset` 的 `choices=(...)` 列表里在 `"ov1_local_spatial_v8a_ov123_top4"` 之后加入 `"ov1_local_spatial_v9_ov123_top4"`。 ### 10.3 `run_ov1_v9_ov123_top4.sh` 新建文件(chmod +x): ```bash #!/usr/bin/env bash set -euo pipefail # v9_ov123_top4: v8a + class-first cleanup (Fix A..F from docs/0423.md analysis) # 所有 fix 都是 additive + zero-init,从 v8a best.pt 热启动 epoch-0 输出与 v8a 完全相同 GPUS="${GPUS:-8}" BATCH_SIZE="${BATCH_SIZE:-8}" NUM_WORKERS="${NUM_WORKERS:-8}" SPATIAL_EPOCHS="${SPATIAL_EPOCHS:-12}" SPATIAL_LR="${SPATIAL_LR:-1.5e-5}" AMP="${AMP:-fp32}" OV1_MANIFEST="${OV1_MANIFEST:-/apdcephfs_cq10/.../ov1_foa.jsonl}" OV2_MANIFEST="${OV2_MANIFEST:-/apdcephfs_cq10/.../ov2_foa.jsonl}" OV3_MANIFEST="${OV3_MANIFEST:-/apdcephfs_cq10/.../ov3_foa.jsonl}" RESUME_CKPT="${RESUME_CKPT:-checkpoints/spatial_beats_ov1_local_spatial_v8a_ov123_exp/03_ov123_top4/best.pt}" OUT_DIR="${OUT_DIR:-checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/03_ov123_top4}" torchrun --nproc_per_node="${GPUS}" --master-port="${MASTER_PORT:-29557}" train_spatial_beats.py \ --preset ov1_local_spatial_v9_ov123_top4 \ --resume "${RESUME_CKPT}" \ --output-dir "${OUT_DIR}" \ --ov1-manifest "${OV1_MANIFEST}" \ --ov2-manifest "${OV2_MANIFEST}" \ --ov3-manifest "${OV3_MANIFEST}" \ --batch-size "${BATCH_SIZE}" \ --num-workers "${NUM_WORKERS}" \ --num-epochs "${SPATIAL_EPOCHS}" \ --learning-rate "${SPATIAL_LR}" \ --amp "${AMP}" \ --no-resume-optimizer \ --reset-epoch-on-resume \ --reset-best-on-resume ``` --- ## 11. 正确性验证(已完成) ### 11.1 语法检查 四个 py 文件(`train_spatial_beats.py` / `spatial_loss.py` / `spatial_beats.py` / `spatial_modules.py`)+ `run_ov1_v9_ov123_top4.sh` 全部通过 ast.parse / bash -n。 ### 11.2 模块级单测 `FrameTrackPredictionHeads` + `ClassHeadSpectralDemixer`: - 开启所有 v9 选项后,**max abs diff = 0.00e+00** 相对于无 v9 选项的同一模型; - 带 padding mask 时 `pred_class_logits` 无 NaN; - 打开 MLP gate / demixer gate + 轻微扰动权重后,logits 差异 ~1e-2(符合预期)。 ### 11.3 v9 vs v8a 端到端前向恒等 ```python torch.manual_seed(42) m_v8a = SpatialBEATs(make_ov1_local_spatial_v8a_ov123_top4_config().model).eval() torch.manual_seed(42) m_v9 = SpatialBEATs(make_ov1_local_spatial_v9_ov123_top4_config().model).eval() # 两个模型都 load v8a/best.pt,strict=False # 同一 waveform 输入 max_abs_diff = (o_v8a - o_v9).abs().max() # 实测:0.00e+00 ``` **v9 加载 v8a ckpt 后的 `pred_class_logits` 与 v8a 模型加载同一 ckpt 的输出逐元素完全相等**(0.00e+00)。 ### 11.4 v8a ckpt 加载统计 ``` ckpt keys = 425 loadable (v9 shapes) = 425 v9 missing params = 18 ← class_head_mlp.* (7 个) + class_head_demixer.* (11 个) ckpt unexpected = 0 ``` 18 个新参数通过 `strict=False` 默认初始化(按 Fix C / F 的 zero-out-proj + tiny-gate 策略)。 ### 11.5 梯度流 假 CE loss + backward: - `class_head.weight`:grad_norm ≈ 9.06(正常) - `class_head_mlp.4.weight`(MLP 最后一层,zero-init):grad_norm ≈ 1.32e-1(**非零,因为 gate = 1e-2**) - `class_head_mlp.0.weight`(MLP 第一层):grad_norm = 0(正常,要等最后一层非零后才有梯度,1 step 内开解) - `class_head_demixer.out_proj.weight`:grad_norm ≈ 1.60e-2(**非零**) - `class_head_demixer.layers.0.in_proj_weight`:grad_norm = 0(同理) ### 11.6 Optimizer 分组 ``` group 0: name=trunk lr=3.00e-06 n_params=238 group 1: name=spatial lr=9.00e-06 n_params= 84 group 2: name=head lr=3.00e-05 n_params= 91 group 3: name=cls_head lr=9.00e-06 n_params= 19 ``` (上例 base_lr=3e-5;v9 启动脚本里 SPATIAL_LR=1.5e-5,对应 cls_head lr = 4.5e-6。) ## 12. 关键训练时间线(SPATIAL_EPOCHS=12) 继承 v8a 的 `frame_spatial_loss_warmup_epochs=3`、`frame_spatial_loss_ramp_epochs=4`,叠加 v9 新的 cls_head LR 调度: | epoch | lambda_dir / dist | dir/dist match cost | cls_head lr | 说明 | |---|---|---|---|---| | 0-2 | 0.0 | 0.0 | 4.5e-6 | Stage 1:class-only warmup,cls_head 低 LR 继续微调(但本来就 v8a 延续) | | 3 | ramp 0.25 | 0.25 | **0.0** | Stage 2 起点,cls_head 冻结 | | 4 | ramp 0.50 | 0.50 | **0.0** | cls_head 冻结 | | 5 | ramp 0.75 | 0.75 | **0.0** | cls_head 冻结 | | 6 | ramp 1.00 | 1.00 | **0.0** | cls_head 冻结,DOA 全开 | | 7-11 | 1.0 | 1.0 | 4.5e-6 | Stage 3:全部放开,cls_head 低 LR 微调 | ## 13. 验证清单 / 观察优先级 启动后按以下顺序观察指标: 1. **epoch 0 validation metrics 是否与 v8a 的最后一个 epoch 等价** - 预期 val loss ≈ v8a best val loss(因为前向 = v8a) - 若不等价,说明 hot-start 出问题; 2. **epoch 0-2(stage 1)**: - `ocls`(oracle class acc)是否比 v8a ep2 同期高?关键指标,主要受 Fix D + B + C + F 驱动; - 若 Fix D 生效:aircraft / vehicle 不再 0%,frog 不再 100%→bird; - 若 Fix B 生效:smooth CE loss 比 hard CE 降得更快(从 1 - eps 开始); - 若 Fix C/F 生效:每 epoch 训完后 `class_head_mlp_gate` / `class_head_demixer.gate` 应从 0.01 慢慢增长; 3. **epoch 3-6(stage 2,cls_head 冻结)**: - DOA 指标改善(`LE_CD` 下降,`oazi` 下降),`ocls` **不应退化**(cls_head LR=0,梯度不回传); - 若 `ocls` 退化:可能是 fusion/trunk 的梯度通过 demixer 影响了 class_head 输入分布,此时需要考虑把 demixer 也一起冻; 4. **epoch 7+(stage 3,全开)**: - `F20` 是否突破 v8a 的上限(v7h 基线 0.246,v8a ~ 0.22-0.25); - ov3 class_ok 是否从 43% 往上走; - 若 ov2/ov3 class_ok 都有改善但 ov3 仍然落后:Fix C 的频谱 demixing 还需要再加层数或 heads。 ## 14. 代码 diff 总览 ### 14.1 `spatial_loss.py` - 新增 2 个 `SpatialLossConfig` 字段:`frame_class_ontology_smoothing`、`frame_class_ontology_groups`; - `__post_init__` 处理 `frame_class_ontology_groups = None`; - `compute_frame_track_losses` 里 CE 分支支持 soft target + class weight 复合。 ### 14.2 `spatial_modules.py` - `FrameTrackPredictionHeads.__init__` 新增 6 个 kwargs; - `FrameTrackPredictionHeads.forward` 新增 3 个 Optional kwargs(`pre_pool_features` / `pre_pool_grid_size` / `pre_pool_time_mask`),内部整合 demixer + MLP residual; - 新增 `ClassHeadSpectralDemixer` 模块。 ### 14.3 `spatial_beats.py` - `SpatialBEATsConfig.__init__` 新增 8 个 `self.*` 字段; - 两处 `FrameTrackPredictionHeads(...)` 构造传入全部新参数; - 两处 `self.frame_track_prediction_heads(...)` 调用传入 pre-pool 参数; - 新增 `SpatialBEATs._derive_pre_pool_time_mask` 实例方法。 ### 14.4 `train_spatial_beats.py` - `TrainSpatialBEATsConfig` 新增 3 个字段:`class_head_lr_scale`、`class_head_freeze_during_ramp_epochs`、`class_head_lr_scale_during_ramp`; - 新增常量 `_V9_CLASS_WEIGHTS`(63 长度)+ `_V9_ONTOLOGY_GROUPS`(8 组,覆盖 62/63 class); - 新增 preset 函数 `make_ov1_local_spatial_v9_ov123_top4_config`; - `build_optimizer` 新增 `_CLASS_HEAD_PREFIXES`、`cls_head_params` 分组、`group_name` 标签、fast-path 条件更新; - epoch loop 动态写 cls_head LR(紧跟 spatial loss schedule 后); - preset dispatch 和 `--preset` argparse choices 注册 `ov1_local_spatial_v9_ov123_top4`。 ### 14.5 `run_ov1_v9_ov123_top4.sh` - 新建 shell 脚本,默认 8 GPU / BS=8 / 12 epoch / LR=1.5e-5 / fp32 / master_port=29557; - 默认 RESUME_CKPT 指向 v8a best.pt; - 复用 `--no-resume-optimizer --reset-epoch-on-resume --reset-best-on-resume` 三开关。 ## 15. 关键文件与行数 | 文件 | 新增代码主要位置 | |---|---| | `spatial_loss.py` | `SpatialLossConfig` 新增字段紧跟 `frame_class_loss_weights` 之后;CE 分支改写在 `compute_frame_track_losses` 内 | | `spatial_modules.py` | `FrameTrackPredictionHeads` 整块重写;`ClassHeadSpectralDemixer` 定义在 `FrameTrackPredictionHeads` 之后 | | `spatial_beats.py` | `SpatialBEATsConfig.__init__` 的 v9 字段紧跟 `frame_accdoa_dropout`;`_derive_pre_pool_time_mask` 紧跟 `_build_patch_padding_mask` | | `train_spatial_beats.py` | `_V9_CLASS_WEIGHTS` 紧跟 `_V7I_CLASS_WEIGHTS` 之后;`_V9_ONTOLOGY_GROUPS` 和 `make_ov1_local_spatial_v9_ov123_top4_config` 紧跟 v7k 系列结束处;`class_head_*_scale` 字段紧跟 `spatial_lr_scale` 之后 | ## 16. 后续计划 如果 v9 ep3 `ocls` 仍不超过 v8a 同期,下一步(v9a/v9b)候选: 1. **demixer 加层**:`class_head_demixer_layers = 2` 或增加 heads 到 16; 2. **demixer 冻结共进退**:DOA ramp 期把 `class_head_demixer.*` 也当作 cls_head group 的一部分(目前 `_CLASS_HEAD_PREFIXES` 已经包含它,在 ramp 期也会冻结——已经生效,但需要验证效果); 3. **ontology smoothing eps 提升**:0.1 → 0.2,进一步容忍 sibling collapse; 4. **source_query_decoder 加正交正则**(前一版建议):`L_orth = ||Q Q^T - I||_F^2`,λ=0.01; 5. **重新审视 matching**:如果 cls 瓶颈解了但 F20 仍不涨,回到 `docs/0422.md` 的 track-dead / duplicate 诊断。 但当前 v9 已经把 "class head 本身" 这一块做得比较彻底,应该先跑满 12 epoch 再决定。