Spatial-BEATs / docs /0423.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
45.5 kB

2026-04-23 — v9_ov123_top4:class-first 针对性修复(A→D→E→B→C→F)

0. 背景与本轮任务

v8 / v8a 在 03_ov123_top4 上完成了 10+ 个 epoch 的训练,整体 spatial 指标比 v7h 有小幅改善,但 class_ok 一直卡在 ~45% 左右。用户要求:

  1. 专注 frame 级预测,不引入任何 clip 级监督;
  2. 分析 v8 / v8a 的 val CSV(epoch 9 / epoch 11),找出 class 准确率上不去的真正瓶颈;
  3. 在不破坏现有代码框架和训练逻辑的前提下,按顺序落地 Fix A→D→E→B→C→F;
  4. 全部作为 v9 新 preset + 新脚本,从 v8a best.pt 热启动(保证 epoch0 前向输出与 v8a 完全相同)。

本文档记录 2026-04-23 这一轮的 所有代码修改细节,以及每个修改背后的 CSV 诊断证据。

相关旧文档:

  • docs/0422.md:v7h → v8 架构升级(cross-attention fusion)
  • docs/0422_v7h_v7j.md:v7h / v7j / v7i 的 per-frame 多源诊断与 class-weighted CE 的引入
  • docs/0421.md:v7f → ov123 per-frame 扩展与 frame-track CSV dump

1. CSV 诊断:v8 ep9 / v8a ep11 的错误模式

1.1 整体准确率(按 "activity ≥ 0.5 的 active track 里 DOA 最近的一条" 取 class)

指标 v8 ep9 v8a ep11
overall cls_ok 933/2051 = 45.5% 934/2051 = 45.5%
DOA≤20° 命中内 cls_ok 48.1% 51.9%
Oracle cls(无视 activity,取 DOA 最近 track 的 class) 38.7% 46.4%

两个关键数字几乎相等:cls_ok(nearest active) ≈ oracle_cls。这说明:

activity 选的 track 和 DOA 选的 track 语义上一致,K=4 的 track binding 是 ok 的;真正错的是 **"那条被选中的 track 自己的 class_logits 就预测错了"**。

换句话说:瓶颈不在 matching,不在 activity,而在 class head 本身的输出分布

1.2 按 ov 分组

v8 ep9:

ov2 : gt= 682  DOA_ok=51.6%  cls_ok=37.1%  any_track_cls_ok=38.9%  cls&DOA=23.9%
ov3 : gt=1158  DOA_ok=42.5%  cls_ok=47.1%  any_track_cls_ok=59.6%  cls&DOA=18.0%
other: gt= 211  DOA_ok=97.6%  cls_ok=64.0%  any_track_cls_ok=64.0%  cls&DOA=63.0%

v8a ep11 类似,ov3 略退:

ov2 :  cls_ok=44.0%  cls&DOA=33.7%
ov3 :  cls_ok=43.1%  cls&DOA=14.2%
other: cls_ok=64.0%  cls&DOA=63.0%

other(近似 ov1 单源)是 64%,ov2 是 37-44%,ov3 降到 43% 且 cls&DOA 只有 14%。多源重叠帧的 class 比单源帧低 20 个点,这是 ov3 demixing 失败的直接证据。

1.3 每类错误模式(Oracle:无视 activity 取 DOA 最近)

aircraft  (n=100): speech(50%)  human_vocalization(50%)   → 永远 0% 正确
vehicle   (n= 68): machine(74%) aircraft(13%) train(10%)  → 永远 0% 正确
frog      (n= 50): bird(98~100%)                           → 永远预测成 bird
speech    (n= 53): human_vocalization(94%) breathing(6%)   → 永远 0% 正确
crackle   (n= 51): rain(53%)    machine(41%)               → 永远 0% 正确
wind      (n= 16): fire(50~88%) vehicle/wood               → 永远 0% 正确
drawer_cab(n= 50): tool(74~84%)                            → 永远 0% 正确
tape      (n= 50): human_vocalization(88%) typing(46%)     → 永远 0% 正确
knock     (n= 36): home_sound(64%) human_vocalization      → 永远 0% 正确
train     (n=115): train(43%) crushing(37%) vehicle(15%)   → ~10%

错误全都是 **同父类 "sibling collapse"**:

  • aircraft / vehicle / train 同属 transportation;
  • speech / human_vocalization / breathing / laughter 同属 human-voice;
  • frog / bird / insect 同属 animal-vocal;
  • drawer_cabinet / tool / home_sound / door 同属 indoor-mechanical。

1.4 Label bug 复核:frog→bird 是不是数据映射错了?

查了 build_ov{1,23}_foa_dataset.pyfinal_vocabulary.csvspatial_dataset.py:_resolve_class_index,以及 jsonl 里具体样本:

ov1 train frog: 113    ov1 train bird: 3270    →  1:29
ov1 valid frog:   9    ov1 valid bird:   56
ov1 test  frog:   3    ov1 test  bird:   51
ov2 train frog:1078    ov2 train bird: 1255
ov3 train frog:1384    ov3 train bird: 1432

frog 标签完全正确,不是 label bug。100% 预测成 bird 的根因是 ov1 训练集里 bird 比 frog 多 29 倍,纯粹的 class imbalance;其他两个 manifest (ov2/ov3) 已经平衡,但 ov1 训练数据量占比仍然不小,整体 prior 偏向 bird。

2. 六个 Fix 的整体设计原则

  1. 全部 additive:v9 只 增加 新 config 字段与新子模块,从不删除或覆盖 v8a 已有的路径;
  2. 全部 zero-init 或 identity-init:新参数在 ckpt 加载时贡献 0(可验证 v9(v8a.pt) - v8a(v8a.pt) 前向最大 abs diff = 0);
  3. strict=False 兼容:v9 的 18 个新参数在 v8a ckpt 中缺失,走现有 _load_spatial_init_checkpoint / load_state_dict(..., strict=False) 流程自然初始化;
  4. 不改 forward signature 的兼容性FrameTrackPredictionHeads.forward 新增的几个关键字参数默认 None,旧调用链(如 v7 / v8 非-track readout)不受影响;
  5. 所有 LR / schedule 默认行为等价于 v8aclass_head_lr_scale=1.0 时走原 3-group 快速路径,frame_class_ontology_smoothing=0.0 时走原 F.cross_entropy

3. 文件级修改清单

