# 0413 当前 Spatial-BEATs 基线总结 本文档记录 2026-04-13 时点仓库内当前可复现的 `Spatial-BEATs` 基线状态。 重点不是回顾全部试错,而是明确: - 当前应该以哪个 checkpoint / preset 为基线 - 当前模型架构到底是什么 - 当前训练和验证流程如何工作 - 这条线已经证明了什么,没证明什么 - 后续继续改时,哪些地方不要再误改 当前本地快照提交: ```text b782333 snapshot: restore 0410 local_spatial baseline state ``` ## 1. 当前基线对象 当前恢复并确认的基线是: ```text train_spatial_beats.py --preset ov1_local_spatial ``` 对应的历史参考 checkpoint 是: ```text checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt ``` 这条线对应的设计目标是: - 用 `W` 通道的 BEATs 路径提供语义时间序列 - 用 `WXYZ + IVxyz` 的 local spatial branch 提供局部空间时间序列 - 两者在 token 时间轴上对齐后相加融合 - 用一个单路 fused temporal token 序列同时做: - clip-level class - direction - distance - 最终 `llm_spatial_tokens` 这条线不是: - `semantic_aux_classifier` 双头版本 - 原始 Kaldi fbank W semantic special path 版本 - class/spatial 分别从不同 token 读出的版本 这些后续尝试都已经从当前基线中移除。 ## 2. 当前模型架构 当前 `ov1_local_spatial` 的主路径如下: ```text FOA waveform [B, 4, T] -> SpatialBEATsPreprocessor -> foa_feat [B, 7, T_f, F] channels = [W, X, Y, Z, IVx, IVy, IVz] -> extract_patch_tokens() base path: W_logmel -> original single-channel BEATs patch embedding spatial delta path: 7ch foa_feat -> SpatialDeltaPatchAdapter -> delta_patch_tokens patch_tokens = base_patch_tokens + delta_patch_tokens -> BEATs trunk -> frequency_pool -> TemporalResampler(target_token_rate = 2.5 Hz) -> ShallowTemporalReadout -> semantic_embeddings [B, T_s, D] -> LocalSpatialEncoder(foa_feat) -> local_patch_rate_tokens [B, T_f, D_s] -> LocalSpatialResampler -> local_spatial_tokens [B, T_s, D_s] -> local_spatial_proj: Linear(D_s, D) -> fused_embeddings = LayerNorm(semantic_embeddings + local_update) -> LocalSpatialPredictionHeads(fused_embeddings) -> class pooled token -> spatial pooled token -> pred_class_logits [B, 65] -> pred_direction [B, 3] -> pred_distance [B, 1] -> SpatialTokenProjector(fused_embeddings) -> llm_spatial_tokens [B, T_s, d_llm] ``` ### 2.1 FOA 前端 代码入口: - [spatial_modules.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_modules.py) - [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py) 当前前端行为: - 输入波形按 DCASE 存储顺序 `[W, Y, Z, X]` - `SpatialBEATsPreprocessor` 内部先重排到 `[W, X, Y, Z]` - 计算: - `WXYZ logmel` - `IVx, IVy, IVz` - 输出 `foa_feat [B, 7, T_f, F]` 这里仍然是当前基线的一部分。 也就是说,当前基线不是完全复用原始 `BEATs.py preprocess()`,而是: - 用自实现 STFT/mel 前端 - 再用 BEATs 统计量做归一化 这点非常重要,因为后续很多分类迁移问题都和这里有关。 ### 2.2 Patch token 构造 当前 patch 输入仍然是“W base + 7ch delta”风格: - `patch_embedding.proj` 只接收 `W_logmel` - `SpatialDeltaPatchAdapter` 从完整 `foa_feat` 生成 patch-level residual update - 最终送入 trunk 的是: ```text patch_tokens = base_patch_tokens + delta_patch_tokens ``` 当前 `ov1_local_spatial` preset 中: - `train_patch_embedding_in_stage1 = False` - `train_spatial_adapter_in_stage1 = False` - `patch_adapter_residual_alpha_init = 0.0` - `patch_adapter_out_proj_scale_init = 0.0` 也就是说,在这条恢复后的 run1 基线里: - trunk 冻结 - patch embedding 冻结 - patch delta adapter 也不训练 真正训练的是后面的 local spatial 分支和 fused readout。 ### 2.3 BEATs trunk 路径 当前 trunk 保持的是原 BEATs 主干: - `layer_norm` - `post_extract_proj` - `encoder.pos_conv` - `encoder.layers.*` - `encoder.layer_norm` 预训练初始化来自: ```text pretrain_ckpt/BEATs_iter3_plus_AS2M.pt/BEATs_iter3_plus_AS2M.pt ``` 加载逻辑在: - [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py) 当前日志行为应类似: ```text [SpatialBEATs] Reusing 201 compatible BEATs keys [SpatialBEATs] Reusing original single-channel BEATs patch embedding ``` ### 2.4 Local spatial 分支 当前 spatial branch 是当前基线的核心增量,结构为: - `LocalSpatialEncoder` - 2D CNN 提取局部多通道空间模式 - 频率池化 - temporal Transformer - `TemporalResampler` - 把 local spatial token 序列重采样到和 semantic token 同样的 `T_s` - `local_spatial_proj` - `Linear(D_s, D)`,小尺度初始化 - `local_spatial_fusion_norm` - `LayerNorm(semantic + local_update)` 代码位置: - [spatial_modules.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_modules.py) - [spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_beats.py) ### 2.5 Prediction heads 当前 `LocalSpatialPredictionHeads` 是 0410 那条线真正使用的版本: - 输入只有 `fused_tokens` - class 和 spatial 都从同一个 fused sequence 上做 attention pooling - 不是 class 从 semantic token 读、spatial 从 fused token 读的后改版本 结构: ```text fused_tokens [B, T_s, D] -> class_score -> class attention pooling -> class_token -> spatial_score -> spatial attention pooling -> spatial_token class_token -> class_head -> pred_class_logits spatial_token -> direction_head -> pred_direction (normalized) spatial_token -> distance_head -> pred_distance (softplus) ``` 初始化细节: - `class_score.weight/bias = 0` - `spatial_score.weight/bias = 0` 所以初始 attention pooling 等价于均匀 mean-pool。 这就是为什么它可以比较平滑地接入之前的 class head checkpoint。 ## 3. 当前 supervision 与 loss 当前 `ov1_local_spatial` 使用: ```text cfg.loss.supervision_mode = "mono_ast" ``` 对应 loss 计算在: - [spatial_loss.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/spatial_loss.py) 当前这条线不是 Hungarian slot matching,而是单源 clip-level supervision: - 每个样本必须只有一个有效源 - 从第一个有效源读取: - class - azimuth - elevation - distance loss 定义: ```text loss_cls_aux = cross_entropy(pred_class_logits, cls_target) loss_direction = 1 - cos(pred_direction, gt_direction) loss_dist = smooth_l1(pred_distance, gt_distance) loss_total = lambda_cls_aux * loss_cls_aux + lambda_direction * loss_direction + lambda_dist * loss_dist ``` 当前 `ov1_local_spatial` 默认权重: ```text lambda_cls_aux = 1.0 lambda_direction = 12.0 lambda_dist = 2.0 ``` valid metrics: - `class_acc` - `azi_mae_deg` - `ele_mae_deg` - `dist_mae` ### 3.1 Active window mask 虽然是单源 clip-level readout,这条线仍然会构建弱时间窗 mask: - `build_primary_source_window_mask()` - 根据 source start/end time 映射到 token 轴 - 作为 `active_window_mask` 提供给 `LocalSpatialPredictionHeads` 这意味着 pooled class/spatial token 默认更关注标注的 active window,而不是整段所有 token。 ## 4. 当前训练配置 当前恢复后的 `ov1_local_spatial` preset 在: - [train_spatial_beats.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/train_spatial_beats.py) 关键配置: ```text batch_size = 8 num_workers = 4 num_epochs = 20 learning_rate = 1e-4 weight_decay = 0.05 freeze_trunk_in_stage1 = True unfreeze_full_trunk = False train_patch_embedding_in_stage1 = False train_spatial_adapter_in_stage1 = False freeze_projector_by_default = True dataset.max_clip_duration_seconds = 20.0 dataset.crop_mode = "start" best_metric_name = "azi_mae_deg" minimize_best_metric = True class_finetuned_ckpt = checkpoints/beats_ov1_event_cls_head_only/best.pt ``` 当前最重要的初始化来源有两个: 1. `BEATs pretrained trunk` 2. `checkpoints/beats_ov1_event_cls_head_only/best.pt` 第二个 checkpoint 的加载策略是: ```text beats.patch_embedding.weight -> patch_embedding.proj.weight beats.patch_embedding.bias -> patch_embedding.proj.bias beats.layer_norm.* -> layer_norm.* beats.post_extract_proj.* -> post_extract_proj.* beats.encoder.* -> encoder.* classifier.weight -> local_spatial_prediction_heads.class_head.weight classifier.bias -> local_spatial_prediction_heads.class_head.bias ``` 所以当前 `ov1_local_spatial` 的初始分类能力,并不是随机的,而是: - 先来自 W-channel 事件分类 baseline 的 class head - 再接到 fused temporal readout 上 ## 5. 当前 checkpoint 已证明的结果 参考 checkpoint: ```text checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt ``` 这条 run 的已知最佳结果是: ```text best epoch = 7 best metric = azi_mae_deg train: loss_total = 4.0454 loss_cls_aux = 1.3232 loss_direction = 0.1966 loss_dist = 0.1816 class_acc = 0.6184 azi_mae_deg = 24.68 ele_mae_deg = 10.07 dist_mae = 0.465 valid: loss_total = 5.1220 loss_cls_aux = 2.5258 loss_direction = 0.1856 loss_dist = 0.1847 class_acc = 0.3985 azi_mae_deg = 23.52 ele_mae_deg = 9.65 dist_mae = 0.475 ``` 更关键的是 `epoch_0000` 的行为: ```text train class_acc ≈ 0.246 val class_acc ≈ 0.244 val azi_mae ≈ 71.49 val ele_mae ≈ 71.54 ``` 这说明 run1 的初始状态是: - class 明显高于随机 - 空间一开始很差 这正是后续 0413 恢复工作所要回到的行为。 ## 6. 0413 已排除掉的错误分支 为了避免后续再把这条基线改坏,下面这些都已经确认不是当前基线的一部分: ### 6.1 不是 `semantic_aux_classifier` 版本 后面曾尝试: - 在训练期额外挂一个完全独立的 semantic auxiliary classifier - 让 class loss 不再从 fused token 读 结果: - 这条线改变了初始化行为 - 会让 `Epoch 0` class 接近随机 - 与原始 run1 不一致 现在已经从当前基线彻底移除。 ### 6.2 不是 `use_original_beats_semantic_frontend_for_local_spatial` 版本 后面曾尝试: - 让 local_spatial 路径里的 semantic branch 改走原始 W-only Kaldi fbank + BEATs 结果: - 这条线会明显改变 local_spatial 的初始化分布 - 也是导致后续 “空间一开始很好、class 一开始很差” 的原因之一 现在已经从当前基线移除。 ### 6.3 不是 class/spatial 分开读不同 token 的版本 后面曾尝试: - class 从 `semantic_tokens` 池化 - spatial 从 `fused_tokens` 池化 这同样改变了 0410 run1 的行为。 当前已经恢复成: - class 和 spatial 都从 `fused_tokens` 读 ## 7. 当前仓库里额外存在但不属于这条基线的内容 当前仓库除了 `ov1_local_spatial` 之外,还保留了一些后续研发分支: - `ov1_ast` - `ov1_pretrunk_ast_*` - `ov123_local_spatial_slot` - `ov123_local_spatial_track` - `ov123_local_spatial_accdoa` 这些分支仍然在代码里,但它们不是当前这份文档关注的“恢复后的单源 local_spatial 基线”。 后续如果继续研究 frame-level / slot-level 多源建模,应当: - 把它们视为新的实验线 - 不要再直接污染 `ov1_local_spatial` ## 8. 当前推荐的复现命令 恢复后的基线复现命令: ```bash CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \ train_spatial_beats.py \ --preset ov1_local_spatial \ --distributed \ --batch-size 8 \ --num-workers 24 \ --num-epochs 20 \ --output-dir checkpoints/spatial_beats_ov1_local_spatial_run1 ``` 如果只想做行为检查,先跑 1 个 epoch 就够: ```bash CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \ train_spatial_beats.py \ --preset ov1_local_spatial \ --distributed \ --batch-size 8 \ --num-workers 24 \ --num-epochs 1 \ --output-dir checkpoints/spatial_beats_ov1_local_spatial_repro_check ``` 判断是否仍在 run1 轨道上的最直接标准: - `Epoch 0 class_acc` 应该明显高于随机 - `Epoch 0 azi/ele` 应该仍然很差 而不是: - `Epoch 0 class` 接近 0 - `Epoch 0 spatial` 一开始就很好 ## 9. 目前这条线的结论 当前 `ov1_local_spatial` 基线已经证明: 1. 单路 fused token 架构是能学到空间的。 2. 这条线能在不动 trunk 的前提下,把 azimuth / elevation / distance 拉下来。 3. 但它的 class 保持能力有限,最佳 valid `class_acc` 约 `0.40`,明显低于纯 W-channel BEATs 分类 baseline。 4. 所以这条线当前更适合作为: - “可工作空间基线” - 后续 frame-level spatial 扩展的 warm-start 它还没有证明: 1. class 和 spatial 可以同时都做到强泛化。 2. fused token 已经足够好,可以直接无脑接 LLM。 ## 10. 后续继续迭代时的建议 如果以后继续改,不要再直接破坏 `ov1_local_spatial` 这条基线。更稳的方式是: 1. 保留 `ov1_local_spatial` 不动,作为 frozen baseline。 2. 新实验单独开新 preset。 3. 新实验优先从: - `checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt` - 或 `DEFAULT_OV1_LOCAL_SPATIAL_INIT` warm-start。 4. 任何涉及下面这些点的修改,都应视为“新架构”,不要再说和 run1 一样: - semantic aux classifier - 原始 W-BEATs special frontend - class/spatial 分离 readout - 新的 frame-level supervision path 当前最安全的角色定位是: ```text ov1_local_spatial = 已恢复并可复现的 0410 单源 local-spatial baseline ```