Spatial-BEATs / docs /0423.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
45.5 kB
# 2026-04-23 — v9_ov123_top4:class-first 针对性修复(A→D→E→B→C→F)
## 0. 背景与本轮任务
v8 / v8a 在 `03_ov123_top4` 上完成了 10+ 个 epoch 的训练,整体 spatial 指标比 v7h 有小幅改善,但 `class_ok` 一直卡在 **~45%** 左右。用户要求:
1. 专注 frame 级预测,不引入任何 clip 级监督;
2. 分析 v8 / v8a 的 val CSV(epoch 9 / epoch 11),找出 class 准确率上不去的真正瓶颈;
3. 在不破坏现有代码框架和训练逻辑的前提下,按顺序落地 Fix A→D→E→B→C→F;
4. 全部作为 v9 新 preset + 新脚本,从 v8a `best.pt` 热启动(保证 epoch0 前向输出与 v8a 完全相同)。
本文档记录 2026-04-23 这一轮的 **所有代码修改细节**,以及每个修改背后的 CSV 诊断证据。
相关旧文档:
- `docs/0422.md`:v7h → v8 架构升级(cross-attention fusion)
- `docs/0422_v7h_v7j.md`:v7h / v7j / v7i 的 per-frame 多源诊断与 class-weighted CE 的引入
- `docs/0421.md`:v7f → ov123 per-frame 扩展与 frame-track CSV dump
---
## 1. CSV 诊断:v8 ep9 / v8a ep11 的错误模式
### 1.1 整体准确率(按 "activity ≥ 0.5 的 active track 里 DOA 最近的一条" 取 class)
| 指标 | v8 ep9 | v8a ep11 |
|---|---|---|
| overall `cls_ok` | 933/2051 = **45.5%** | 934/2051 = **45.5%** |
| DOA≤20° 命中内 `cls_ok` | 48.1% | 51.9% |
| **Oracle cls**(无视 activity,取 DOA 最近 track 的 class) | **38.7%** | **46.4%** |
两个关键数字几乎相等:`cls_ok(nearest active) ≈ oracle_cls`。这说明:
> **activity 选的 track 和 DOA 选的 track 语义上一致**,K=4 的 track binding 是 ok 的;真正错的是 **"那条被选中的 track 自己的 class_logits 就预测错了"**。
换句话说:**瓶颈不在 matching,不在 activity,而在 class head 本身的输出分布**。
### 1.2 按 ov 分组
v8 ep9:
```
ov2 : gt= 682 DOA_ok=51.6% cls_ok=37.1% any_track_cls_ok=38.9% cls&DOA=23.9%
ov3 : gt=1158 DOA_ok=42.5% cls_ok=47.1% any_track_cls_ok=59.6% cls&DOA=18.0%
other: gt= 211 DOA_ok=97.6% cls_ok=64.0% any_track_cls_ok=64.0% cls&DOA=63.0%
```
v8a ep11 类似,ov3 略退:
```
ov2 : cls_ok=44.0% cls&DOA=33.7%
ov3 : cls_ok=43.1% cls&DOA=14.2%
other: cls_ok=64.0% cls&DOA=63.0%
```
`other`(近似 ov1 单源)是 64%,ov2 是 37-44%,ov3 降到 43% 且 `cls&DOA` 只有 14%。**多源重叠帧的 class 比单源帧低 20 个点**,这是 ov3 demixing 失败的直接证据。
### 1.3 每类错误模式(Oracle:无视 activity 取 DOA 最近)
```
aircraft (n=100): speech(50%) human_vocalization(50%) → 永远 0% 正确
vehicle (n= 68): machine(74%) aircraft(13%) train(10%) → 永远 0% 正确
frog (n= 50): bird(98~100%) → 永远预测成 bird
speech (n= 53): human_vocalization(94%) breathing(6%) → 永远 0% 正确
crackle (n= 51): rain(53%) machine(41%) → 永远 0% 正确
wind (n= 16): fire(50~88%) vehicle/wood → 永远 0% 正确
drawer_cab(n= 50): tool(74~84%) → 永远 0% 正确
tape (n= 50): human_vocalization(88%) typing(46%) → 永远 0% 正确
knock (n= 36): home_sound(64%) human_vocalization → 永远 0% 正确
train (n=115): train(43%) crushing(37%) vehicle(15%) → ~10%
```
错误全都是 **同父类 "sibling collapse"**:
- `aircraft / vehicle / train` 同属 transportation;
- `speech / human_vocalization / breathing / laughter` 同属 human-voice;
- `frog / bird / insect` 同属 animal-vocal;
- `drawer_cabinet / tool / home_sound / door` 同属 indoor-mechanical。
### 1.4 Label bug 复核:frog→bird 是不是数据映射错了?
查了 `build_ov{1,23}_foa_dataset.py`、`final_vocabulary.csv`、`spatial_dataset.py:_resolve_class_index`,以及 jsonl 里具体样本:
```
ov1 train frog: 113 ov1 train bird: 3270 → 1:29
ov1 valid frog: 9 ov1 valid bird: 56
ov1 test frog: 3 ov1 test bird: 51
ov2 train frog:1078 ov2 train bird: 1255
ov3 train frog:1384 ov3 train bird: 1432
```
**frog 标签完全正确**,不是 label bug。100% 预测成 bird 的根因是 ov1 训练集里 bird 比 frog 多 **29 倍**,纯粹的 class imbalance;其他两个 manifest (ov2/ov3) 已经平衡,但 ov1 训练数据量占比仍然不小,整体 prior 偏向 bird。
## 2. 六个 Fix 的整体设计原则
1. **全部 additive**:v9 只 **增加** 新 config 字段与新子模块,从不删除或覆盖 v8a 已有的路径;
2. **全部 zero-init 或 identity-init**:新参数在 ckpt 加载时贡献 0(可验证 `v9(v8a.pt) - v8a(v8a.pt)` 前向最大 abs diff = 0);
3. **strict=False 兼容**:v9 的 18 个新参数在 v8a ckpt 中缺失,走现有 `_load_spatial_init_checkpoint` / `load_state_dict(..., strict=False)` 流程自然初始化;
4. **不改 forward signature 的兼容性**`FrameTrackPredictionHeads.forward` 新增的几个关键字参数默认 `None`,旧调用链(如 v7 / v8 非-track readout)不受影响;
5. **所有 LR / schedule 默认行为等价于 v8a**`class_head_lr_scale=1.0` 时走原 3-group 快速路径,`frame_class_ontology_smoothing=0.0` 时走原 `F.cross_entropy`
## 3. 文件级修改清单
| 文件 | 修改内容 |
|---|---|
| `spatial_loss.py` | 新增 `frame_class_ontology_smoothing``frame_class_ontology_groups` 字段;CE 分支支持本体软标签 |
| `spatial_modules.py` | `FrameTrackPredictionHeads` 新增 6 个构造参数 + 新增 `ClassHeadSpectralDemixer` 模块 |
| `spatial_beats.py` | `SpatialBEATsConfig` 新增 8 个字段;两处 `FrameTrackPredictionHeads` 构造传入新参数;forward 两处 call 传入 `pre_pool_features`;新增 `_derive_pre_pool_time_mask` helper |
| `train_spatial_beats.py` | 新增 `class_head_lr_scale``class_head_freeze_during_ramp_epochs``class_head_lr_scale_during_ramp` 字段;`build_optimizer` 拆出 `cls_head` group;epoch loop 动态写 cls_head LR;新增 `_V9_CLASS_WEIGHTS`、`_V9_ONTOLOGY_GROUPS`、`make_ov1_local_spatial_v9_ov123_top4_config`;preset dispatch 和 `--preset` choices 注册 `v9_ov123_top4` |
| `run_ov1_v9_ov123_top4.sh` | 新增启动脚本(chmod +x) |
---
## 4. Fix A —— frog/稀有类的诊断(无代码改动)
Fix A 是诊断性质的,没有代码落地。结论:frog 不是 label bug,是 ov1 train 的 29:1 imbalance。解决方案被折进 Fix D:
- `bird` 权重:1.0 → **0.6**(压 catch-all)
- `frog` 权重:1.0 → **3.0**(提 minority)
- `insect` 权重:从 v7I 的 4.0 → **1.0**(v8/v8a 已经 100%,不需要加权)
## 5. Fix D —— 重新设计 class 权重 `_V9_CLASS_WEIGHTS`
### 5.1 数据证据
v7I 的权重表给了 aircraft / vehicle / insect **4×** 权重:
```
v8 ep9: aircraft 0%, vehicle 0%, insect 100%
v8a ep11: aircraft 0%, vehicle 0%, insect 100%, printer 100%→59%(退化!)
```
- aircraft / vehicle 的 4× 没有任何效果,它们的声学特征与 speech / machine 真实接近;
- 4× 让误分类(aircraft→speech)的 loss 放大 4 倍,模型为了降 loss 把 **speech 的预测分布也拉保守**,导致 speech 也掉到 0%(双输);
- v8a ep11 还有一个异常:printer 从 v8 ep9 的 100% 掉到 59%,原因见 Fix E。
### 5.2 设计规则
- **"catch-all" 类**(GT frame 上被当万金油预测的类)→ 权重下调:
- `human_vocalization` 0.4(被 speech / tape / typing / knock 当靶子)
- `bird` 0.6(frog 几乎全部坍缩到 bird)
- `machine` 0.5(vehicle / crackle 坍缩到 machine)
- `rain` 0.5(crackle / singing 的 collapse 目标)
- `breathing` 0.5(speech 94% 错分到这里)
- `home_sound` 0.6(knock / printer 的 collapse)
- **易被 collapse 的稀有类** → 权重上调(但不过分,≤ 3×):
- `frog` 3.0
- `crackle` 2.0、`tape` 2.0、`knock` 2.0、`drawer_cabinet` 2.0、`speech` 2.0
- `aircraft` / `vehicle` / `train` 回到 1.0-1.5(4× 已证实无效且伤 sibling)
- **过度自信的稳健类** → 略压:
- `singing` 0.7、`printer` 0.7(在 ov123 任务里吸收错 FP)
### 5.3 代码改动
位置:`train_spatial_beats.py`(紧挨 `_V7I_CLASS_WEIGHTS` 之后)
```python
_V9_CLASS_WEIGHTS: List[float] = [
# 0 wind_instrument 1 string_instrument 2 guitar 3 body_sound
1.0, 1.0, 1.0, 1.0,
# 4 drum 5 water 6 human_vocalization 7 keyboard_instrument
1.0, 1.0, 0.4, 1.0,
# 8 bird 9 tool 10 machine 11 war_sound
0.6, 1.0, 0.5, 1.0,
# 12 metal_clink 13 breathing 14 laughter 15 percussion
1.0, 0.5, 1.0, 1.0,
# 16 speech 17 bell 18 dog 19 vehicle
2.0, 1.0, 1.0, 1.5,
# 20 alarm 21 footsteps 22 train 23 telephone_alarm
1.0, 1.0, 1.5, 1.0,
# 24 glass 25 wind 26 kitchenware 27 animal
1.0, 1.5, 1.0, 1.0,
# 28 musical_instrument 29 thunderstorm 30 door 31 male_speech
1.0, 1.0, 1.0, 1.0,
# 32 female_speech 33 cat 34 home_sound 35 insect
1.0, 1.0, 0.6, 1.0,
# 36 typing 37 zipper 38 camera 39 clock
1.0, 1.0, 1.0, 1.0,
# 40 fire 41 singing 42 tearing 43 writing
1.0, 0.7, 1.0, 1.0,
# 44 car 45 rain 46 scratch 47 gong
1.0, 0.5, 1.0, 1.0,
# 48 appliance 49 paper 50 drawer_cabinet 51 ocean
1.0, 1.0, 2.0, 1.0,
# 52 knock 53 crackle 54 finger_snapping 55 aircraft
2.0, 2.0, 1.0, 1.0,
# 56 crushing 57 printer 58 tape 59 wood
1.0, 0.7, 2.0, 1.0,
# 60 crack 61 cooking 62 frog
1.0, 1.0, 3.0,
]
assert len(_V9_CLASS_WEIGHTS) == 63
```
v9 preset 里:`cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)`
## 6. Fix E —— class head LR 单独分组 + DOA ramp 期冻结
### 6.1 数据证据
v8a ep5 vs ep11(stage 2 DOA ramp 打开前后):
```
ep5 ov3: ocls=44.4% printer=100%
ep11 ov3: ocls=44.4% printer=59% ← class 被 DOA 梯度扰动掉了
```
打开 dir/dist 的 Hungarian cost 和 loss 之后,class binding 被扰动。传统做法是整体降 LR,但那会连带把 trunk / decoder 的学习也拖慢。更好的做法是 **单独把 class_head 从 optimizer 拉出来,DOA ramp 期间冻结**。
### 6.2 代码改动
#### 6.2.1 新增 config 字段(`train_spatial_beats.py`)
```python
trunk_lr_scale: float = 1.0
spatial_lr_scale: float = 1.0
# v9: isolated LR multiplier for the class_head inside
# frame_track_prediction_heads. When < 1.0 the class head is put in its
# own param group with lr = base_lr * class_head_lr_scale. Used during
# DOA ramp (stage 2) to prevent class binding from being perturbed by
# the newly-unlocked dir/dist gradients. 1.0 = legacy behaviour.
class_head_lr_scale: float = 1.0
# Optional epoch-range override that further scales the class head LR
# specifically during the DOA ramp. When set, between
# frame_spatial_loss_warmup_epochs and frame_spatial_loss_warmup_epochs
# + class_head_freeze_during_ramp_epochs the class head LR is set to
# class_head_lr_scale_during_ramp (defaults to 0.0 = frozen). After the
# ramp window the LR returns to class_head_lr_scale.
class_head_freeze_during_ramp_epochs: int = 0
class_head_lr_scale_during_ramp: float = 0.0
```
#### 6.2.2 `build_optimizer` 拆出 cls_head group
- 新增 `_CLASS_HEAD_PREFIXES`:
```python
_CLASS_HEAD_PREFIXES = (
"frame_track_prediction_heads.class_head.",
"frame_track_prediction_heads.class_head_mlp.",
"frame_track_prediction_heads.class_head_demixer.",
)
```
- fast path 条件从 `trunk_scale == 1.0 and spatial_scale == 1.0` 改为再加 `and cls_head_scale == 1.0`;
- 当 `cls_head_scale != 1.0` 时,匹配 `_CLASS_HEAD_PREFIXES` 的参数从 head_params 抽出,放入 `cls_head_params`;
- 每个 param group 增加 `group_name` 字段(`"trunk"` / `"spatial"` / `"head"` / `"cls_head"`),便于 epoch loop 按 name 定位。
关键代码片段:
```python
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if name.startswith(_CLASS_HEAD_PREFIXES) and cls_head_scale != 1.0:
cls_head_params.append(param)
elif name.startswith(_TRUNK_PREFIXES):
trunk_params.append(param)
elif name.startswith(_SPATIAL_PREFIXES):
spatial_params.append(param)
else:
head_params.append(param)
param_groups = []
if trunk_params: param_groups.append({"params": trunk_params, "lr": base_lr * trunk_scale, "weight_decay": wd, "group_name": "trunk"})
if spatial_params: param_groups.append({"params": spatial_params, "lr": base_lr * spatial_scale, "weight_decay": wd, "group_name": "spatial"})
if head_params: param_groups.append({"params": head_params, "lr": base_lr, "weight_decay": wd, "group_name": "head"})
if cls_head_params:param_groups.append({"params": cls_head_params,"lr": base_lr * cls_head_scale, "weight_decay": wd, "group_name": "cls_head"})
```
#### 6.2.3 epoch loop 动态写 cls_head LR
紧跟 spatial loss schedule 之后(在 `_log(f"[Epoch {epoch}] start")` 之前):
```python
_cls_ramp_len = int(train_cfg.class_head_freeze_during_ramp_epochs)
if _cls_ramp_len > 0 and _sp_warmup > 0 and train_cfg.class_head_lr_scale != 1.0:
in_ramp = _sp_warmup <= epoch < _sp_warmup + _cls_ramp_len
if in_ramp:
_cls_scale = train_cfg.class_head_lr_scale_during_ramp
else:
_cls_scale = train_cfg.class_head_lr_scale
for _g in optimizer.param_groups:
if _g.get("group_name") == "cls_head":
_g["lr"] = train_cfg.learning_rate * _cls_scale
_log(
f"[Epoch {epoch}] cls_head_lr scale={_cls_scale:.3f} "
f"lr={train_cfg.learning_rate * _cls_scale:.2e} "
f"(ramp_window={_sp_warmup}..{_sp_warmup + _cls_ramp_len - 1})"
)
```
#### 6.2.4 v9 preset 设定
```python
cfg.class_head_lr_scale = 0.3
cfg.class_head_freeze_during_ramp_epochs = 4
cfg.class_head_lr_scale_during_ramp = 0.0
```
v9 的 `SPATIAL_LR=1.5e-5`,各阶段 cls_head LR:
| 阶段 | epoch | lambda_dir | cls_head_lr |
|---|---|---|---|
| stage 1(class-only warmup) | 0-2 | 0.0 | 4.5e-6 |
| stage 2(DOA ramp,cls_head 冻结) | 3-6 | ramp 0→1 | **0.0** |
| stage 3(全放开) | 7+ | 1.0 | 4.5e-6 |
## 7. Fix B —— 本体论(hierarchical)标签平滑
### 7.1 核心思想
v8/v8a 的 class 错误 **≥70% 是 sibling collapse**(同 AudioSet 父类)。Hard CE 把 `frog→bird``frog→aircraft` 一视同仁地惩罚满 log loss,这不合理:
- 对下游 LLM 而言 `frog↔bird` 混淆可以靠语义上下文恢复;
- `aircraft↔speech` 这种跨域错误则完全荒谬。
希望 loss 的惩罚强度 **匹配错误的 "语义距离"**。最简单的实现:在同父类内部做 label smoothing,跨父类保持硬 CE。
### 7.2 代码改动
#### 7.2.1 `SpatialLossConfig` 新字段(`spatial_loss.py`)
```python
# v9 hierarchical label smoothing for the frame-track class head.
# When frame_class_ontology_smoothing > 0, the CE target becomes a soft
# label distribution:
# target[c_gt] = 1 - eps
# target[c_sib] = eps / |siblings| (for each sibling in same ontology
# group as c_gt; excludes c_gt itself)
# target[c_other] = 0
# ...
frame_class_ontology_smoothing: float = 0.0
# Parallel list of sibling groups. Each entry is a list of class indices
# belonging to the same AudioSet ontology parent. A class may appear in
# only one group. Empty list = no hierarchical smoothing.
frame_class_ontology_groups: List[List[int]] = None
```
`__post_init__``None` 替换为 `[]` 以防意外。
#### 7.2.2 CE 分支改写(`spatial_loss.py:compute_frame_track_losses`)
原有分支:
```python
loss_class = F.cross_entropy(
class_logits_flat, class_target_flat, weight=_cls_weights
)
```
改为:
```python
eps_onto = float(config.frame_class_ontology_smoothing)
onto_groups = config.frame_class_ontology_groups
if eps_onto > 0.0 and onto_groups:
num_classes = class_logits_flat.size(-1)
# Build a [C, C] soft-target "mixing" table on first use and
# cache it on the config object to avoid per-batch rebuild.
if (
not hasattr(config, "_onto_mixing_table")
or config._onto_mixing_table is None
or config._onto_mixing_table.shape[0] != num_classes
or config._onto_mixing_table.dtype != class_logits_flat.dtype
or config._onto_mixing_table.device != device
):
table = torch.zeros((num_classes, num_classes), dtype=class_logits_flat.dtype, device=device)
table.fill_diagonal_(1.0)
for group in onto_groups:
members = [int(c) for c in group if 0 <= int(c) < num_classes]
if len(members) < 2:
continue
for c in members:
siblings = [s for s in members if s != c]
table[c].zero_()
table[c, c] = 1.0 - eps_onto
sib_mass = eps_onto / len(siblings)
for s in siblings:
table[c, s] = sib_mass
config._onto_mixing_table = table
soft_target = config._onto_mixing_table[class_target_flat]
log_probs = F.log_softmax(class_logits_flat, dim=-1)
if _cls_weights is not None:
sample_w = _cls_weights[class_target_flat]
per_sample_loss = -(soft_target * log_probs).sum(dim=-1)
loss_class = (per_sample_loss * sample_w).sum() / sample_w.sum().clamp_min(1e-8)
else:
loss_class = -(soft_target * log_probs).sum(dim=-1).mean)
else:
loss_class = F.cross_entropy(class_logits_flat, class_target_flat, weight=_cls_weights)
```
实现要点:
- **mixing table 缓存**:表存在 `config` 对象上(不是模块),第一次构建并缓存,之后按形状/device/dtype 重用,开销可忽略;
- **per-class weight 兼容**:当同时启用 ontology smoothing 和 class weights 时,正类权重照常生效(按 GT hard class idx 加权每样本 loss);
- **不在 group 里的类**:table 默认对角线 = 1,所以未列入任何 group 的类自然退化为 hard one-hot,零影响;
- **完全开关**`frame_class_ontology_smoothing=0``frame_class_ontology_groups=[]` 都会走原 `F.cross_entropy`,v8/v8a 训练复现不受影响。
### 7.3 Ontology groups(`_V9_ONTOLOGY_GROUPS`)
位置:`train_spatial_beats.py`,v9 preset 之前。
```python
_V9_ONTOLOGY_GROUPS: List[List[int]] = [
# transportation: aircraft, vehicle, train, car
[55, 19, 22, 44],
# human voice (non-singing): speech, human_vocalization, male_speech,
# female_speech, breathing, laughter
[16, 6, 31, 32, 13, 14],
# animal vocal: bird, frog, insect, dog, cat, animal
[8, 62, 35, 18, 33, 27],
# indoor mechanical + appliances: tool, machine, appliance, printer,
# home_sound, door, drawer_cabinet, kitchenware, camera, clock, typing,
# zipper, tape, cooking
[9, 10, 48, 57, 34, 30, 50, 26, 38, 39, 36, 37, 58, 61],
# percussive / impact: knock, footsteps, crack, crackle, crushing,
# scratch, finger_snapping, tearing, writing, paper
[52, 21, 60, 53, 56, 46, 54, 42, 43, 49],
# weather / water / ambience: wind, rain, thunderstorm, ocean, water,
# fire, glass, metal_clink, wood
[25, 45, 29, 51, 5, 40, 24, 12, 59],
# musical instruments: wind_instrument, string_instrument, guitar, drum,
# keyboard_instrument, percussion, musical_instrument, gong, bell,
# singing
[0, 1, 2, 4, 7, 15, 28, 47, 17, 41],
# alarms / signals: alarm, telephone_alarm, war_sound
[20, 23, 11],
]
```
**8 组,覆盖 62/63 个 class**(只有 `body_sound` 没归组,走 hard CE)。每个 class 仅出现在一个组中。
v9 preset:
```python
cfg.loss.frame_class_ontology_smoothing = 0.1
cfg.loss.frame_class_ontology_groups = [list(g) for g in _V9_ONTOLOGY_GROUPS]
```
## 8. Fix C —— 频谱级 demixing cross-attention
### 8.1 动机
- `track_time_features[B, K, T_s, D]``SourceQueryDecoder``fused_embeddings[B, T_s, D]` 里 decode 出来的;`fused_embeddings``frequency_pool` 之后,**频率维已经被池化掉了**
- DOA 能在多源重叠帧 demix(IV 通道物理上就编码了方向),但 class 没有等价物理解混通路;
- 需要让每个 track latent 能"回看"池化前的 trunk 输出,从 F_p 个频率 token 里挑自己负责的那部分。
### 8.2 模型结构
- 输入:
- `track_time_features: [B, K, T_s, D]`(已经过 `input_norm`)
- `pre_pool_features: [B, T_p * F_p, D]`(BEATs trunk 输出,无 task tokens,未经 frequency_pool)
- `pre_pool_grid_size: (T_p, F_p)`
- `pre_pool_time_mask: [B, T_p]`(True = 有效时间步)
- 时间对齐:frame `t ∈ [0, T_s)` → trunk 时间步 `t_p = round(t * T_p / T_s)`,clip 到 `[0, T_p-1]`
- KV 构造:`kv_grid[:, t_p, :, :]``[B, T_s, F_p, D]`,然后在 K 轴 expand 到 `[B, K, T_s, F_p, D]`,flatten 成 `[B*K*T_s, F_p, D]`
- Query:`track_time_features` reshape 成 `[B*K*T_s, 1, D]`
- 经 1 层 `nn.MultiheadAttention``num_heads=8`, `dropout=0.1``batch_first=True`),再 `out_proj(Linear 768→768)`,乘以标量 `gate`,加回 class head 的 `class_input`
### 8.3 初始化策略(关键)
- `out_proj.weight` 全零、`out_proj.bias` 全零 → **加载时 demixer 输出 = 0**,class_logits 与 v8a 完全相同;
- `gate = 1e-2`(**不是 0**!)→ 前向 = `gate * 0 = 0`(身份保证),但 `∂L/∂out_proj.weight = gate × ...` **非零**,梯度从 step 0 就能流进 demixer 的 attention 权重;这是关键的"gradient-warmup trick",否则两端 zero 会把 demixer 永久冻在 0。
### 8.4 `ClassHeadSpectralDemixer` 源码(`spatial_modules.py`)
```python
class ClassHeadSpectralDemixer(nn.Module):
def __init__(self, embed_dim=768, num_layers=1, num_heads=8, dropout=0.1):
super().__init__()
self.embed_dim = embed_dim
self.num_layers = max(1, int(num_layers))
self.kv_norm = nn.LayerNorm(embed_dim)
self.q_norm = nn.LayerNorm(embed_dim)
self.layers = nn.ModuleList([
nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,
dropout=dropout, batch_first=True)
for _ in range(self.num_layers)
])
self.out_proj = nn.Linear(embed_dim, embed_dim)
nn.init.zeros_(self.out_proj.weight)
nn.init.zeros_(self.out_proj.bias)
self.gate = nn.Parameter(torch.full((1,), 1e-2))
def forward(self, track_time_features, pre_pool_features,
pre_pool_grid_size, pre_pool_time_mask=None):
B, K, T_s, D = track_time_features.shape
T_p, F_p = int(pre_pool_grid_size[0]), int(pre_pool_grid_size[1])
if pre_pool_features.size(-1) != D:
raise ValueError(...)
expected = T_p * F_p
if pre_pool_features.size(1) != expected:
# Fall back gracefully — demixer is additive & zero-gated.
return track_time_features.new_zeros(track_time_features.shape)
kv_grid = pre_pool_features.view(B, T_p, F_p, D)
if T_s > 0 and T_p > 0:
time_idx = torch.arange(T_s, device=kv_grid.device).float() * (T_p / max(1, T_s))
time_idx = time_idx.round().clamp_(0, T_p - 1).long()
else:
time_idx = torch.zeros((T_s,), dtype=torch.long, device=kv_grid.device)
kv_per_frame = kv_grid[:, time_idx, :, :]
kv_per_frame = kv_per_frame.unsqueeze(1).expand(B, K, T_s, F_p, D).contiguous()
kv_flat = kv_per_frame.view(B * K * T_s, F_p, D)
kv_flat = self.kv_norm(kv_flat)
q_flat = track_time_features.reshape(B * K * T_s, 1, D)
q_flat = self.q_norm(q_flat)
key_padding_mask = None
if pre_pool_time_mask is not None:
per_frame_valid = pre_pool_time_mask[:, time_idx]
per_frame_valid = per_frame_valid.unsqueeze(1).expand(B, K, T_s).reshape(-1)
if not per_frame_valid.all():
ignore = ~per_frame_valid
key_padding_mask = ignore.unsqueeze(1).expand(-1, F_p).contiguous()
attn_out = q_flat
for layer in self.layers:
attn_out, _ = layer(attn_out, kv_flat, kv_flat,
key_padding_mask=key_padding_mask, need_weights=False)
residual = self.out_proj(attn_out).view(B, K, T_s, D)
return residual * self.gate
```
### 8.5 `_derive_pre_pool_time_mask` helper(`spatial_beats.py`)
紧跟 `_build_patch_padding_mask` 之后新增:
```python
def _derive_pre_pool_time_mask(
self,
patch_padding_mask: Optional[Tensor],
grid_size: Tuple[int, int],
) -> Optional[Tensor]:
"""Return a [B, T_p] boolean mask where True marks *valid* trunk time
steps. Used by the v9 class-head spectral demixer to ignore padded
tail frames."""
if patch_padding_mask is None:
return None
t_p, f_p = grid_size
B = patch_padding_mask.size(0)
pad_grid = patch_padding_mask.view(B, t_p, f_p)
time_valid = ~pad_grid.all(dim=-1)
return time_valid
```
`patch_padding_mask` 的语义是 `True = padded`,time-valid 是 "某时间步还有非 padded 频率位置" → `~pad_grid.all(dim=-1)`
### 8.6 forward 两处 call 更新(`spatial_beats.py`)
两处调用 `self.frame_track_prediction_heads(...)` 都扩展:
```python
_pre_pool_time_mask = self._derive_pre_pool_time_mask(
patch_padding_mask=patch_padding_mask,
grid_size=grid_size,
)
frame_track_prediction_output = self.frame_track_prediction_heads(
track_time_features=track_time_features,
track_latents=track_latents,
pre_pool_features=encoder_memory,
pre_pool_grid_size=grid_size,
pre_pool_time_mask=_pre_pool_time_mask,
)
```
- 第 1 处在 `readout_scheme == "local_spatial"` 分支下 frame-track parallel 路径;
- 第 2 处在 `readout_scheme == "local_spatial_track"` 分支下纯 per-frame 路径(v9 走这里)。
两处都用到的上下文变量:
- `encoder_memory``self.encode_patches(...)` 的输出,`[B, T_p*F_p, D]`**已经过 trunk 但未 frequency_pool**,正好是 demixer 需要的 pre_pool features;
- `grid_size`:`self.extract_patch_tokens(...)` 返回的 `(T_p, F_p)`;
- `patch_padding_mask`:`self._build_patch_padding_mask(...)` 返回。
## 9. Fix F —— 2-layer MLP 残差
### 9.1 动机
当前 class head 是 `nn.Linear(768, 63)`,对多源混合 token 的表达能力可能不够。加一个 2-layer MLP 残差 branch:
```
class_logits = class_head(x) + gate * class_head_mlp(x)
class_head_mlp = Linear(768, 1536) → GELU → Dropout → LayerNorm → Linear(1536, 63)
```
### 9.2 初始化策略
与 Fix C 的 demixer 同构:
- `class_head_mlp[-1].weight / bias` 全零 → 残差输出 = 0;
- `class_head_mlp_gate = 1e-2` → 前向仍然 = 0,但梯度可流进 MLP 最后一层(非零);
- 经过 1-2 个 step MLP 最后一层有非零权重后,前一层(GELU 前)也开始获得梯度。
### 9.3 `FrameTrackPredictionHeads` 重构(`spatial_modules.py`)
构造签名扩展:
```python
def __init__(
self,
embed_dim: int = 768,
num_classes: int = 63,
dropout: float = 0.1,
use_class_head_mlp_residual: bool = False,
class_head_mlp_hidden_multiplier: int = 2,
class_head_mlp_dropout: float = 0.1,
use_class_head_demixer: bool = False,
class_head_demixer_layers: int = 1,
class_head_demixer_heads: int = 8,
class_head_demixer_dropout: float = 0.1,
) -> None:
```
构造体内新增:
```python
self.use_class_head_mlp_residual = bool(use_class_head_mlp_residual)
if self.use_class_head_mlp_residual:
hidden = embed_dim * max(1, int(class_head_mlp_hidden_multiplier))
self.class_head_mlp = nn.Sequential(
nn.Linear(embed_dim, hidden),
nn.GELU(),
nn.Dropout(class_head_mlp_dropout),
nn.LayerNorm(hidden),
nn.Linear(hidden, num_classes),
)
nn.init.zeros_(self.class_head_mlp[-1].weight)
nn.init.zeros_(self.class_head_mlp[-1].bias)
self.class_head_mlp_gate = nn.Parameter(torch.full((1,), 1e-2))
else:
self.class_head_mlp = None
self.class_head_mlp_gate = None
self.use_class_head_demixer = bool(use_class_head_demixer)
if self.use_class_head_demixer:
self.class_head_demixer = ClassHeadSpectralDemixer(
embed_dim=embed_dim,
num_layers=class_head_demixer_layers,
num_heads=class_head_demixer_heads,
dropout=class_head_demixer_dropout,
)
else:
self.class_head_demixer = None
```
forward 改写:
```python
def forward(
self,
track_time_features: Tensor,
track_latents: Tensor,
pre_pool_features: Optional[Tensor] = None,
pre_pool_grid_size: Optional[Tuple[int, int]] = None,
pre_pool_time_mask: Optional[Tensor] = None,
) -> FrameTrackPredictionOutput:
...
x = self.input_norm(track_time_features)
activity = self.activity_head(x).squeeze(-1)
class_input = x
if (
self.class_head_demixer is not None
and pre_pool_features is not None
and pre_pool_grid_size is not None
):
demix_residual = self.class_head_demixer(
track_time_features=x,
pre_pool_features=pre_pool_features,
pre_pool_grid_size=pre_pool_grid_size,
pre_pool_time_mask=pre_pool_time_mask,
)
class_input = class_input + demix_residual
class_logits = self.class_head(class_input)
if self.class_head_mlp is not None and self.class_head_mlp_gate is not None:
class_logits = class_logits + self.class_head_mlp_gate * self.class_head_mlp(class_input)
direction = F.normalize(self.direction_head(x), dim=-1)
distance = F.softplus(self.distance_head(x)).squeeze(-1)
...
```
**关键细节**:
- demixer 的 residual 加到 `class_input` 上(即 class_head 的输入),而不是加到 logits 上;这样 demixer 得到的"纠偏"信号先经过 `class_head` 的共享投影再生成 logits;
- MLP 分支也看 `class_input`(包含 demixer residual),这样 demixer 的信息同样能进入 MLP 分支;
- `direction_head` 和 `distance_head` 的输入 `x` 不加 demixer(demixer 是 class-specific),保持 DOA head 对 v8a ckpt 的完全等价。
### 9.4 `SpatialBEATsConfig` 新字段(`spatial_beats.py`)
```python
self.frame_track_dropout: float = 0.1
self.frame_accdoa_hidden_dim: int = 256
self.frame_accdoa_dropout: float = 0.1
# v9: optional zero-initialised MLP residual branch inside
# FrameTrackPredictionHeads. ...
self.use_class_head_mlp_residual: bool = False
self.class_head_mlp_hidden_multiplier: int = 2
self.class_head_mlp_dropout: float = 0.1
# v9: optional spectral demixing cross-attention branch. ...
self.use_class_head_demixer: bool = False
self.class_head_demixer_layers: int = 1
self.class_head_demixer_heads: int = 8
self.class_head_demixer_dropout: float = 0.1
```
### 9.5 两处构造都传入新参数
`spatial_beats.py` 两处 `FrameTrackPredictionHeads(...)` 调用都改为:
```python
self.frame_track_prediction_heads = FrameTrackPredictionHeads(
embed_dim=cfg.encoder_embed_dim,
num_classes=cfg.source_num_classes,
dropout=cfg.frame_track_dropout,
use_class_head_mlp_residual=cfg.use_class_head_mlp_residual,
class_head_mlp_hidden_multiplier=cfg.class_head_mlp_hidden_multiplier,
class_head_mlp_dropout=cfg.class_head_mlp_dropout,
use_class_head_demixer=cfg.use_class_head_demixer,
class_head_demixer_layers=cfg.class_head_demixer_layers,
class_head_demixer_heads=cfg.class_head_demixer_heads,
class_head_demixer_dropout=cfg.class_head_demixer_dropout,
)
```
## 10. v9 preset 与启动脚本
### 10.1 `make_ov1_local_spatial_v9_ov123_top4_config`(`train_spatial_beats.py`)
位置:v7k 系列的最后一个 preset 之后,`v3: top-8 unfreeze` 注释分割线之前。
```python
def make_ov1_local_spatial_v9_ov123_top4_config(
ov1_manifest_path: str = DEFAULT_OV1_MANIFEST,
ov2_manifest_path: str = DEFAULT_OV2_MANIFEST,
ov3_manifest_path: str = DEFAULT_OV3_MANIFEST,
) -> TrainSpatialBEATsConfig:
"""v9 = v8a + class-first cleanup (fixes A..F).
Inherits v8a (cross-attn fusion + segment matching + 4-epoch DOA ramp),
applies:
- _V9_CLASS_WEIGHTS (suppress catch-all classes, boost frog/crackle/tape)
- ontology-aware label smoothing (eps=0.1)
- class head residual MLP + spectral demixer (both zero-init)
- class_head_lr_scale=0.3 with full freeze during the 4-epoch DOA ramp
Frontend / trunk / source_query_decoder / activity / dir / dist heads
are unchanged. Hot-start from v8a best.pt works with strict=False.
"""
cfg = make_ov1_local_spatial_v8a_ov123_top4_config(
ov1_manifest_path=ov1_manifest_path,
ov2_manifest_path=ov2_manifest_path,
ov3_manifest_path=ov3_manifest_path,
)
# (D) Re-balanced class weights driven by v8/v8a CSV confusion analysis.
cfg.loss.frame_class_loss_weights = list(_V9_CLASS_WEIGHTS)
# (B) Hierarchical (ontology-aware) label smoothing.
cfg.loss.frame_class_ontology_smoothing = 0.1
cfg.loss.frame_class_ontology_groups = [list(g) for g in _V9_ONTOLOGY_GROUPS]
# (F) Zero-gated MLP residual on the class head.
cfg.model.use_class_head_mlp_residual = True
cfg.model.class_head_mlp_hidden_multiplier = 2
cfg.model.class_head_mlp_dropout = 0.1
# (C) Zero-gated spectral demixing cross-attention on the class head.
cfg.model.use_class_head_demixer = True
cfg.model.class_head_demixer_layers = 1
cfg.model.class_head_demixer_heads = 8
cfg.model.class_head_demixer_dropout = 0.1
# (E) Class head gets its own LR group.
cfg.class_head_lr_scale = 0.3
cfg.class_head_freeze_during_ramp_epochs = 4
cfg.class_head_lr_scale_during_ramp = 0.0
cfg.num_epochs = 12
cfg.output_dir = "checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/03_ov123_top4"
return cfg
```
### 10.2 preset 注册
`train_spatial_beats.py` 中两处:
1. 在 `args.preset == "ov1_local_spatial_v8a_ov123_top4"` 分支之后,新增:
```python
elif args.preset == "ov1_local_spatial_v9_ov123_top4":
cfg = make_ov1_local_spatial_v9_ov123_top4_config(
ov1_manifest_path=args.ov1_manifest,
ov2_manifest_path=args.ov2_manifest,
ov3_manifest_path=args.ov3_manifest,
)
```
2. argparse `--preset` 的 `choices=(...)` 列表里在 `"ov1_local_spatial_v8a_ov123_top4"` 之后加入 `"ov1_local_spatial_v9_ov123_top4"`。
### 10.3 `run_ov1_v9_ov123_top4.sh`
新建文件(chmod +x):
```bash
#!/usr/bin/env bash
set -euo pipefail
# v9_ov123_top4: v8a + class-first cleanup (Fix A..F from docs/0423.md analysis)
# 所有 fix 都是 additive + zero-init,从 v8a best.pt 热启动 epoch-0 输出与 v8a 完全相同
GPUS="${GPUS:-8}"
BATCH_SIZE="${BATCH_SIZE:-8}"
NUM_WORKERS="${NUM_WORKERS:-8}"
SPATIAL_EPOCHS="${SPATIAL_EPOCHS:-12}"
SPATIAL_LR="${SPATIAL_LR:-1.5e-5}"
AMP="${AMP:-fp32}"
OV1_MANIFEST="${OV1_MANIFEST:-/apdcephfs_cq10/.../ov1_foa.jsonl}"
OV2_MANIFEST="${OV2_MANIFEST:-/apdcephfs_cq10/.../ov2_foa.jsonl}"
OV3_MANIFEST="${OV3_MANIFEST:-/apdcephfs_cq10/.../ov3_foa.jsonl}"
RESUME_CKPT="${RESUME_CKPT:-checkpoints/spatial_beats_ov1_local_spatial_v8a_ov123_exp/03_ov123_top4/best.pt}"
OUT_DIR="${OUT_DIR:-checkpoints/spatial_beats_ov1_local_spatial_v9_ov123_exp/03_ov123_top4}"
torchrun --nproc_per_node="${GPUS}" --master-port="${MASTER_PORT:-29557}" train_spatial_beats.py \
--preset ov1_local_spatial_v9_ov123_top4 \
--resume "${RESUME_CKPT}" \
--output-dir "${OUT_DIR}" \
--ov1-manifest "${OV1_MANIFEST}" \
--ov2-manifest "${OV2_MANIFEST}" \
--ov3-manifest "${OV3_MANIFEST}" \
--batch-size "${BATCH_SIZE}" \
--num-workers "${NUM_WORKERS}" \
--num-epochs "${SPATIAL_EPOCHS}" \
--learning-rate "${SPATIAL_LR}" \
--amp "${AMP}" \
--no-resume-optimizer \
--reset-epoch-on-resume \
--reset-best-on-resume
```
---
## 11. 正确性验证(已完成)
### 11.1 语法检查
四个 py 文件(`train_spatial_beats.py` / `spatial_loss.py` / `spatial_beats.py` / `spatial_modules.py`)+ `run_ov1_v9_ov123_top4.sh` 全部通过 ast.parse / bash -n。
### 11.2 模块级单测
`FrameTrackPredictionHeads` + `ClassHeadSpectralDemixer`:
- 开启所有 v9 选项后,**max abs diff = 0.00e+00** 相对于无 v9 选项的同一模型;
- 带 padding mask 时 `pred_class_logits` 无 NaN;
- 打开 MLP gate / demixer gate + 轻微扰动权重后,logits 差异 ~1e-2(符合预期)。
### 11.3 v9 vs v8a 端到端前向恒等
```python
torch.manual_seed(42)
m_v8a = SpatialBEATs(make_ov1_local_spatial_v8a_ov123_top4_config().model).eval()
torch.manual_seed(42)
m_v9 = SpatialBEATs(make_ov1_local_spatial_v9_ov123_top4_config().model).eval()
# 两个模型都 load v8a/best.pt,strict=False
# 同一 waveform 输入
max_abs_diff = (o_v8a - o_v9).abs().max()
# 实测:0.00e+00
```
**v9 加载 v8a ckpt 后的 `pred_class_logits` 与 v8a 模型加载同一 ckpt 的输出逐元素完全相等**(0.00e+00)。
### 11.4 v8a ckpt 加载统计
```
ckpt keys = 425
loadable (v9 shapes) = 425
v9 missing params = 18 ← class_head_mlp.* (7 个) + class_head_demixer.* (11 个)
ckpt unexpected = 0
```
18 个新参数通过 `strict=False` 默认初始化(按 Fix C / F 的 zero-out-proj + tiny-gate 策略)。
### 11.5 梯度流
假 CE loss + backward:
- `class_head.weight`:grad_norm ≈ 9.06(正常)
- `class_head_mlp.4.weight`(MLP 最后一层,zero-init):grad_norm ≈ 1.32e-1(**非零,因为 gate = 1e-2**)
- `class_head_mlp.0.weight`(MLP 第一层):grad_norm = 0(正常,要等最后一层非零后才有梯度,1 step 内开解)
- `class_head_demixer.out_proj.weight`:grad_norm ≈ 1.60e-2(**非零**)
- `class_head_demixer.layers.0.in_proj_weight`:grad_norm = 0(同理)
### 11.6 Optimizer 分组
```
group 0: name=trunk lr=3.00e-06 n_params=238
group 1: name=spatial lr=9.00e-06 n_params= 84
group 2: name=head lr=3.00e-05 n_params= 91
group 3: name=cls_head lr=9.00e-06 n_params= 19
```
(上例 base_lr=3e-5;v9 启动脚本里 SPATIAL_LR=1.5e-5,对应 cls_head lr = 4.5e-6。)
## 12. 关键训练时间线(SPATIAL_EPOCHS=12)
继承 v8a 的 `frame_spatial_loss_warmup_epochs=3`、`frame_spatial_loss_ramp_epochs=4`,叠加 v9 新的 cls_head LR 调度:
| epoch | lambda_dir / dist | dir/dist match cost | cls_head lr | 说明 |
|---|---|---|---|---|
| 0-2 | 0.0 | 0.0 | 4.5e-6 | Stage 1:class-only warmup,cls_head 低 LR 继续微调(但本来就 v8a 延续) |
| 3 | ramp 0.25 | 0.25 | **0.0** | Stage 2 起点,cls_head 冻结 |
| 4 | ramp 0.50 | 0.50 | **0.0** | cls_head 冻结 |
| 5 | ramp 0.75 | 0.75 | **0.0** | cls_head 冻结 |
| 6 | ramp 1.00 | 1.00 | **0.0** | cls_head 冻结,DOA 全开 |
| 7-11 | 1.0 | 1.0 | 4.5e-6 | Stage 3:全部放开,cls_head 低 LR 微调 |
## 13. 验证清单 / 观察优先级
启动后按以下顺序观察指标:
1. **epoch 0 validation metrics 是否与 v8a 的最后一个 epoch 等价**
- 预期 val loss ≈ v8a best val loss(因为前向 = v8a)
- 若不等价,说明 hot-start 出问题;
2. **epoch 0-2(stage 1)**
- `ocls`(oracle class acc)是否比 v8a ep2 同期高?关键指标,主要受 Fix D + B + C + F 驱动;
- 若 Fix D 生效:aircraft / vehicle 不再 0%,frog 不再 100%→bird;
- 若 Fix B 生效:smooth CE loss 比 hard CE 降得更快(从 1 - eps 开始);
- 若 Fix C/F 生效:每 epoch 训完后 `class_head_mlp_gate` / `class_head_demixer.gate` 应从 0.01 慢慢增长;
3. **epoch 3-6(stage 2,cls_head 冻结)**:
- DOA 指标改善(`LE_CD` 下降,`oazi` 下降),`ocls` **不应退化**(cls_head LR=0,梯度不回传);
- 若 `ocls` 退化:可能是 fusion/trunk 的梯度通过 demixer 影响了 class_head 输入分布,此时需要考虑把 demixer 也一起冻;
4. **epoch 7+(stage 3,全开)**:
- `F20` 是否突破 v8a 的上限(v7h 基线 0.246,v8a ~ 0.22-0.25);
- ov3 class_ok 是否从 43% 往上走;
- 若 ov2/ov3 class_ok 都有改善但 ov3 仍然落后:Fix C 的频谱 demixing 还需要再加层数或 heads。
## 14. 代码 diff 总览
### 14.1 `spatial_loss.py`
- 新增 2 个 `SpatialLossConfig` 字段:`frame_class_ontology_smoothing`、`frame_class_ontology_groups`;
- `__post_init__` 处理 `frame_class_ontology_groups = None`;
- `compute_frame_track_losses` 里 CE 分支支持 soft target + class weight 复合。
### 14.2 `spatial_modules.py`
- `FrameTrackPredictionHeads.__init__` 新增 6 个 kwargs;
- `FrameTrackPredictionHeads.forward` 新增 3 个 Optional kwargs(`pre_pool_features` / `pre_pool_grid_size` / `pre_pool_time_mask`),内部整合 demixer + MLP residual;
- 新增 `ClassHeadSpectralDemixer` 模块。
### 14.3 `spatial_beats.py`
- `SpatialBEATsConfig.__init__` 新增 8 个 `self.*` 字段;
- 两处 `FrameTrackPredictionHeads(...)` 构造传入全部新参数;
- 两处 `self.frame_track_prediction_heads(...)` 调用传入 pre-pool 参数;
- 新增 `SpatialBEATs._derive_pre_pool_time_mask` 实例方法。
### 14.4 `train_spatial_beats.py`
- `TrainSpatialBEATsConfig` 新增 3 个字段:`class_head_lr_scale`、`class_head_freeze_during_ramp_epochs`、`class_head_lr_scale_during_ramp`;
- 新增常量 `_V9_CLASS_WEIGHTS`(63 长度)+ `_V9_ONTOLOGY_GROUPS`(8 组,覆盖 62/63 class);
- 新增 preset 函数 `make_ov1_local_spatial_v9_ov123_top4_config`;
- `build_optimizer` 新增 `_CLASS_HEAD_PREFIXES`、`cls_head_params` 分组、`group_name` 标签、fast-path 条件更新;
- epoch loop 动态写 cls_head LR(紧跟 spatial loss schedule 后);
- preset dispatch 和 `--preset` argparse choices 注册 `ov1_local_spatial_v9_ov123_top4`。
### 14.5 `run_ov1_v9_ov123_top4.sh`
- 新建 shell 脚本,默认 8 GPU / BS=8 / 12 epoch / LR=1.5e-5 / fp32 / master_port=29557;
- 默认 RESUME_CKPT 指向 v8a best.pt;
- 复用 `--no-resume-optimizer --reset-epoch-on-resume --reset-best-on-resume` 三开关。
## 15. 关键文件与行数
| 文件 | 新增代码主要位置 |
|---|---|
| `spatial_loss.py` | `SpatialLossConfig` 新增字段紧跟 `frame_class_loss_weights` 之后;CE 分支改写在 `compute_frame_track_losses` 内 |
| `spatial_modules.py` | `FrameTrackPredictionHeads` 整块重写;`ClassHeadSpectralDemixer` 定义在 `FrameTrackPredictionHeads` 之后 |
| `spatial_beats.py` | `SpatialBEATsConfig.__init__` 的 v9 字段紧跟 `frame_accdoa_dropout`;`_derive_pre_pool_time_mask` 紧跟 `_build_patch_padding_mask` |
| `train_spatial_beats.py` | `_V9_CLASS_WEIGHTS` 紧跟 `_V7I_CLASS_WEIGHTS` 之后;`_V9_ONTOLOGY_GROUPS` 和 `make_ov1_local_spatial_v9_ov123_top4_config` 紧跟 v7k 系列结束处;`class_head_*_scale` 字段紧跟 `spatial_lr_scale` 之后 |
## 16. 后续计划
如果 v9 ep3 `ocls` 仍不超过 v8a 同期,下一步(v9a/v9b)候选:
1. **demixer 加层**:`class_head_demixer_layers = 2` 或增加 heads 到 16;
2. **demixer 冻结共进退**:DOA ramp 期把 `class_head_demixer.*` 也当作 cls_head group 的一部分(目前 `_CLASS_HEAD_PREFIXES` 已经包含它,在 ramp 期也会冻结——已经生效,但需要验证效果);
3. **ontology smoothing eps 提升**:0.1 → 0.2,进一步容忍 sibling collapse;
4. **source_query_decoder 加正交正则**(前一版建议):`L_orth = ||Q Q^T - I||_F^2`,λ=0.01;
5. **重新审视 matching**:如果 cls 瓶颈解了但 F20 仍不涨,回到 `docs/0422.md` 的 track-dead / duplicate 诊断。
但当前 v9 已经把 "class head 本身" 这一块做得比较彻底,应该先跑满 12 epoch 再决定。