文件 修改内容
spatial_loss.py 新增 frame_class_ontology_smoothingframe_class_ontology_groups 字段;CE 分支支持本体软标签
spatial_modules.py FrameTrackPredictionHeads 新增 6 个构造参数 + 新增 ClassHeadSpectralDemixer 模块
spatial_beats.py SpatialBEATsConfig 新增 8 个字段;两处 FrameTrackPredictionHeads 构造传入新参数;forward 两处 call 传入 pre_pool_features;新增 _derive_pre_pool_time_mask helper
train_spatial_beats.py 新增 class_head_lr_scaleclass_head_freeze_during_ramp_epochsclass_head_lr_scale_during_ramp 字段;build_optimizer 拆出 cls_head group;epoch loop 动态写 cls_head LR;新增 _V9_CLASS_WEIGHTS_V9_ONTOLOGY_GROUPSmake_ov1_local_spatial_v9_ov123_top4_config;preset dispatch 和 --preset choices 注册 v9_ov123_top4
run_ov1_v9_ov123_top4.sh 新增启动脚本(chmod +x)

4. Fix A —— frog/稀有类的诊断(无代码改动)

Fix A 是诊断性质的,没有代码落地。结论:frog 不是 label bug,是 ov1 train 的 29:1 imbalance。解决方案被折进 Fix D:

  • bird 权重:1.0 → 0.6(压 catch-all)
  • frog 权重:1.0 → 3.0(提 minority)
  • insect 权重:从 v7I 的 4.0 → 1.0(v8/v8a 已经 100%,不需要加权)

5. Fix D —— 重新设计 class 权重 _V9_CLASS_WEIGHTS

5.1 数据证据

v7I 的权重表给了 aircraft / vehicle / insect 权重:

v8 ep9:    aircraft 0%,  vehicle 0%,  insect 100%
v8a ep11:  aircraft 0%,  vehicle 0%,  insect 100%,  printer 100%→59%(退化!)
  • aircraft / vehicle 的 4× 没有任何效果,它们的声学特征与 speech / machine 真实接近;
  • 4× 让误分类(aircraft→speech)的 loss 放大 4 倍,模型为了降 loss 把 speech 的预测分布也拉保守,导致 speech 也掉到 0%(双输);
  • v8a ep11 还有一个异常:printer 从 v8 ep9 的 100% 掉到 59%,原因见 Fix E。

5.2 设计规则

  • "catch-all" 类(GT frame 上被当万金油预测的类)→ 权重下调:
    • human_vocalization 0.4(被 speech / tape / typing / knock 当靶子)
    • bird 0.6(frog 几乎全部坍缩到 bird)
    • machine 0.5(vehicle / crackle 坍缩到 machine)
    • rain 0.5(crackle / singing 的 collapse 目标)
    • breathing 0.5(speech 94% 错分到这里)
    • home_sound 0.6(knock / printer 的 collapse)
  • 易被 collapse 的稀有类 → 权重上调(但不过分,≤ 3×):
    • frog 3.0
    • crackle 2.0、tape 2.0、knock 2.0、drawer_cabinet 2.0、speech 2.0
    • aircraft / vehicle / train 回到 1.0-1.5(4× 已证实无效且伤 sibling)
  • 过度自信的稳健类 → 略压:
    • singing 0.7、printer 0.7(在 ov123 任务里吸收错 FP)

5.3 代码改动

位置:train_spatial_beats.py(紧挨 _V7I_CLASS_WEIGHTS 之后)

_V9_CLASS_WEIGHTS: List[float] = [
    # 0  wind_instrument    1  string_instrument  2  guitar             3  body_sound
    1.0,                    1.0,                  1.0,                  1.0,
    # 4  drum               5  water              6  human_vocalization  7  keyboard_instrument
    1.0,                    1.0,                  0.4,                  1.0,
    # 8  bird               9  tool              10  machine            11  war_sound
    0.6,                    1.0,                  0.5,                  1.0,
    # 12 metal_clink       13  breathing         14  laughter           15  percussion
    1.0,                    0.5,                  1.0,                  1.0,
    # 16 speech            17  bell              18  dog                19  vehicle
    2.0,                    1.0,                  1.0,                  1.5,
    # 20 alarm             21  footsteps         22  train              23  telephone_alarm
    1.0,                    1.0,                  1.5,                  1.0,
    # 24 glass             25  wind              26  kitchenware        27  animal
    1.0,                    1.5,                  1.0,                  1.0,
    # 28 musical_instrument 29 thunderstorm      30  door               31  male_speech
    1.0,                    1.0,                  1.0,                  1.0,
    # 32 female_speech      33 cat               34  home_sound         35  insect
    1.0,                    1.0,                  0.6,                  1.0,
    # 36 typing             37 zipper            38  camera             39  clock
    1.0,                    1.0,                  1.0,                  1.0,
    # 40 fire               41 singing           42  tearing            43  writing
    1.0,                    0.7,                  1.0,                  1.0,
    # 44 car                45 rain              46  scratch            47  gong
    1.0,                    0.5,                  1.0,                  1.0,
    # 48 appliance          49 paper             50  drawer_cabinet     51  ocean
    1.0,                    1.0,                  2.0,                  1.0,
    # 52 knock              53 crackle           54  finger_snapping    55  aircraft
    2.0,                    2.0,                  1.0,                  1.0,
    # 56 crushing           57 printer           58  tape               59  wood
    1.0,                    0.7,                  2.0,                  1.0,
    # 60 crack              61 cooking           62  frog
    1.0,                    1.0,                  3.0,
]
assert len(_V9_CLASS_WEIGHTS) == 63

v9 preset 里:cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)

6. Fix E —— class head LR 单独分组 + DOA ramp 期冻结

6.1 数据证据

v8a ep5 vs ep11(stage 2 DOA ramp 打开前后):

ep5  ov3: ocls=44.4%  printer=100%
ep11 ov3: ocls=44.4%  printer=59%     ← class 被 DOA 梯度扰动掉了

打开 dir/dist 的 Hungarian cost 和 loss 之后,class binding 被扰动。传统做法是整体降 LR,但那会连带把 trunk / decoder 的学习也拖慢。更好的做法是 单独把 class_head 从 optimizer 拉出来,DOA ramp 期间冻结

6.2 代码改动

