0413 当前 Spatial-BEATs 基线总结
本文档记录 2026-04-13 时点仓库内当前可复现的 Spatial-BEATs 基线状态。
重点不是回顾全部试错,而是明确:
- 当前应该以哪个 checkpoint / preset 为基线
- 当前模型架构到底是什么
- 当前训练和验证流程如何工作
- 这条线已经证明了什么,没证明什么
- 后续继续改时,哪些地方不要再误改
当前本地快照提交:
b782333 snapshot: restore 0410 local_spatial baseline state
1. 当前基线对象
当前恢复并确认的基线是:
train_spatial_beats.py --preset ov1_local_spatial
对应的历史参考 checkpoint 是:
checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt
这条线对应的设计目标是:
- 用
W通道的 BEATs 路径提供语义时间序列 - 用
WXYZ + IVxyz的 local spatial branch 提供局部空间时间序列 - 两者在 token 时间轴上对齐后相加融合
- 用一个单路 fused temporal token 序列同时做:
- clip-level class
- direction
- distance
- 最终
llm_spatial_tokens
这条线不是:
semantic_aux_classifier双头版本- 原始 Kaldi fbank W semantic special path 版本
- class/spatial 分别从不同 token 读出的版本
这些后续尝试都已经从当前基线中移除。
2. 当前模型架构
当前 ov1_local_spatial 的主路径如下:
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> foa_feat [B, 7, T_f, F]
channels = [W, X, Y, Z, IVx, IVy, IVz]
-> extract_patch_tokens()
base path:
W_logmel -> original single-channel BEATs patch embedding
spatial delta path:
7ch foa_feat -> SpatialDeltaPatchAdapter -> delta_patch_tokens
patch_tokens = base_patch_tokens + delta_patch_tokens
-> BEATs trunk
-> frequency_pool
-> TemporalResampler(target_token_rate = 2.5 Hz)
-> ShallowTemporalReadout
-> semantic_embeddings [B, T_s, D]
-> LocalSpatialEncoder(foa_feat) -> local_patch_rate_tokens [B, T_f, D_s]
-> LocalSpatialResampler -> local_spatial_tokens [B, T_s, D_s]
-> local_spatial_proj: Linear(D_s, D)
-> fused_embeddings = LayerNorm(semantic_embeddings + local_update)
-> LocalSpatialPredictionHeads(fused_embeddings)
-> class pooled token
-> spatial pooled token
-> pred_class_logits [B, 65]
-> pred_direction [B, 3]
-> pred_distance [B, 1]
-> SpatialTokenProjector(fused_embeddings)
-> llm_spatial_tokens [B, T_s, d_llm]
2.1 FOA 前端
代码入口:
当前前端行为:
- 输入波形按 DCASE 存储顺序
[W, Y, Z, X] SpatialBEATsPreprocessor内部先重排到[W, X, Y, Z]- 计算:
WXYZ logmelIVx, IVy, IVz
- 输出
foa_feat [B, 7, T_f, F]
这里仍然是当前基线的一部分。
也就是说,当前基线不是完全复用原始 BEATs.py preprocess(),而是:
- 用自实现 STFT/mel 前端
- 再用 BEATs 统计量做归一化
这点非常重要,因为后续很多分类迁移问题都和这里有关。
2.2 Patch token 构造
当前 patch 输入仍然是“W base + 7ch delta”风格:
patch_embedding.proj只接收W_logmelSpatialDeltaPatchAdapter从完整foa_feat生成 patch-level residual update- 最终送入 trunk 的是:
patch_tokens = base_patch_tokens + delta_patch_tokens
当前 ov1_local_spatial preset 中:
train_patch_embedding_in_stage1 = Falsetrain_spatial_adapter_in_stage1 = Falsepatch_adapter_residual_alpha_init = 0.0patch_adapter_out_proj_scale_init = 0.0
也就是说,在这条恢复后的 run1 基线里:
- trunk 冻结
- patch embedding 冻结
- patch delta adapter 也不训练
真正训练的是后面的 local spatial 分支和 fused readout。
2.3 BEATs trunk 路径
当前 trunk 保持的是原 BEATs 主干:
layer_normpost_extract_projencoder.pos_convencoder.layers.*encoder.layer_norm
预训练初始化来自:
pretrain_ckpt/BEATs_iter3_plus_AS2M.pt/BEATs_iter3_plus_AS2M.pt
加载逻辑在:
当前日志行为应类似:
[SpatialBEATs] Reusing 201 compatible BEATs keys
[SpatialBEATs] Reusing original single-channel BEATs patch embedding
2.4 Local spatial 分支
当前 spatial branch 是当前基线的核心增量,结构为:
LocalSpatialEncoder- 2D CNN 提取局部多通道空间模式
- 频率池化
- temporal Transformer
TemporalResampler- 把 local spatial token 序列重采样到和 semantic token 同样的
T_s
- 把 local spatial token 序列重采样到和 semantic token 同样的
local_spatial_projLinear(D_s, D),小尺度初始化
local_spatial_fusion_normLayerNorm(semantic + local_update)
代码位置:
2.5 Prediction heads
当前 LocalSpatialPredictionHeads 是 0410 那条线真正使用的版本:
- 输入只有
fused_tokens - class 和 spatial 都从同一个 fused sequence 上做 attention pooling
- 不是 class 从 semantic token 读、spatial 从 fused token 读的后改版本
结构:
fused_tokens [B, T_s, D]
-> class_score -> class attention pooling -> class_token
-> spatial_score -> spatial attention pooling -> spatial_token
class_token -> class_head -> pred_class_logits
spatial_token -> direction_head -> pred_direction (normalized)
spatial_token -> distance_head -> pred_distance (softplus)
初始化细节:
class_score.weight/bias = 0spatial_score.weight/bias = 0
所以初始 attention pooling 等价于均匀 mean-pool。
这就是为什么它可以比较平滑地接入之前的 class head checkpoint。
3. 当前 supervision 与 loss
当前 ov1_local_spatial 使用:
cfg.loss.supervision_mode = "mono_ast"
对应 loss 计算在:
当前这条线不是 Hungarian slot matching,而是单源 clip-level supervision:
- 每个样本必须只有一个有效源
- 从第一个有效源读取:
- class
- azimuth
- elevation
- distance
loss 定义:
loss_cls_aux = cross_entropy(pred_class_logits, cls_target)
loss_direction = 1 - cos(pred_direction, gt_direction)
loss_dist = smooth_l1(pred_distance, gt_distance)
loss_total =
lambda_cls_aux * loss_cls_aux
+ lambda_direction * loss_direction
+ lambda_dist * loss_dist
当前 ov1_local_spatial 默认权重:
lambda_cls_aux = 1.0
lambda_direction = 12.0
lambda_dist = 2.0
valid metrics:
class_accazi_mae_degele_mae_degdist_mae
3.1 Active window mask
虽然是单源 clip-level readout,这条线仍然会构建弱时间窗 mask:
build_primary_source_window_mask()- 根据 source start/end time 映射到 token 轴
- 作为
active_window_mask提供给LocalSpatialPredictionHeads
这意味着 pooled class/spatial token 默认更关注标注的 active window,而不是整段所有 token。
4. 当前训练配置
当前恢复后的 ov1_local_spatial preset 在:
关键配置:
batch_size = 8
num_workers = 4
num_epochs = 20
learning_rate = 1e-4
weight_decay = 0.05
freeze_trunk_in_stage1 = True
unfreeze_full_trunk = False
train_patch_embedding_in_stage1 = False
train_spatial_adapter_in_stage1 = False
freeze_projector_by_default = True
dataset.max_clip_duration_seconds = 20.0
dataset.crop_mode = "start"
best_metric_name = "azi_mae_deg"
minimize_best_metric = True
class_finetuned_ckpt = checkpoints/beats_ov1_event_cls_head_only/best.pt
当前最重要的初始化来源有两个:
BEATs pretrained trunkcheckpoints/beats_ov1_event_cls_head_only/best.pt
第二个 checkpoint 的加载策略是:
beats.patch_embedding.weight -> patch_embedding.proj.weight
beats.patch_embedding.bias -> patch_embedding.proj.bias
beats.layer_norm.* -> layer_norm.*
beats.post_extract_proj.* -> post_extract_proj.*
beats.encoder.* -> encoder.*
classifier.weight -> local_spatial_prediction_heads.class_head.weight
classifier.bias -> local_spatial_prediction_heads.class_head.bias
所以当前 ov1_local_spatial 的初始分类能力,并不是随机的,而是:
- 先来自 W-channel 事件分类 baseline 的 class head
- 再接到 fused temporal readout 上
5. 当前 checkpoint 已证明的结果
参考 checkpoint:
checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt
这条 run 的已知最佳结果是:
best epoch = 7
best metric = azi_mae_deg
train:
loss_total = 4.0454
loss_cls_aux = 1.3232
loss_direction = 0.1966
loss_dist = 0.1816
class_acc = 0.6184
azi_mae_deg = 24.68
ele_mae_deg = 10.07
dist_mae = 0.465
valid:
loss_total = 5.1220
loss_cls_aux = 2.5258
loss_direction = 0.1856
loss_dist = 0.1847
class_acc = 0.3985
azi_mae_deg = 23.52
ele_mae_deg = 9.65
dist_mae = 0.475
更关键的是 epoch_0000 的行为:
train class_acc ≈ 0.246
val class_acc ≈ 0.244
val azi_mae ≈ 71.49
val ele_mae ≈ 71.54
这说明 run1 的初始状态是:
- class 明显高于随机
- 空间一开始很差
这正是后续 0413 恢复工作所要回到的行为。
6. 0413 已排除掉的错误分支
为了避免后续再把这条基线改坏,下面这些都已经确认不是当前基线的一部分:
6.1 不是 semantic_aux_classifier 版本
后面曾尝试:
- 在训练期额外挂一个完全独立的 semantic auxiliary classifier
- 让 class loss 不再从 fused token 读
结果:
- 这条线改变了初始化行为
- 会让
Epoch 0class 接近随机 - 与原始 run1 不一致
现在已经从当前基线彻底移除。
6.2 不是 use_original_beats_semantic_frontend_for_local_spatial 版本
后面曾尝试:
- 让 local_spatial 路径里的 semantic branch 改走原始 W-only Kaldi fbank + BEATs
结果:
- 这条线会明显改变 local_spatial 的初始化分布
- 也是导致后续 “空间一开始很好、class 一开始很差” 的原因之一
现在已经从当前基线移除。
6.3 不是 class/spatial 分开读不同 token 的版本
后面曾尝试:
- class 从
semantic_tokens池化 - spatial 从
fused_tokens池化
这同样改变了 0410 run1 的行为。
当前已经恢复成:
- class 和 spatial 都从
fused_tokens读
7. 当前仓库里额外存在但不属于这条基线的内容
当前仓库除了 ov1_local_spatial 之外,还保留了一些后续研发分支:
ov1_astov1_pretrunk_ast_*ov123_local_spatial_slotov123_local_spatial_trackov123_local_spatial_accdoa
这些分支仍然在代码里,但它们不是当前这份文档关注的“恢复后的单源 local_spatial 基线”。
后续如果继续研究 frame-level / slot-level 多源建模,应当:
- 把它们视为新的实验线
- 不要再直接污染
ov1_local_spatial
8. 当前推荐的复现命令
恢复后的基线复现命令:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \
train_spatial_beats.py \
--preset ov1_local_spatial \
--distributed \
--batch-size 8 \
--num-workers 24 \
--num-epochs 20 \
--output-dir checkpoints/spatial_beats_ov1_local_spatial_run1
如果只想做行为检查,先跑 1 个 epoch 就够:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \
train_spatial_beats.py \
--preset ov1_local_spatial \
--distributed \
--batch-size 8 \
--num-workers 24 \
--num-epochs 1 \
--output-dir checkpoints/spatial_beats_ov1_local_spatial_repro_check
判断是否仍在 run1 轨道上的最直接标准:
Epoch 0 class_acc应该明显高于随机Epoch 0 azi/ele应该仍然很差
而不是:
Epoch 0 class接近 0Epoch 0 spatial一开始就很好
9. 目前这条线的结论
当前 ov1_local_spatial 基线已经证明:
- 单路 fused token 架构是能学到空间的。
- 这条线能在不动 trunk 的前提下,把 azimuth / elevation / distance 拉下来。
- 但它的 class 保持能力有限,最佳 valid
class_acc约0.40,明显低于纯 W-channel BEATs 分类 baseline。 - 所以这条线当前更适合作为:
- “可工作空间基线”
- 后续 frame-level spatial 扩展的 warm-start
它还没有证明:
- class 和 spatial 可以同时都做到强泛化。
- fused token 已经足够好,可以直接无脑接 LLM。
10. 后续继续迭代时的建议
如果以后继续改,不要再直接破坏 ov1_local_spatial 这条基线。更稳的方式是:
- 保留
ov1_local_spatial不动,作为 frozen baseline。 - 新实验单独开新 preset。
- 新实验优先从:
checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt- 或
DEFAULT_OV1_LOCAL_SPATIAL_INITwarm-start。
- 任何涉及下面这些点的修改,都应视为“新架构”,不要再说和 run1 一样:
- semantic aux classifier
- 原始 W-BEATs special frontend
- class/spatial 分离 readout
- 新的 frame-level supervision path
当前最安全的角色定位是:
ov1_local_spatial = 已恢复并可复现的 0410 单源 local-spatial baseline