Spatial-BEATs / docs /0413.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
14 kB
# 0413 当前 Spatial-BEATs 基线总结
本文档记录 2026-04-13 时点仓库内当前可复现的 `Spatial-BEATs` 基线状态。
重点不是回顾全部试错,而是明确:
- 当前应该以哪个 checkpoint / preset 为基线
- 当前模型架构到底是什么
- 当前训练和验证流程如何工作
- 这条线已经证明了什么,没证明什么
- 后续继续改时,哪些地方不要再误改
当前本地快照提交:
```text
b782333 snapshot: restore 0410 local_spatial baseline state
```
## 1. 当前基线对象
当前恢复并确认的基线是:
```text
train_spatial_beats.py --preset ov1_local_spatial
```
对应的历史参考 checkpoint 是:
```text
checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt
```
这条线对应的设计目标是:
-`W` 通道的 BEATs 路径提供语义时间序列
-`WXYZ + IVxyz` 的 local spatial branch 提供局部空间时间序列
- 两者在 token 时间轴上对齐后相加融合
- 用一个单路 fused temporal token 序列同时做:
- clip-level class
- direction
- distance
- 最终 `llm_spatial_tokens`
这条线不是:
- `semantic_aux_classifier` 双头版本
- 原始 Kaldi fbank W semantic special path 版本
- class/spatial 分别从不同 token 读出的版本
这些后续尝试都已经从当前基线中移除。
## 2. 当前模型架构
当前 `ov1_local_spatial` 的主路径如下:
```text
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> foa_feat [B, 7, T_f, F]
channels = [W, X, Y, Z, IVx, IVy, IVz]
-> extract_patch_tokens()
base path:
W_logmel -> original single-channel BEATs patch embedding
spatial delta path:
7ch foa_feat -> SpatialDeltaPatchAdapter -> delta_patch_tokens
patch_tokens = base_patch_tokens + delta_patch_tokens
-> BEATs trunk
-> frequency_pool
-> TemporalResampler(target_token_rate = 2.5 Hz)
-> ShallowTemporalReadout
-> semantic_embeddings [B, T_s, D]
-> LocalSpatialEncoder(foa_feat) -> local_patch_rate_tokens [B, T_f, D_s]
-> LocalSpatialResampler -> local_spatial_tokens [B, T_s, D_s]
-> local_spatial_proj: Linear(D_s, D)
-> fused_embeddings = LayerNorm(semantic_embeddings + local_update)
-> LocalSpatialPredictionHeads(fused_embeddings)
-> class pooled token
-> spatial pooled token
-> pred_class_logits [B, 65]
-> pred_direction [B, 3]
-> pred_distance [B, 1]
-> SpatialTokenProjector(fused_embeddings)
-> llm_spatial_tokens [B, T_s, d_llm]
```
### 2.1 FOA 前端
代码入口:
- [spatial_modules.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_modules.py)
- [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py)
当前前端行为:
- 输入波形按 DCASE 存储顺序 `[W, Y, Z, X]`
- `SpatialBEATsPreprocessor` 内部先重排到 `[W, X, Y, Z]`
- 计算:
- `WXYZ logmel`
- `IVx, IVy, IVz`
- 输出 `foa_feat [B, 7, T_f, F]`
这里仍然是当前基线的一部分。
也就是说,当前基线不是完全复用原始 `BEATs.py preprocess()`,而是:
- 用自实现 STFT/mel 前端
- 再用 BEATs 统计量做归一化
这点非常重要,因为后续很多分类迁移问题都和这里有关。
### 2.2 Patch token 构造
当前 patch 输入仍然是“W base + 7ch delta”风格:
- `patch_embedding.proj` 只接收 `W_logmel`
- `SpatialDeltaPatchAdapter` 从完整 `foa_feat` 生成 patch-level residual update
- 最终送入 trunk 的是:
```text
patch_tokens = base_patch_tokens + delta_patch_tokens
```
当前 `ov1_local_spatial` preset 中:
- `train_patch_embedding_in_stage1 = False`
- `train_spatial_adapter_in_stage1 = False`
- `patch_adapter_residual_alpha_init = 0.0`
- `patch_adapter_out_proj_scale_init = 0.0`
也就是说,在这条恢复后的 run1 基线里:
- trunk 冻结
- patch embedding 冻结
- patch delta adapter 也不训练
真正训练的是后面的 local spatial 分支和 fused readout。
### 2.3 BEATs trunk 路径
当前 trunk 保持的是原 BEATs 主干:
- `layer_norm`
- `post_extract_proj`
- `encoder.pos_conv`
- `encoder.layers.*`
- `encoder.layer_norm`
预训练初始化来自:
```text
pretrain_ckpt/BEATs_iter3_plus_AS2M.pt/BEATs_iter3_plus_AS2M.pt
```
加载逻辑在:
- [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py)
当前日志行为应类似:
```text
[SpatialBEATs] Reusing 201 compatible BEATs keys
[SpatialBEATs] Reusing original single-channel BEATs patch embedding
```
### 2.4 Local spatial 分支
当前 spatial branch 是当前基线的核心增量,结构为:
- `LocalSpatialEncoder`
- 2D CNN 提取局部多通道空间模式
- 频率池化
- temporal Transformer
- `TemporalResampler`
- 把 local spatial token 序列重采样到和 semantic token 同样的 `T_s`
- `local_spatial_proj`
- `Linear(D_s, D)`,小尺度初始化
- `local_spatial_fusion_norm`
- `LayerNorm(semantic + local_update)`
代码位置:
- [spatial_modules.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_modules.py)
- [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py)
### 2.5 Prediction heads
当前 `LocalSpatialPredictionHeads` 是 0410 那条线真正使用的版本:
- 输入只有 `fused_tokens`
- class 和 spatial 都从同一个 fused sequence 上做 attention pooling
- 不是 class 从 semantic token 读、spatial 从 fused token 读的后改版本
结构:
```text
fused_tokens [B, T_s, D]
-> class_score -> class attention pooling -> class_token
-> spatial_score -> spatial attention pooling -> spatial_token
class_token -> class_head -> pred_class_logits
spatial_token -> direction_head -> pred_direction (normalized)
spatial_token -> distance_head -> pred_distance (softplus)
```
初始化细节:
- `class_score.weight/bias = 0`
- `spatial_score.weight/bias = 0`
所以初始 attention pooling 等价于均匀 mean-pool。
这就是为什么它可以比较平滑地接入之前的 class head checkpoint。
## 3. 当前 supervision 与 loss
当前 `ov1_local_spatial` 使用:
```text
cfg.loss.supervision_mode = "mono_ast"
```
对应 loss 计算在:
- [spatial_loss.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_loss.py)
当前这条线不是 Hungarian slot matching,而是单源 clip-level supervision:
- 每个样本必须只有一个有效源
- 从第一个有效源读取:
- class
- azimuth
- elevation
- distance
loss 定义:
```text
loss_cls_aux = cross_entropy(pred_class_logits, cls_target)
loss_direction = 1 - cos(pred_direction, gt_direction)
loss_dist = smooth_l1(pred_distance, gt_distance)
loss_total =
lambda_cls_aux * loss_cls_aux
+ lambda_direction * loss_direction
+ lambda_dist * loss_dist
```
当前 `ov1_local_spatial` 默认权重:
```text
lambda_cls_aux = 1.0
lambda_direction = 12.0
lambda_dist = 2.0
```
valid metrics:
- `class_acc`
- `azi_mae_deg`
- `ele_mae_deg`
- `dist_mae`
### 3.1 Active window mask
虽然是单源 clip-level readout,这条线仍然会构建弱时间窗 mask:
- `build_primary_source_window_mask()`
- 根据 source start/end time 映射到 token 轴
- 作为 `active_window_mask` 提供给 `LocalSpatialPredictionHeads`
这意味着 pooled class/spatial token 默认更关注标注的 active window,而不是整段所有 token。
## 4. 当前训练配置
当前恢复后的 `ov1_local_spatial` preset 在:
- [train_spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/train_spatial_beats.py)
关键配置:
```text
batch_size = 8
num_workers = 4
num_epochs = 20
learning_rate = 1e-4
weight_decay = 0.05
freeze_trunk_in_stage1 = True
unfreeze_full_trunk = False
train_patch_embedding_in_stage1 = False
train_spatial_adapter_in_stage1 = False
freeze_projector_by_default = True
dataset.max_clip_duration_seconds = 20.0
dataset.crop_mode = "start"
best_metric_name = "azi_mae_deg"
minimize_best_metric = True
class_finetuned_ckpt = checkpoints/beats_ov1_event_cls_head_only/best.pt
```
当前最重要的初始化来源有两个:
1. `BEATs pretrained trunk`
2. `checkpoints/beats_ov1_event_cls_head_only/best.pt`
第二个 checkpoint 的加载策略是:
```text
beats.patch_embedding.weight -> patch_embedding.proj.weight
beats.patch_embedding.bias -> patch_embedding.proj.bias
beats.layer_norm.* -> layer_norm.*
beats.post_extract_proj.* -> post_extract_proj.*
beats.encoder.* -> encoder.*
classifier.weight -> local_spatial_prediction_heads.class_head.weight
classifier.bias -> local_spatial_prediction_heads.class_head.bias
```
所以当前 `ov1_local_spatial` 的初始分类能力,并不是随机的,而是:
- 先来自 W-channel 事件分类 baseline 的 class head
- 再接到 fused temporal readout 上
## 5. 当前 checkpoint 已证明的结果
参考 checkpoint:
```text
checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt
```
这条 run 的已知最佳结果是:
```text
best epoch = 7
best metric = azi_mae_deg
train:
loss_total = 4.0454
loss_cls_aux = 1.3232
loss_direction = 0.1966
loss_dist = 0.1816
class_acc = 0.6184
azi_mae_deg = 24.68
ele_mae_deg = 10.07
dist_mae = 0.465
valid:
loss_total = 5.1220
loss_cls_aux = 2.5258
loss_direction = 0.1856
loss_dist = 0.1847
class_acc = 0.3985
azi_mae_deg = 23.52
ele_mae_deg = 9.65
dist_mae = 0.475
```
更关键的是 `epoch_0000` 的行为:
```text
train class_acc ≈ 0.246
val class_acc ≈ 0.244
val azi_mae ≈ 71.49
val ele_mae ≈ 71.54
```
这说明 run1 的初始状态是:
- class 明显高于随机
- 空间一开始很差
这正是后续 0413 恢复工作所要回到的行为。
## 6. 0413 已排除掉的错误分支
为了避免后续再把这条基线改坏,下面这些都已经确认不是当前基线的一部分:
### 6.1 不是 `semantic_aux_classifier` 版本
后面曾尝试:
- 在训练期额外挂一个完全独立的 semantic auxiliary classifier
- 让 class loss 不再从 fused token 读
结果:
- 这条线改变了初始化行为
- 会让 `Epoch 0` class 接近随机
- 与原始 run1 不一致
现在已经从当前基线彻底移除。
### 6.2 不是 `use_original_beats_semantic_frontend_for_local_spatial` 版本
后面曾尝试:
- 让 local_spatial 路径里的 semantic branch 改走原始 W-only Kaldi fbank + BEATs
结果:
- 这条线会明显改变 local_spatial 的初始化分布
- 也是导致后续 “空间一开始很好、class 一开始很差” 的原因之一
现在已经从当前基线移除。
### 6.3 不是 class/spatial 分开读不同 token 的版本
后面曾尝试:
- class 从 `semantic_tokens` 池化
- spatial 从 `fused_tokens` 池化
这同样改变了 0410 run1 的行为。
当前已经恢复成:
- class 和 spatial 都从 `fused_tokens`
## 7. 当前仓库里额外存在但不属于这条基线的内容
当前仓库除了 `ov1_local_spatial` 之外,还保留了一些后续研发分支:
- `ov1_ast`
- `ov1_pretrunk_ast_*`
- `ov123_local_spatial_slot`
- `ov123_local_spatial_track`
- `ov123_local_spatial_accdoa`
这些分支仍然在代码里,但它们不是当前这份文档关注的“恢复后的单源 local_spatial 基线”。
后续如果继续研究 frame-level / slot-level 多源建模,应当:
- 把它们视为新的实验线
- 不要再直接污染 `ov1_local_spatial`
## 8. 当前推荐的复现命令
恢复后的基线复现命令:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \
train_spatial_beats.py \
--preset ov1_local_spatial \
--distributed \
--batch-size 8 \
--num-workers 24 \
--num-epochs 20 \
--output-dir checkpoints/spatial_beats_ov1_local_spatial_run1
```
如果只想做行为检查,先跑 1 个 epoch 就够:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \
train_spatial_beats.py \
--preset ov1_local_spatial \
--distributed \
--batch-size 8 \
--num-workers 24 \
--num-epochs 1 \
--output-dir checkpoints/spatial_beats_ov1_local_spatial_repro_check
```
判断是否仍在 run1 轨道上的最直接标准:
- `Epoch 0 class_acc` 应该明显高于随机
- `Epoch 0 azi/ele` 应该仍然很差
而不是:
- `Epoch 0 class` 接近 0
- `Epoch 0 spatial` 一开始就很好
## 9. 目前这条线的结论
当前 `ov1_local_spatial` 基线已经证明:
1. 单路 fused token 架构是能学到空间的。
2. 这条线能在不动 trunk 的前提下,把 azimuth / elevation / distance 拉下来。
3. 但它的 class 保持能力有限,最佳 valid `class_acc` 约 `0.40`,明显低于纯 W-channel BEATs 分类 baseline。
4. 所以这条线当前更适合作为:
- “可工作空间基线”
- 后续 frame-level spatial 扩展的 warm-start
它还没有证明:
1. class 和 spatial 可以同时都做到强泛化。
2. fused token 已经足够好,可以直接无脑接 LLM。
## 10. 后续继续迭代时的建议
如果以后继续改,不要再直接破坏 `ov1_local_spatial` 这条基线。更稳的方式是:
1. 保留 `ov1_local_spatial` 不动,作为 frozen baseline。
2. 新实验单独开新 preset。
3. 新实验优先从:
- `checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt`
- 或 `DEFAULT_OV1_LOCAL_SPATIAL_INIT`
warm-start。
4. 任何涉及下面这些点的修改,都应视为“新架构”,不要再说和 run1 一样:
- semantic aux classifier
- 原始 W-BEATs special frontend
- class/spatial 分离 readout
- 新的 frame-level supervision path
当前最安全的角色定位是:
```text
ov1_local_spatial = 已恢复并可复现的 0410 单源 local-spatial baseline
```