6.2.1 新增 config 字段(train_spatial_beats.py

    trunk_lr_scale: float = 1.0
    spatial_lr_scale: float = 1.0
    # v9: isolated LR multiplier for the class_head inside
    # frame_track_prediction_heads. When < 1.0 the class head is put in its
    # own param group with lr = base_lr * class_head_lr_scale.  Used during
    # DOA ramp (stage 2) to prevent class binding from being perturbed by
    # the newly-unlocked dir/dist gradients.  1.0 = legacy behaviour.
    class_head_lr_scale: float = 1.0
    # Optional epoch-range override that further scales the class head LR
    # specifically during the DOA ramp.  When set, between
    # frame_spatial_loss_warmup_epochs and frame_spatial_loss_warmup_epochs
    # + class_head_freeze_during_ramp_epochs the class head LR is set to
    # class_head_lr_scale_during_ramp (defaults to 0.0 = frozen).  After the
    # ramp window the LR returns to class_head_lr_scale.
    class_head_freeze_during_ramp_epochs: int = 0
    class_head_lr_scale_during_ramp: float = 0.0

6.2.2 build_optimizer 拆出 cls_head group

  • 新增 _CLASS_HEAD_PREFIXES
    _CLASS_HEAD_PREFIXES = (
        "frame_track_prediction_heads.class_head.",
        "frame_track_prediction_heads.class_head_mlp.",
        "frame_track_prediction_heads.class_head_demixer.",
    )
    
  • fast path 条件从 trunk_scale == 1.0 and spatial_scale == 1.0 改为再加 and cls_head_scale == 1.0
  • cls_head_scale != 1.0 时,匹配 _CLASS_HEAD_PREFIXES 的参数从 head_params 抽出,放入 cls_head_params
  • 每个 param group 增加 group_name 字段("trunk" / "spatial" / "head" / "cls_head"),便于 epoch loop 按 name 定位。

关键代码片段:

for name, param in model.named_parameters():
    if not param.requires_grad:
        continue
    if name.startswith(_CLASS_HEAD_PREFIXES) and cls_head_scale != 1.0:
        cls_head_params.append(param)
    elif name.startswith(_TRUNK_PREFIXES):
        trunk_params.append(param)
    elif name.startswith(_SPATIAL_PREFIXES):
        spatial_params.append(param)
    else:
        head_params.append(param)

param_groups = []
if trunk_params:   param_groups.append({"params": trunk_params,   "lr": base_lr * trunk_scale,   "weight_decay": wd, "group_name": "trunk"})
if spatial_params: param_groups.append({"params": spatial_params, "lr": base_lr * spatial_scale, "weight_decay": wd, "group_name": "spatial"})
if head_params:    param_groups.append({"params": head_params,    "lr": base_lr,                 "weight_decay": wd, "group_name": "head"})
if cls_head_params:param_groups.append({"params": cls_head_params,"lr": base_lr * cls_head_scale, "weight_decay": wd, "group_name": "cls_head"})

6.2.3 epoch loop 动态写 cls_head LR

紧跟 spatial loss schedule 之后(在 _log(f"[Epoch {epoch}] start") 之前):

_cls_ramp_len = int(train_cfg.class_head_freeze_during_ramp_epochs)
if _cls_ramp_len > 0 and _sp_warmup > 0 and train_cfg.class_head_lr_scale != 1.0:
    in_ramp = _sp_warmup <= epoch < _sp_warmup + _cls_ramp_len
    if in_ramp:
        _cls_scale = train_cfg.class_head_lr_scale_during_ramp
    else:
        _cls_scale = train_cfg.class_head_lr_scale
    for _g in optimizer.param_groups:
        if _g.get("group_name") == "cls_head":
            _g["lr"] = train_cfg.learning_rate * _cls_scale
    _log(
        f"[Epoch {epoch}] cls_head_lr scale={_cls_scale:.3f}  "
        f"lr={train_cfg.learning_rate * _cls_scale:.2e}  "
        f"(ramp_window={_sp_warmup}..{_sp_warmup + _cls_ramp_len - 1})"
    )

6.2.4 v9 preset 设定

cfg.class_head_lr_scale = 0.3
cfg.class_head_freeze_during_ramp_epochs = 4
cfg.class_head_lr_scale_during_ramp = 0.0

v9 的 SPATIAL_LR=1.5e-5,各阶段 cls_head LR:

阶段 epoch lambda_dir cls_head_lr
stage 1(class-only warmup) 0-2 0.0 4.5e-6
stage 2(DOA ramp,cls_head 冻结) 3-6 ramp 0→1 0.0
stage 3(全放开) 7+ 1.0 4.5e-6

7. Fix B —— 本体论(hierarchical)标签平滑

7.1 核心思想

v8/v8a 的 class 错误 ≥70% 是 sibling collapse(同 AudioSet 父类)。Hard CE 把 frog→birdfrog→aircraft 一视同仁地惩罚满 log loss,这不合理:

  • 对下游 LLM 而言 frog↔bird 混淆可以靠语义上下文恢复;
  • aircraft↔speech 这种跨域错误则完全荒谬。

希望 loss 的惩罚强度 **匹配错误的 "语义距离"**。最简单的实现:在同父类内部做 label smoothing,跨父类保持硬 CE。

7.2 代码改动

7.2.1 SpatialLossConfig 新字段(spatial_loss.py

    # v9 hierarchical label smoothing for the frame-track class head.
    # When frame_class_ontology_smoothing > 0, the CE target becomes a soft
    # label distribution:
    #   target[c_gt] = 1 - eps
    #   target[c_sib] = eps / |siblings|  (for each sibling in same ontology
    #                                       group as c_gt; excludes c_gt itself)
    #   target[c_other] = 0
    # ...
    frame_class_ontology_smoothing: float = 0.0
    # Parallel list of sibling groups.  Each entry is a list of class indices
    # belonging to the same AudioSet ontology parent.  A class may appear in
    # only one group.  Empty list = no hierarchical smoothing.
    frame_class_ontology_groups: List[List[int]] = None

__post_init__None 替换为 [] 以防意外。

7.2.2 CE 分支改写(spatial_loss.py:compute_frame_track_losses

原有分支:

loss_class = F.cross_entropy(
    class_logits_flat, class_target_flat, weight=_cls_weights
)

改为:

eps_onto = float(config.frame_class_ontology_smoothing)
onto_groups = config.frame_class_ontology_groups
if eps_onto > 0.0 and onto_groups:
    num_classes = class_logits_flat.size(-1)
    # Build a [C, C] soft-target "mixing" table on first use and
    # cache it on the config object to avoid per-batch rebuild.
    if (
        not hasattr(config, "_onto_mixing_table")
        or config._onto_mixing_table is None
        or config._onto_mixing_table.shape[0] != num_classes
        or config._onto_mixing_table.dtype != class_logits_flat.dtype
        or config._onto_mixing_table.device != device
    ):
        table = torch.zeros((num_classes, num_classes), dtype=class_logits_flat.dtype, device=device)
        table.fill_diagonal_(1.0)
        for group in onto_groups:
            members = [int(c) for c in group if 0 <= int(c) < num_classes]
            if len(members) < 2:
                continue
            for c in members:
                siblings = [s for s in members if s != c]
                table[c].zero_()
                table[c, c] = 1.0 - eps_onto
                sib_mass = eps_onto / len(siblings)
                for s in siblings:
                    table[c, s] = sib_mass
        config._onto_mixing_table = table
    soft_target = config._onto_mixing_table[class_target_flat]
    log_probs = F.log_softmax(class_logits_flat, dim=-1)
    if _cls_weights is not None:
        sample_w = _cls_weights[class_target_flat]
        per_sample_loss = -(soft_target * log_probs).sum(dim=-1)
        loss_class = (per_sample_loss * sample_w).sum() / sample_w.sum().clamp_min(1e-8)
    else:
        loss_class = -(soft_target * log_probs).sum(dim=-1).mean)
else:
    loss_class = F.cross_entropy(class_logits_flat, class_target_flat, weight=_cls_weights)

实现要点:

  • mixing table 缓存:表存在 config 对象上(不是模块),第一次构建并缓存,之后按形状/device/dtype 重用,开销可忽略;
  • per-class weight 兼容:当同时启用 ontology smoothing 和 class weights 时,正类权重照常生效(按 GT hard class idx 加权每样本 loss);
  • 不在 group 里的类:table 默认对角线 = 1,所以未列入任何 group 的类自然退化为 hard one-hot,零影响;
  • 完全开关frame_class_ontology_smoothing=0frame_class_ontology_groups=[] 都会走原 F.cross_entropy,v8/v8a 训练复现不受影响。

7.3 Ontology groups(_V9_ONTOLOGY_GROUPS

位置:train_spatial_beats.py,v9 preset 之前。

_V9_ONTOLOGY_GROUPS: List[List[int]] = [
    # transportation: aircraft, vehicle, train, car
    [55, 19, 22, 44],
    # human voice (non-singing): speech, human_vocalization, male_speech,
    # female_speech, breathing, laughter
    [16, 6, 31, 32, 13, 14],
    # animal vocal: bird, frog, insect, dog, cat, animal
    [8, 62, 35, 18, 33, 27],
    # indoor mechanical + appliances: tool, machine, appliance, printer,
    # home_sound, door, drawer_cabinet, kitchenware, camera, clock, typing,
    # zipper, tape, cooking
    [9, 10, 48, 57, 34, 30, 50, 26, 38, 39, 36, 37, 58, 61],
    # percussive / impact: knock, footsteps, crack, crackle, crushing,
    # scratch, finger_snapping, tearing, writing, paper
    [52, 21, 60, 53, 56, 46, 54, 42, 43, 49],
    # weather / water / ambience: wind, rain, thunderstorm, ocean, water,
    # fire, glass, metal_clink, wood
    [25, 45, 29, 51, 5, 40, 24, 12, 59],
    # musical instruments: wind_instrument, string_instrument, guitar, drum,
    # keyboard_instrument, percussion, musical_instrument, gong, bell,
    # singing
    [0, 1, 2, 4, 7, 15, 28, 47, 17, 41],
    # alarms / signals: alarm, telephone_alarm, war_sound
    [20, 23, 11],
]

8 组,覆盖 62/63 个 class(只有 body_sound 没归组,走 hard CE)。每个 class 仅出现在一个组中。

v9 preset:

cfg.loss.frame_class_ontology_smoothing = 0.1
cfg.loss.frame_class_ontology_groups = [list(g) for g in _V9_ONTOLOGY_GROUPS]

8. Fix C —— 频谱级 demixing cross-attention

8.1 动机

  • track_time_features[B, K, T_s, D]SourceQueryDecoderfused_embeddings[B, T_s, D] 里 decode 出来的;fused_embeddingsfrequency_pool 之后,频率维已经被池化掉了
  • DOA 能在多源重叠帧 demix(IV 通道物理上就编码了方向),但 class 没有等价物理解混通路;
  • 需要让每个 track latent 能"回看"池化前的 trunk 输出,从 F_p 个频率 token 里挑自己负责的那部分。

8.2 模型结构

  • 输入:
    • track_time_features: [B, K, T_s, D](已经过 input_norm
    • pre_pool_features: [B, T_p * F_p, D](BEATs trunk 输出,无 task tokens,未经 frequency_pool)
    • pre_pool_grid_size: (T_p, F_p)
    • pre_pool_time_mask: [B, T_p](True = 有效时间步)
  • 时间对齐:frame t ∈ [0, T_s) → trunk 时间步 t_p = round(t * T_p / T_s),clip 到 [0, T_p-1]
  • KV 构造:kv_grid[:, t_p, :, :][B, T_s, F_p, D],然后在 K 轴 expand 到 [B, K, T_s, F_p, D],flatten 成 [B*K*T_s, F_p, D]
  • Query:track_time_features reshape 成 [B*K*T_s, 1, D]
  • 经 1 层 nn.MultiheadAttentionnum_heads=8, dropout=0.1batch_first=True),再 out_proj(Linear 768→768),乘以标量 gate,加回 class head 的 class_input

8.3 初始化策略(关键)

  • out_proj.weight 全零、out_proj.bias 全零 → 加载时 demixer 输出 = 0,class_logits 与 v8a 完全相同;
  • gate = 1e-2不是 0!)→ 前向 = gate * 0 = 0(身份保证),但 ∂L/∂out_proj.weight = gate × ... 非零,梯度从 step 0 就能流进 demixer 的 attention 权重;这是关键的"gradient-warmup trick",否则两端 zero 会把 demixer 永久冻在 0。

8.4 ClassHeadSpectralDemixer 源码(spatial_modules.py

class ClassHeadSpectralDemixer(nn.Module):
    def __init__(self, embed_dim=768, num_layers=1, num_heads=8, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_layers = max(1, int(num_layers))
        self.kv_norm = nn.LayerNorm(embed_dim)
        self.q_norm = nn.LayerNorm(embed_dim)
        self.layers = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,
                                  dropout=dropout, batch_first=True)
            for _ in range(self.num_layers)
        ])
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        nn.init.zeros_(self.out_proj.weight)
        nn.init.zeros_(self.out_proj.bias)
        self.gate = nn.Parameter(torch.full((1,), 1e-2))

    def forward(self, track_time_features, pre_pool_features,
                pre_pool_grid_size, pre_pool_time_mask=None):
        B, K, T_s, D = track_time_features.shape
        T_p, F_p = int(pre_pool_grid_size[0]), int(pre_pool_grid_size[1])
        if pre_pool_features.size(-1) != D:
            raise ValueError(...)
        expected = T_p * F_p
        if pre_pool_features.size(1) != expected:
            # Fall back gracefully — demixer is additive & zero-gated.
            return track_time_features.new_zeros(track_time_features.shape)
        kv_grid = pre_pool_features.view(B, T_p, F_p, D)
        if T_s > 0 and T_p > 0:
            time_idx = torch.arange(T_s, device=kv_grid.device).float() * (T_p / max(1, T_s))
            time_idx = time_idx.round().clamp_(0, T_p - 1).long()
        else:
            time_idx = torch.zeros((T_s,), dtype=torch.long, device=kv_grid.device)
        kv_per_frame = kv_grid[:, time_idx, :, :]
        kv_per_frame = kv_per_frame.unsqueeze(1).expand(B, K, T_s, F_p, D).contiguous()
        kv_flat = kv_per_frame.view(B * K * T_s, F_p, D)
        kv_flat = self.kv_norm(kv_flat)
        q_flat = track_time_features.reshape(B * K * T_s, 1, D)
        q_flat = self.q_norm(q_flat)
        key_padding_mask = None
        if pre_pool_time_mask is not None:
            per_frame_valid = pre_pool_time_mask[:, time_idx]
            per_frame_valid = per_frame_valid.unsqueeze(1).expand(B, K, T_s).reshape(-1)
            if not per_frame_valid.all():
                ignore = ~per_frame_valid
                key_padding_mask = ignore.unsqueeze(1).expand(-1, F_p).contiguous()
        attn_out = q_flat
        for layer in self.layers:
            attn_out, _ = layer(attn_out, kv_flat, kv_flat,
                                key_padding_mask=key_padding_mask, need_weights=False)
        residual = self.out_proj(attn_out).view(B, K, T_s, D)
        return residual * self.gate

8.5 _derive_pre_pool_time_mask helper(spatial_beats.py

紧跟 _build_patch_padding_mask 之后新增:

def _derive_pre_pool_time_mask(
    self,
    patch_padding_mask: Optional[Tensor],
    grid_size: Tuple[int, int],
) -> Optional[Tensor]:
    """Return a [B, T_p] boolean mask where True marks *valid* trunk time
    steps.  Used by the v9 class-head spectral demixer to ignore padded
    tail frames."""
    if patch_padding_mask is None:
        return None
    t_p, f_p = grid_size
    B = patch_padding_mask.size(0)
    pad_grid = patch_padding_mask.view(B, t_p, f_p)
    time_valid = ~pad_grid.all(dim=-1)
    return time_valid

patch_padding_mask 的语义是 True = padded,time-valid 是 "某时间步还有非 padded 频率位置" → ~pad_grid.all(dim=-1)

8.6 forward 两处 call 更新(spatial_beats.py

两处调用 self.frame_track_prediction_heads(...) 都扩展:

_pre_pool_time_mask = self._derive_pre_pool_time_mask(
    patch_padding_mask=patch_padding_mask,
    grid_size=grid_size,
)
frame_track_prediction_output = self.frame_track_prediction_heads(
    track_time_features=track_time_features,
    track_latents=track_latents,
    pre_pool_features=encoder_memory,
    pre_pool_grid_size=grid_size,
    pre_pool_time_mask=_pre_pool_time_mask,
)
  • 第 1 处在 readout_scheme == "local_spatial" 分支下 frame-track parallel 路径;
  • 第 2 处在 readout_scheme == "local_spatial_track" 分支下纯 per-frame 路径(v9 走这里)。

两处都用到的上下文变量:

  • encoder_memoryself.encode_patches(...) 的输出,[B, T_p*F_p, D]已经过 trunk 但未 frequency_pool,正好是 demixer 需要的 pre_pool features;
  • grid_sizeself.extract_patch_tokens(...) 返回的 (T_p, F_p)
  • patch_padding_maskself._build_patch_padding_mask(...) 返回。

9. Fix F —— 2-layer MLP 残差

9.1 动机

当前 class head 是 nn.Linear(768, 63),对多源混合 token 的表达能力可能不够。加一个 2-layer MLP 残差 branch:

class_logits = class_head(x) + gate * class_head_mlp(x)
class_head_mlp = Linear(768, 1536) → GELU → Dropout → LayerNorm → Linear(1536, 63)

9.2 初始化策略

与 Fix C 的 demixer 同构:

  • class_head_mlp[-1].weight / bias 全零 → 残差输出 = 0;
  • class_head_mlp_gate = 1e-2 → 前向仍然 = 0,但梯度可流进 MLP 最后一层(非零);
  • 经过 1-2 个 step MLP 最后一层有非零权重后,前一层(GELU 前)也开始获得梯度。

9.3 FrameTrackPredictionHeads 重构(spatial_modules.py

构造签名扩展:

def __init__(
    self,
    embed_dim: int = 768,
    num_classes: int = 63,
    dropout: float = 0.1,
    use_class_head_mlp_residual: bool = False,
    class_head_mlp_hidden_multiplier: int = 2,
    class_head_mlp_dropout: float = 0.1,
    use_class_head_demixer: bool = False,
    class_head_demixer_layers: int = 1,
    class_head_demixer_heads: int = 8,
    class_head_demixer_dropout: float = 0.1,
) -> None:

构造体内新增:

self.use_class_head_mlp_residual = bool(use_class_head_mlp_residual)
if self.use_class_head_mlp_residual:
    hidden = embed_dim * max(1, int(class_head_mlp_hidden_multiplier))
    self.class_head_mlp = nn.Sequential(
        nn.Linear(embed_dim, hidden),
        nn.GELU(),
        nn.Dropout(class_head_mlp_dropout),
        nn.LayerNorm(hidden),
        nn.Linear(hidden, num_classes),
    )
    nn.init.zeros_(self.class_head_mlp[-1].weight)
    nn.init.zeros_(self.class_head_mlp[-1].bias)
    self.class_head_mlp_gate = nn.Parameter(torch.full((1,), 1e-2))
else:
    self.class_head_mlp = None
    self.class_head_mlp_gate = None

self.use_class_head_demixer = bool(use_class_head_demixer)
if self.use_class_head_demixer:
    self.class_head_demixer = ClassHeadSpectralDemixer(
        embed_dim=embed_dim,
        num_layers=class_head_demixer_layers,
        num_heads=class_head_demixer_heads,
        dropout=class_head_demixer_dropout,
    )
else:
    self.class_head_demixer = None

forward 改写:

def forward(
    self,
    track_time_features: Tensor,
    track_latents: Tensor,
    pre_pool_features: Optional[Tensor] = None,
    pre_pool_grid_size: Optional[Tuple[int, int]] = None,
    pre_pool_time_mask: Optional[Tensor] = None,
) -> FrameTrackPredictionOutput:
    ...
    x = self.input_norm(track_time_features)
    activity = self.activity_head(x).squeeze(-1)
    class_input = x
    if (
        self.class_head_demixer is not None
        and pre_pool_features is not None
        and pre_pool_grid_size is not None
    ):
        demix_residual = self.class_head_demixer(
            track_time_features=x,
            pre_pool_features=pre_pool_features,
            pre_pool_grid_size=pre_pool_grid_size,
            pre_pool_time_mask=pre_pool_time_mask,
        )
        class_input = class_input + demix_residual
    class_logits = self.class_head(class_input)
    if self.class_head_mlp is not None and self.class_head_mlp_gate is not None:
        class_logits = class_logits + self.class_head_mlp_gate * self.class_head_mlp(class_input)
    direction = F.normalize(self.direction_head(x), dim=-1)
    distance = F.softplus(self.distance_head(x)).squeeze(-1)
    ...

关键细节

  • demixer 的 residual 加到 class_input 上(即 class_head 的输入),而不是加到 logits 上;这样 demixer 得到的"纠偏"信号先经过 class_head 的共享投影再生成 logits;
  • MLP 分支也看 class_input(包含 demixer residual),这样 demixer 的信息同样能进入 MLP 分支;
  • direction_headdistance_head 的输入 x 不加 demixer(demixer 是 class-specific),保持 DOA head 对 v8a ckpt 的完全等价。

9.4 SpatialBEATsConfig 新字段(spatial_beats.py

self.frame_track_dropout: float = 0.1
self.frame_accdoa_hidden_dim: int = 256
self.frame_accdoa_dropout: float = 0.1

# v9: optional zero-initialised MLP residual branch inside
# FrameTrackPredictionHeads.  ...
self.use_class_head_mlp_residual: bool = False
self.class_head_mlp_hidden_multiplier: int = 2
self.class_head_mlp_dropout: float = 0.1
# v9: optional spectral demixing cross-attention branch.  ...
self.use_class_head_demixer: bool = False
self.class_head_demixer_layers: int = 1
self.class_head_demixer_heads: int = 8
self.class_head_demixer_dropout: float = 0.1

9.5 两处构造都传入新参数

spatial_beats.py 两处 FrameTrackPredictionHeads(...) 调用都改为:

self.frame_track_prediction_heads = FrameTrackPredictionHeads(
    embed_dim=cfg.encoder_embed_dim,
    num_classes=cfg.source_num_classes,
    dropout=cfg.frame_track_dropout,
    use_class_head_mlp_residual=cfg.use_class_head_mlp_residual,
    class_head_mlp_hidden_multiplier=cfg.class_head_mlp_hidden_multiplier,
    class_head_mlp_dropout=cfg.class_head_mlp_dropout,
    use_class_head_demixer=cfg.use_class_head_demixer,
    class_head_demixer_layers=cfg.class_head_demixer_layers,
    class_head_demixer_heads=cfg.class_head_demixer_heads,
    class_head_demixer_dropout=cfg.class_head_demixer_dropout,
)

10. v9 preset 与启动脚本

10.1 make_ov1_local_spatial_v9_ov123_top4_configtrain_spatial_beats.py

位置:v7k 系列的最后一个 preset 之后,v3: top-8 unfreeze 注释分割线之前。

def make_ov1_local_spatial_v9_ov123_top4_config(
    ov1_manifest_path: str = DEFAULT_OV1_MANIFEST,
    ov2_manifest_path: str = DEFAULT_OV2_MANIFEST,
    ov3_manifest_path: str = DEFAULT_OV3_MANIFEST,
) -> TrainSpatialBEATsConfig:
    """v9 = v8a + class-first cleanup (fixes A..F).

    Inherits v8a (cross-attn fusion + segment matching + 4-epoch DOA ramp),
    applies:
      - _V9_CLASS_WEIGHTS (suppress catch-all classes, boost frog/crackle/tape)
      - ontology-aware label smoothing (eps=0.1)
      - class head residual MLP + spectral demixer (both zero-init)
      - class_head_lr_scale=0.3 with full freeze during the 4-epoch DOA ramp

    Frontend / trunk / source_query_decoder / activity / dir / dist heads
    are unchanged.  Hot-start from v8a best.pt works with strict=False.
    """
    cfg = make_ov1_local_spatial_v8a_ov123_top4_config(
        ov1_manifest_path=ov1_manifest_path,
        ov2_manifest_path=ov2_manifest_path,
        ov3_manifest_path=ov3_manifest_path,
    )

    # (D) Re-balanced class weights driven by v8/v8a CSV confusion analysis.
    cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)

    # (B) Hierarchical (ontology-aware) label smoothing.
    cfg.loss.frame_class_ontology_smoothing = 0.1
    cfg.loss.frame_class_ontology_groups = [list(g) for g in _V9_ONTOLOGY_GROUPS]

    # (F) Zero-gated MLP residual on the class head.
    cfg.model.use_class_head_mlp_residual = True
    cfg.model.class_head_mlp_hidden_multiplier = 2
    cfg.model.class_head_mlp_dropout = 0.1

    # (C) Zero-gated spectral demixing cross-attention on the class head.
    cfg.model.use_class_head_demixer = True
    cfg.model.class_head_demixer_layers = 1
    cfg.model.class_head_demixer_heads = 8
    cfg.model.class_head_demixer_dropout = 0.1

    # (E) Class head gets its own LR group.
    cfg.class_head_lr_scale = 0.3
    cfg.class_head_freeze_during_ramp_epochs = 4
    cfg.class_head_lr_scale_during_ramp = 0.0

    cfg.num_epochs = 12
    cfg.output_dir = "checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/03_ov123_top4"
    return cfg

10.2 preset 注册

train_spatial_beats.py 中两处:

  1. args.preset == "ov1_local_spatial_v8a_ov123_top4" 分支之后,新增:
    elif args.preset == "ov1_local_spatial_v9_ov123_top4":
        cfg = make_ov1_local_spatial_v9_ov123_top4_config(
            ov1_manifest_path=args.ov1_manifest,
            ov2_manifest_path=args.ov2_manifest,
            ov3_manifest_path=args.ov3_manifest,
        )
    
  2. argparse --presetchoices=(...) 列表里在 "ov1_local_spatial_v8a_ov123_top4" 之后加入 "ov1_local_spatial_v9_ov123_top4"

10.3 run_ov1_v9_ov123_top4.sh

新建文件(chmod +x):

#!/usr/bin/env bash
set -euo pipefail

# v9_ov123_top4: v8a + class-first cleanup (Fix A..F from docs/0423.md analysis)
# 所有 fix 都是 additive + zero-init,从 v8a best.pt 热启动 epoch-0 输出与 v8a 完全相同

GPUS="${GPUS:-8}"
BATCH_SIZE="${BATCH_SIZE:-8}"
NUM_WORKERS="${NUM_WORKERS:-8}"
SPATIAL_EPOCHS="${SPATIAL_EPOCHS:-12}"
SPATIAL_LR="${SPATIAL_LR:-1.5e-5}"
AMP="${AMP:-fp32}"

OV1_MANIFEST="${OV1_MANIFEST:-/apdcephfs_cq10/.../ov1_foa.jsonl}"
OV2_MANIFEST="${OV2_MANIFEST:-/apdcephfs_cq10/.../ov2_foa.jsonl}"
OV3_MANIFEST="${OV3_MANIFEST:-/apdcephfs_cq10/.../ov3_foa.jsonl}"

RESUME_CKPT="${RESUME_CKPT:-checkpoints/spatial_beats_ov1_local_spatial_v8a_ov123_exp/03_ov123_top4/best.pt}"
OUT_DIR="${OUT_DIR:-checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/03_ov123_top4}"

torchrun --nproc_per_node="${GPUS}" --master-port="${MASTER_PORT:-29557}" train_spatial_beats.py \
  --preset ov1_local_spatial_v9_ov123_top4 \
  --resume "${RESUME_CKPT}" \
  --output-dir "${OUT_DIR}" \
  --ov1-manifest "${OV1_MANIFEST}" \
  --ov2-manifest "${OV2_MANIFEST}" \
  --ov3-manifest "${OV3_MANIFEST}" \
  --batch-size "${BATCH_SIZE}" \
  --num-workers "${NUM_WORKERS}" \
  --num-epochs "${SPATIAL_EPOCHS}" \
  --learning-rate "${SPATIAL_LR}" \
  --amp "${AMP}" \
  --no-resume-optimizer \
  --reset-epoch-on-resume \
  --reset-best-on-resume

11. 正确性验证(已完成)

11.1 语法检查

四个 py 文件(train_spatial_beats.py / spatial_loss.py / spatial_beats.py / spatial_modules.py)+ run_ov1_v9_ov123_top4.sh 全部通过 ast.parse / bash -n。

11.2 模块级单测

FrameTrackPredictionHeads + ClassHeadSpectralDemixer

  • 开启所有 v9 选项后,max abs diff = 0.00e+00 相对于无 v9 选项的同一模型;
  • 带 padding mask 时 pred_class_logits 无 NaN;
  • 打开 MLP gate / demixer gate + 轻微扰动权重后,logits 差异 ~1e-2(符合预期)。

11.3 v9 vs v8a 端到端前向恒等

torch.manual_seed(42)
m_v8a = SpatialBEATs(make_ov1_local_spatial_v8a_ov123_top4_config().model).eval()
torch.manual_seed(42)
m_v9  = SpatialBEATs(make_ov1_local_spatial_v9_ov123_top4_config().model).eval()

# 两个模型都 load v8a/best.pt,strict=False
# 同一 waveform 输入
max_abs_diff = (o_v8a - o_v9).abs().max()
# 实测:0.00e+00

v9 加载 v8a ckpt 后的 pred_class_logits 与 v8a 模型加载同一 ckpt 的输出逐元素完全相等(0.00e+00)。

11.4 v8a ckpt 加载统计

ckpt keys            = 425
loadable (v9 shapes) = 425
v9 missing params    = 18   ← class_head_mlp.* (7 个) + class_head_demixer.* (11 个)
ckpt unexpected      = 0

18 个新参数通过 strict=False 默认初始化(按 Fix C / F 的 zero-out-proj + tiny-gate 策略)。

11.5 梯度流

假 CE loss + backward:

  • class_head.weight:grad_norm ≈ 9.06(正常)
  • class_head_mlp.4.weight(MLP 最后一层,zero-init):grad_norm ≈ 1.32e-1(非零,因为 gate = 1e-2
  • class_head_mlp.0.weight(MLP 第一层):grad_norm = 0(正常,要等最后一层非零后才有梯度,1 step 内开解)
  • class_head_demixer.out_proj.weight:grad_norm ≈ 1.60e-2(非零
  • class_head_demixer.layers.0.in_proj_weight:grad_norm = 0(同理)

11.6 Optimizer 分组

group 0: name=trunk    lr=3.00e-06  n_params=238
group 1: name=spatial  lr=9.00e-06  n_params= 84
group 2: name=head     lr=3.00e-05  n_params= 91
group 3: name=cls_head lr=9.00e-06  n_params= 19

(上例 base_lr=3e-5;v9 启动脚本里 SPATIAL_LR=1.5e-5,对应 cls_head lr = 4.5e-6。)

12. 关键训练时间线(SPATIAL_EPOCHS=12)

继承 v8a 的 frame_spatial_loss_warmup_epochs=3frame_spatial_loss_ramp_epochs=4,叠加 v9 新的 cls_head LR 调度:

epoch lambda_dir / dist dir/dist match cost cls_head lr 说明
0-2 0.0 0.0 4.5e-6 Stage 1:class-only warmup,cls_head 低 LR 继续微调(但本来就 v8a 延续)
3 ramp 0.25 0.25 0.0 Stage 2 起点,cls_head 冻结
4 ramp 0.50 0.50 0.0 cls_head 冻结
5 ramp 0.75 0.75 0.0 cls_head 冻结
6 ramp 1.00 1.00 0.0 cls_head 冻结,DOA 全开
7-11 1.0 1.0 4.5e-6 Stage 3:全部放开,cls_head 低 LR 微调

13. 验证清单 / 观察优先级

启动后按以下顺序观察指标:

  1. epoch 0 validation metrics 是否与 v8a 的最后一个 epoch 等价
    • 预期 val loss ≈ v8a best val loss(因为前向 = v8a)
    • 若不等价,说明 hot-start 出问题;
  2. epoch 0-2(stage 1)
    • ocls(oracle class acc)是否比 v8a ep2 同期高?关键指标,主要受 Fix D + B + C + F 驱动;
    • 若 Fix D 生效:aircraft / vehicle 不再 0%,frog 不再 100%→bird;
    • 若 Fix B 生效:smooth CE loss 比 hard CE 降得更快(从 1 - eps 开始);
    • 若 Fix C/F 生效:每 epoch 训完后 class_head_mlp_gate / class_head_demixer.gate 应从 0.01 慢慢增长;
  3. epoch 3-6(stage 2,cls_head 冻结)
    • DOA 指标改善(LE_CD 下降,oazi 下降),ocls 不应退化(cls_head LR=0,梯度不回传);
    • ocls 退化:可能是 fusion/trunk 的梯度通过 demixer 影响了 class_head 输入分布,此时需要考虑把 demixer 也一起冻;
  4. epoch 7+(stage 3,全开)
    • F20 是否突破 v8a 的上限(v7h 基线 0.246,v8a ~ 0.22-0.25);
    • ov3 class_ok 是否从 43% 往上走;
    • 若 ov2/ov3 class_ok 都有改善但 ov3 仍然落后:Fix C 的频谱 demixing 还需要再加层数或 heads。

14. 代码 diff 总览

14.1 spatial_loss.py

  • 新增 2 个 SpatialLossConfig 字段:frame_class_ontology_smoothingframe_class_ontology_groups
  • __post_init__ 处理 frame_class_ontology_groups = None
  • compute_frame_track_losses 里 CE 分支支持 soft target + class weight 复合。

14.2 spatial_modules.py

  • FrameTrackPredictionHeads.__init__ 新增 6 个 kwargs;
  • FrameTrackPredictionHeads.forward 新增 3 个 Optional kwargs(pre_pool_features / pre_pool_grid_size / pre_pool_time_mask),内部整合 demixer + MLP residual;
  • 新增 ClassHeadSpectralDemixer 模块。

14.3 spatial_beats.py

  • SpatialBEATsConfig.__init__ 新增 8 个 self.* 字段;
  • 两处 FrameTrackPredictionHeads(...) 构造传入全部新参数;
  • 两处 self.frame_track_prediction_heads(...) 调用传入 pre-pool 参数;
  • 新增 SpatialBEATs._derive_pre_pool_time_mask 实例方法。

14.4 train_spatial_beats.py

  • TrainSpatialBEATsConfig 新增 3 个字段:class_head_lr_scaleclass_head_freeze_during_ramp_epochsclass_head_lr_scale_during_ramp
  • 新增常量 _V9_CLASS_WEIGHTS(63 长度)+ _V9_ONTOLOGY_GROUPS(8 组,覆盖 62/63 class);
  • 新增 preset 函数 make_ov1_local_spatial_v9_ov123_top4_config
  • build_optimizer 新增 _CLASS_HEAD_PREFIXEScls_head_params 分组、group_name 标签、fast-path 条件更新;
  • epoch loop 动态写 cls_head LR(紧跟 spatial loss schedule 后);
  • preset dispatch 和 --preset argparse choices 注册 ov1_local_spatial_v9_ov123_top4

14.5 run_ov1_v9_ov123_top4.sh

  • 新建 shell 脚本,默认 8 GPU / BS=8 / 12 epoch / LR=1.5e-5 / fp32 / master_port=29557;
  • 默认 RESUME_CKPT 指向 v8a best.pt;
  • 复用 --no-resume-optimizer --reset-epoch-on-resume --reset-best-on-resume 三开关。

15. 关键文件与行数

文件 新增代码主要位置
spatial_loss.py SpatialLossConfig 新增字段紧跟 frame_class_loss_weights 之后;CE 分支改写在 compute_frame_track_losses
spatial_modules.py FrameTrackPredictionHeads 整块重写;ClassHeadSpectralDemixer 定义在 FrameTrackPredictionHeads 之后
spatial_beats.py SpatialBEATsConfig.__init__ 的 v9 字段紧跟 frame_accdoa_dropout_derive_pre_pool_time_mask 紧跟 _build_patch_padding_mask
train_spatial_beats.py _V9_CLASS_WEIGHTS 紧跟 _V7I_CLASS_WEIGHTS 之后;_V9_ONTOLOGY_GROUPSmake_ov1_local_spatial_v9_ov123_top4_config 紧跟 v7k 系列结束处;class_head_*_scale 字段紧跟 spatial_lr_scale 之后

16. 后续计划

如果 v9 ep3 ocls 仍不超过 v8a 同期,下一步(v9a/v9b)候选:

  1. demixer 加层class_head_demixer_layers = 2 或增加 heads 到 16;
  2. demixer 冻结共进退:DOA ramp 期把 class_head_demixer.* 也当作 cls_head group 的一部分(目前 _CLASS_HEAD_PREFIXES 已经包含它,在 ramp 期也会冻结——已经生效,但需要验证效果);
  3. ontology smoothing eps 提升:0.1 → 0.2,进一步容忍 sibling collapse;
  4. source_query_decoder 加正交正则(前一版建议):L_orth = ||Q Q^T - I||_F^2,λ=0.01;
  5. 重新审视 matching:如果 cls 瓶颈解了但 F20 仍不涨,回到 docs/0422.md 的 track-dead / duplicate 诊断。

但当前 v9 已经把 "class head 本身" 这一块做得比较彻底,应该先跑满 12 epoch 再决定。