| # 0413 当前 Spatial-BEATs 基线总结 |
|
|
| 本文档记录 2026-04-13 时点仓库内当前可复现的 `Spatial-BEATs` 基线状态。 |
| 重点不是回顾全部试错,而是明确: |
|
|
| - 当前应该以哪个 checkpoint / preset 为基线 |
| - 当前模型架构到底是什么 |
| - 当前训练和验证流程如何工作 |
| - 这条线已经证明了什么,没证明什么 |
| - 后续继续改时,哪些地方不要再误改 |
|
|
| 当前本地快照提交: |
|
|
| ```text |
| b782333 snapshot: restore 0410 local_spatial baseline state |
| ``` |
|
|
| ## 1. 当前基线对象 |
|
|
| 当前恢复并确认的基线是: |
|
|
| ```text |
| train_spatial_beats.py --preset ov1_local_spatial |
| ``` |
|
|
| 对应的历史参考 checkpoint 是: |
|
|
| ```text |
| checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt |
| ``` |
|
|
| 这条线对应的设计目标是: |
|
|
| - 用 `W` 通道的 BEATs 路径提供语义时间序列 |
| - 用 `WXYZ + IVxyz` 的 local spatial branch 提供局部空间时间序列 |
| - 两者在 token 时间轴上对齐后相加融合 |
| - 用一个单路 fused temporal token 序列同时做: |
| - clip-level class |
| - direction |
| - distance |
| - 最终 `llm_spatial_tokens` |
|
|
| 这条线不是: |
|
|
| - `semantic_aux_classifier` 双头版本 |
| - 原始 Kaldi fbank W semantic special path 版本 |
| - class/spatial 分别从不同 token 读出的版本 |
|
|
| 这些后续尝试都已经从当前基线中移除。 |
|
|
| ## 2. 当前模型架构 |
|
|
| 当前 `ov1_local_spatial` 的主路径如下: |
|
|
| ```text |
| FOA waveform [B, 4, T] |
| -> SpatialBEATsPreprocessor |
| -> foa_feat [B, 7, T_f, F] |
| channels = [W, X, Y, Z, IVx, IVy, IVz] |
| |
| -> extract_patch_tokens() |
| base path: |
| W_logmel -> original single-channel BEATs patch embedding |
| spatial delta path: |
| 7ch foa_feat -> SpatialDeltaPatchAdapter -> delta_patch_tokens |
| patch_tokens = base_patch_tokens + delta_patch_tokens |
| |
| -> BEATs trunk |
| -> frequency_pool |
| -> TemporalResampler(target_token_rate = 2.5 Hz) |
| -> ShallowTemporalReadout |
| -> semantic_embeddings [B, T_s, D] |
| |
| -> LocalSpatialEncoder(foa_feat) -> local_patch_rate_tokens [B, T_f, D_s] |
| -> LocalSpatialResampler -> local_spatial_tokens [B, T_s, D_s] |
| -> local_spatial_proj: Linear(D_s, D) |
| -> fused_embeddings = LayerNorm(semantic_embeddings + local_update) |
| |
| -> LocalSpatialPredictionHeads(fused_embeddings) |
| -> class pooled token |
| -> spatial pooled token |
| -> pred_class_logits [B, 65] |
| -> pred_direction [B, 3] |
| -> pred_distance [B, 1] |
| |
| -> SpatialTokenProjector(fused_embeddings) |
| -> llm_spatial_tokens [B, T_s, d_llm] |
| ``` |
|
|
| ### 2.1 FOA 前端 |
|
|
| 代码入口: |
|
|
| - [spatial_modules.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_modules.py) |
| - [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py) |
|
|
| 当前前端行为: |
|
|
| - 输入波形按 DCASE 存储顺序 `[W, Y, Z, X]` |
| - `SpatialBEATsPreprocessor` 内部先重排到 `[W, X, Y, Z]` |
| - 计算: |
| - `WXYZ logmel` |
| - `IVx, IVy, IVz` |
| - 输出 `foa_feat [B, 7, T_f, F]` |
|
|
| 这里仍然是当前基线的一部分。 |
| 也就是说,当前基线不是完全复用原始 `BEATs.py preprocess()`,而是: |
|
|
| - 用自实现 STFT/mel 前端 |
| - 再用 BEATs 统计量做归一化 |
|
|
| 这点非常重要,因为后续很多分类迁移问题都和这里有关。 |
|
|
| ### 2.2 Patch token 构造 |
|
|
| 当前 patch 输入仍然是“W base + 7ch delta”风格: |
|
|
| - `patch_embedding.proj` 只接收 `W_logmel` |
| - `SpatialDeltaPatchAdapter` 从完整 `foa_feat` 生成 patch-level residual update |
| - 最终送入 trunk 的是: |
|
|
| ```text |
| patch_tokens = base_patch_tokens + delta_patch_tokens |
| ``` |
|
|
| 当前 `ov1_local_spatial` preset 中: |
|
|
| - `train_patch_embedding_in_stage1 = False` |
| - `train_spatial_adapter_in_stage1 = False` |
| - `patch_adapter_residual_alpha_init = 0.0` |
| - `patch_adapter_out_proj_scale_init = 0.0` |
|
|
| 也就是说,在这条恢复后的 run1 基线里: |
|
|
| - trunk 冻结 |
| - patch embedding 冻结 |
| - patch delta adapter 也不训练 |
|
|
| 真正训练的是后面的 local spatial 分支和 fused readout。 |
|
|
| ### 2.3 BEATs trunk 路径 |
|
|
| 当前 trunk 保持的是原 BEATs 主干: |
|
|
| - `layer_norm` |
| - `post_extract_proj` |
| - `encoder.pos_conv` |
| - `encoder.layers.*` |
| - `encoder.layer_norm` |
|
|
| 预训练初始化来自: |
|
|
| ```text |
| pretrain_ckpt/BEATs_iter3_plus_AS2M.pt/BEATs_iter3_plus_AS2M.pt |
| ``` |
|
|
| 加载逻辑在: |
|
|
| - [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py) |
|
|
| 当前日志行为应类似: |
|
|
| ```text |
| [SpatialBEATs] Reusing 201 compatible BEATs keys |
| [SpatialBEATs] Reusing original single-channel BEATs patch embedding |
| ``` |
|
|
| ### 2.4 Local spatial 分支 |
|
|
| 当前 spatial branch 是当前基线的核心增量,结构为: |
|
|
| - `LocalSpatialEncoder` |
| - 2D CNN 提取局部多通道空间模式 |
| - 频率池化 |
| - temporal Transformer |
| - `TemporalResampler` |
| - 把 local spatial token 序列重采样到和 semantic token 同样的 `T_s` |
| - `local_spatial_proj` |
| - `Linear(D_s, D)`,小尺度初始化 |
| - `local_spatial_fusion_norm` |
| - `LayerNorm(semantic + local_update)` |
|
|
| 代码位置: |
|
|
| - [spatial_modules.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_modules.py) |
| - [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py) |
|
|
| ### 2.5 Prediction heads |
|
|
| 当前 `LocalSpatialPredictionHeads` 是 0410 那条线真正使用的版本: |
|
|
| - 输入只有 `fused_tokens` |
| - class 和 spatial 都从同一个 fused sequence 上做 attention pooling |
| - 不是 class 从 semantic token 读、spatial 从 fused token 读的后改版本 |
|
|
| 结构: |
|
|
| ```text |
| fused_tokens [B, T_s, D] |
| -> class_score -> class attention pooling -> class_token |
| -> spatial_score -> spatial attention pooling -> spatial_token |
| |
| class_token -> class_head -> pred_class_logits |
| spatial_token -> direction_head -> pred_direction (normalized) |
| spatial_token -> distance_head -> pred_distance (softplus) |
| ``` |
|
|
| 初始化细节: |
|
|
| - `class_score.weight/bias = 0` |
| - `spatial_score.weight/bias = 0` |
|
|
| 所以初始 attention pooling 等价于均匀 mean-pool。 |
| 这就是为什么它可以比较平滑地接入之前的 class head checkpoint。 |
|
|
| ## 3. 当前 supervision 与 loss |
|
|
| 当前 `ov1_local_spatial` 使用: |
|
|
| ```text |
| cfg.loss.supervision_mode = "mono_ast" |
| ``` |
|
|
| 对应 loss 计算在: |
|
|
| - [spatial_loss.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_loss.py) |
|
|
| 当前这条线不是 Hungarian slot matching,而是单源 clip-level supervision: |
|
|
| - 每个样本必须只有一个有效源 |
| - 从第一个有效源读取: |
| - class |
| - azimuth |
| - elevation |
| - distance |
|
|
| loss 定义: |
|
|
| ```text |
| loss_cls_aux = cross_entropy(pred_class_logits, cls_target) |
| loss_direction = 1 - cos(pred_direction, gt_direction) |
| loss_dist = smooth_l1(pred_distance, gt_distance) |
| |
| loss_total = |
| lambda_cls_aux * loss_cls_aux |
| + lambda_direction * loss_direction |
| + lambda_dist * loss_dist |
| ``` |
|
|
| 当前 `ov1_local_spatial` 默认权重: |
|
|
| ```text |
| lambda_cls_aux = 1.0 |
| lambda_direction = 12.0 |
| lambda_dist = 2.0 |
| ``` |
|
|
| valid metrics: |
|
|
| - `class_acc` |
| - `azi_mae_deg` |
| - `ele_mae_deg` |
| - `dist_mae` |
|
|
| ### 3.1 Active window mask |
|
|
| 虽然是单源 clip-level readout,这条线仍然会构建弱时间窗 mask: |
|
|
| - `build_primary_source_window_mask()` |
| - 根据 source start/end time 映射到 token 轴 |
| - 作为 `active_window_mask` 提供给 `LocalSpatialPredictionHeads` |
|
|
| 这意味着 pooled class/spatial token 默认更关注标注的 active window,而不是整段所有 token。 |
|
|
| ## 4. 当前训练配置 |
|
|
| 当前恢复后的 `ov1_local_spatial` preset 在: |
|
|
| - [train_spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/train_spatial_beats.py) |
|
|
| 关键配置: |
|
|
| ```text |
| batch_size = 8 |
| num_workers = 4 |
| num_epochs = 20 |
| learning_rate = 1e-4 |
| weight_decay = 0.05 |
| |
| freeze_trunk_in_stage1 = True |
| unfreeze_full_trunk = False |
| train_patch_embedding_in_stage1 = False |
| train_spatial_adapter_in_stage1 = False |
| freeze_projector_by_default = True |
| |
| dataset.max_clip_duration_seconds = 20.0 |
| dataset.crop_mode = "start" |
| |
| best_metric_name = "azi_mae_deg" |
| minimize_best_metric = True |
| class_finetuned_ckpt = checkpoints/beats_ov1_event_cls_head_only/best.pt |
| ``` |
|
|
| 当前最重要的初始化来源有两个: |
|
|
| 1. `BEATs pretrained trunk` |
| 2. `checkpoints/beats_ov1_event_cls_head_only/best.pt` |
|
|
| 第二个 checkpoint 的加载策略是: |
|
|
| ```text |
| beats.patch_embedding.weight -> patch_embedding.proj.weight |
| beats.patch_embedding.bias -> patch_embedding.proj.bias |
| beats.layer_norm.* -> layer_norm.* |
| beats.post_extract_proj.* -> post_extract_proj.* |
| beats.encoder.* -> encoder.* |
| classifier.weight -> local_spatial_prediction_heads.class_head.weight |
| classifier.bias -> local_spatial_prediction_heads.class_head.bias |
| ``` |
|
|
| 所以当前 `ov1_local_spatial` 的初始分类能力,并不是随机的,而是: |
|
|
| - 先来自 W-channel 事件分类 baseline 的 class head |
| - 再接到 fused temporal readout 上 |
|
|
| ## 5. 当前 checkpoint 已证明的结果 |
|
|
| 参考 checkpoint: |
|
|
| ```text |
| checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt |
| ``` |
|
|
| 这条 run 的已知最佳结果是: |
|
|
| ```text |
| best epoch = 7 |
| best metric = azi_mae_deg |
| |
| train: |
| loss_total = 4.0454 |
| loss_cls_aux = 1.3232 |
| loss_direction = 0.1966 |
| loss_dist = 0.1816 |
| class_acc = 0.6184 |
| azi_mae_deg = 24.68 |
| ele_mae_deg = 10.07 |
| dist_mae = 0.465 |
| |
| valid: |
| loss_total = 5.1220 |
| loss_cls_aux = 2.5258 |
| loss_direction = 0.1856 |
| loss_dist = 0.1847 |
| class_acc = 0.3985 |
| azi_mae_deg = 23.52 |
| ele_mae_deg = 9.65 |
| dist_mae = 0.475 |
| ``` |
|
|
| 更关键的是 `epoch_0000` 的行为: |
|
|
| ```text |
| train class_acc ≈ 0.246 |
| val class_acc ≈ 0.244 |
| val azi_mae ≈ 71.49 |
| val ele_mae ≈ 71.54 |
| ``` |
|
|
| 这说明 run1 的初始状态是: |
|
|
| - class 明显高于随机 |
| - 空间一开始很差 |
|
|
| 这正是后续 0413 恢复工作所要回到的行为。 |
|
|
| ## 6. 0413 已排除掉的错误分支 |
|
|
| 为了避免后续再把这条基线改坏,下面这些都已经确认不是当前基线的一部分: |
|
|
| ### 6.1 不是 `semantic_aux_classifier` 版本 |
|
|
| 后面曾尝试: |
|
|
| - 在训练期额外挂一个完全独立的 semantic auxiliary classifier |
| - 让 class loss 不再从 fused token 读 |
|
|
| 结果: |
|
|
| - 这条线改变了初始化行为 |
| - 会让 `Epoch 0` class 接近随机 |
| - 与原始 run1 不一致 |
|
|
| 现在已经从当前基线彻底移除。 |
|
|
| ### 6.2 不是 `use_original_beats_semantic_frontend_for_local_spatial` 版本 |
| |
| 后面曾尝试: |
| |
| - 让 local_spatial 路径里的 semantic branch 改走原始 W-only Kaldi fbank + BEATs |
|
|
| 结果: |
|
|
| - 这条线会明显改变 local_spatial 的初始化分布 |
| - 也是导致后续 “空间一开始很好、class 一开始很差” 的原因之一 |
| |
| 现在已经从当前基线移除。 |
| |
| ### 6.3 不是 class/spatial 分开读不同 token 的版本 |
| |
| 后面曾尝试: |
| |
| - class 从 `semantic_tokens` 池化 |
| - spatial 从 `fused_tokens` 池化 |
|
|
| 这同样改变了 0410 run1 的行为。 |
| 当前已经恢复成: |
|
|
| - class 和 spatial 都从 `fused_tokens` 读 |
|
|
| ## 7. 当前仓库里额外存在但不属于这条基线的内容 |
|
|
| 当前仓库除了 `ov1_local_spatial` 之外,还保留了一些后续研发分支: |
|
|
| - `ov1_ast` |
| - `ov1_pretrunk_ast_*` |
| - `ov123_local_spatial_slot` |
| - `ov123_local_spatial_track` |
| - `ov123_local_spatial_accdoa` |
|
|
| 这些分支仍然在代码里,但它们不是当前这份文档关注的“恢复后的单源 local_spatial 基线”。 |
| |
| 后续如果继续研究 frame-level / slot-level 多源建模,应当: |
| |
| - 把它们视为新的实验线 |
| - 不要再直接污染 `ov1_local_spatial` |
| |
| ## 8. 当前推荐的复现命令 |
| |
| 恢复后的基线复现命令: |
| |
| ```bash |
| CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \ |
| train_spatial_beats.py \ |
| --preset ov1_local_spatial \ |
| --distributed \ |
| --batch-size 8 \ |
| --num-workers 24 \ |
| --num-epochs 20 \ |
| --output-dir checkpoints/spatial_beats_ov1_local_spatial_run1 |
| ``` |
| |
| 如果只想做行为检查,先跑 1 个 epoch 就够: |
| |
| ```bash |
| CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \ |
| train_spatial_beats.py \ |
| --preset ov1_local_spatial \ |
| --distributed \ |
| --batch-size 8 \ |
| --num-workers 24 \ |
| --num-epochs 1 \ |
| --output-dir checkpoints/spatial_beats_ov1_local_spatial_repro_check |
| ``` |
| |
| 判断是否仍在 run1 轨道上的最直接标准: |
| |
| - `Epoch 0 class_acc` 应该明显高于随机 |
| - `Epoch 0 azi/ele` 应该仍然很差 |
| |
| 而不是: |
| |
| - `Epoch 0 class` 接近 0 |
| - `Epoch 0 spatial` 一开始就很好 |
| |
| ## 9. 目前这条线的结论 |
| |
| 当前 `ov1_local_spatial` 基线已经证明: |
| |
| 1. 单路 fused token 架构是能学到空间的。 |
| 2. 这条线能在不动 trunk 的前提下,把 azimuth / elevation / distance 拉下来。 |
| 3. 但它的 class 保持能力有限,最佳 valid `class_acc` 约 `0.40`,明显低于纯 W-channel BEATs 分类 baseline。 |
| 4. 所以这条线当前更适合作为: |
| - “可工作空间基线” |
| - 后续 frame-level spatial 扩展的 warm-start |
| |
| 它还没有证明: |
| |
| 1. class 和 spatial 可以同时都做到强泛化。 |
| 2. fused token 已经足够好,可以直接无脑接 LLM。 |
| |
| ## 10. 后续继续迭代时的建议 |
| |
| 如果以后继续改,不要再直接破坏 `ov1_local_spatial` 这条基线。更稳的方式是: |
| |
| 1. 保留 `ov1_local_spatial` 不动,作为 frozen baseline。 |
| 2. 新实验单独开新 preset。 |
| 3. 新实验优先从: |
| - `checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt` |
| - 或 `DEFAULT_OV1_LOCAL_SPATIAL_INIT` |
| warm-start。 |
| 4. 任何涉及下面这些点的修改,都应视为“新架构”,不要再说和 run1 一样: |
| - semantic aux classifier |
| - 原始 W-BEATs special frontend |
| - class/spatial 分离 readout |
| - 新的 frame-level supervision path |
| |
| 当前最安全的角色定位是: |
| |
| ```text |
| ov1_local_spatial = 已恢复并可复现的 0410 单源 local-spatial baseline |
| ``` |
| |