Spatial-BEATs / docs /0413.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
14 kB

0413 当前 Spatial-BEATs 基线总结

本文档记录 2026-04-13 时点仓库内当前可复现的 Spatial-BEATs 基线状态。
重点不是回顾全部试错,而是明确:

  • 当前应该以哪个 checkpoint / preset 为基线
  • 当前模型架构到底是什么
  • 当前训练和验证流程如何工作
  • 这条线已经证明了什么,没证明什么
  • 后续继续改时,哪些地方不要再误改

当前本地快照提交:

b782333 snapshot: restore 0410 local_spatial baseline state

1. 当前基线对象

当前恢复并确认的基线是:

train_spatial_beats.py --preset ov1_local_spatial

对应的历史参考 checkpoint 是:

checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt

这条线对应的设计目标是:

  • W 通道的 BEATs 路径提供语义时间序列
  • WXYZ + IVxyz 的 local spatial branch 提供局部空间时间序列
  • 两者在 token 时间轴上对齐后相加融合
  • 用一个单路 fused temporal token 序列同时做:
    • clip-level class
    • direction
    • distance
    • 最终 llm_spatial_tokens

这条线不是:

  • semantic_aux_classifier 双头版本
  • 原始 Kaldi fbank W semantic special path 版本
  • class/spatial 分别从不同 token 读出的版本

这些后续尝试都已经从当前基线中移除。

2. 当前模型架构

当前 ov1_local_spatial 的主路径如下:

FOA waveform [B, 4, T]
  -> SpatialBEATsPreprocessor
  -> foa_feat [B, 7, T_f, F]
     channels = [W, X, Y, Z, IVx, IVy, IVz]

  -> extract_patch_tokens()
     base path:
       W_logmel -> original single-channel BEATs patch embedding
     spatial delta path:
       7ch foa_feat -> SpatialDeltaPatchAdapter -> delta_patch_tokens
     patch_tokens = base_patch_tokens + delta_patch_tokens

  -> BEATs trunk
  -> frequency_pool
  -> TemporalResampler(target_token_rate = 2.5 Hz)
  -> ShallowTemporalReadout
  -> semantic_embeddings [B, T_s, D]

  -> LocalSpatialEncoder(foa_feat) -> local_patch_rate_tokens [B, T_f, D_s]
  -> LocalSpatialResampler -> local_spatial_tokens [B, T_s, D_s]
  -> local_spatial_proj: Linear(D_s, D)
  -> fused_embeddings = LayerNorm(semantic_embeddings + local_update)

  -> LocalSpatialPredictionHeads(fused_embeddings)
     -> class pooled token
     -> spatial pooled token
     -> pred_class_logits [B, 65]
     -> pred_direction [B, 3]
     -> pred_distance [B, 1]

  -> SpatialTokenProjector(fused_embeddings)
     -> llm_spatial_tokens [B, T_s, d_llm]

2.1 FOA 前端

代码入口:

当前前端行为:

  • 输入波形按 DCASE 存储顺序 [W, Y, Z, X]
  • SpatialBEATsPreprocessor 内部先重排到 [W, X, Y, Z]
  • 计算:
    • WXYZ logmel
    • IVx, IVy, IVz
  • 输出 foa_feat [B, 7, T_f, F]

这里仍然是当前基线的一部分。
也就是说,当前基线不是完全复用原始 BEATs.py preprocess(),而是:

  • 用自实现 STFT/mel 前端
  • 再用 BEATs 统计量做归一化

这点非常重要,因为后续很多分类迁移问题都和这里有关。

2.2 Patch token 构造

当前 patch 输入仍然是“W base + 7ch delta”风格:

  • patch_embedding.proj 只接收 W_logmel
  • SpatialDeltaPatchAdapter 从完整 foa_feat 生成 patch-level residual update
  • 最终送入 trunk 的是:
patch_tokens = base_patch_tokens + delta_patch_tokens

当前 ov1_local_spatial preset 中:

  • train_patch_embedding_in_stage1 = False
  • train_spatial_adapter_in_stage1 = False
  • patch_adapter_residual_alpha_init = 0.0
  • patch_adapter_out_proj_scale_init = 0.0

也就是说,在这条恢复后的 run1 基线里:

  • trunk 冻结
  • patch embedding 冻结
  • patch delta adapter 也不训练

真正训练的是后面的 local spatial 分支和 fused readout。

2.3 BEATs trunk 路径

当前 trunk 保持的是原 BEATs 主干:

  • layer_norm
  • post_extract_proj
  • encoder.pos_conv
  • encoder.layers.*
  • encoder.layer_norm

预训练初始化来自:

pretrain_ckpt/BEATs_iter3_plus_AS2M.pt/BEATs_iter3_plus_AS2M.pt

加载逻辑在:

当前日志行为应类似:

[SpatialBEATs] Reusing 201 compatible BEATs keys
[SpatialBEATs] Reusing original single-channel BEATs patch embedding

2.4 Local spatial 分支

当前 spatial branch 是当前基线的核心增量,结构为:

  • LocalSpatialEncoder
    • 2D CNN 提取局部多通道空间模式
    • 频率池化
    • temporal Transformer
  • TemporalResampler
    • 把 local spatial token 序列重采样到和 semantic token 同样的 T_s
  • local_spatial_proj
    • Linear(D_s, D),小尺度初始化
  • local_spatial_fusion_norm
    • LayerNorm(semantic + local_update)

代码位置:

2.5 Prediction heads

当前 LocalSpatialPredictionHeads 是 0410 那条线真正使用的版本:

  • 输入只有 fused_tokens
  • class 和 spatial 都从同一个 fused sequence 上做 attention pooling
  • 不是 class 从 semantic token 读、spatial 从 fused token 读的后改版本

结构:

fused_tokens [B, T_s, D]
  -> class_score -> class attention pooling -> class_token
  -> spatial_score -> spatial attention pooling -> spatial_token

class_token   -> class_head -> pred_class_logits
spatial_token -> direction_head -> pred_direction (normalized)
spatial_token -> distance_head  -> pred_distance (softplus)

初始化细节:

  • class_score.weight/bias = 0
  • spatial_score.weight/bias = 0

所以初始 attention pooling 等价于均匀 mean-pool。
这就是为什么它可以比较平滑地接入之前的 class head checkpoint。

3. 当前 supervision 与 loss

当前 ov1_local_spatial 使用:

cfg.loss.supervision_mode = "mono_ast"

对应 loss 计算在:

当前这条线不是 Hungarian slot matching,而是单源 clip-level supervision:

  • 每个样本必须只有一个有效源
  • 从第一个有效源读取:
    • class
    • azimuth
    • elevation
    • distance

loss 定义:

loss_cls_aux = cross_entropy(pred_class_logits, cls_target)
loss_direction = 1 - cos(pred_direction, gt_direction)
loss_dist = smooth_l1(pred_distance, gt_distance)

loss_total =
    lambda_cls_aux * loss_cls_aux
  + lambda_direction * loss_direction
  + lambda_dist * loss_dist

当前 ov1_local_spatial 默认权重:

lambda_cls_aux = 1.0
lambda_direction = 12.0
lambda_dist = 2.0

valid metrics:

  • class_acc
  • azi_mae_deg
  • ele_mae_deg
  • dist_mae

3.1 Active window mask

虽然是单源 clip-level readout,这条线仍然会构建弱时间窗 mask:

  • build_primary_source_window_mask()
  • 根据 source start/end time 映射到 token 轴
  • 作为 active_window_mask 提供给 LocalSpatialPredictionHeads

这意味着 pooled class/spatial token 默认更关注标注的 active window,而不是整段所有 token。

4. 当前训练配置

当前恢复后的 ov1_local_spatial preset 在:

关键配置:

batch_size = 8
num_workers = 4
num_epochs = 20
learning_rate = 1e-4
weight_decay = 0.05

freeze_trunk_in_stage1 = True
unfreeze_full_trunk = False
train_patch_embedding_in_stage1 = False
train_spatial_adapter_in_stage1 = False
freeze_projector_by_default = True

dataset.max_clip_duration_seconds = 20.0
dataset.crop_mode = "start"

best_metric_name = "azi_mae_deg"
minimize_best_metric = True
class_finetuned_ckpt = checkpoints/beats_ov1_event_cls_head_only/best.pt

当前最重要的初始化来源有两个:

  1. BEATs pretrained trunk
  2. checkpoints/beats_ov1_event_cls_head_only/best.pt

第二个 checkpoint 的加载策略是:

beats.patch_embedding.weight -> patch_embedding.proj.weight
beats.patch_embedding.bias   -> patch_embedding.proj.bias
beats.layer_norm.*           -> layer_norm.*
beats.post_extract_proj.*    -> post_extract_proj.*
beats.encoder.*              -> encoder.*
classifier.weight            -> local_spatial_prediction_heads.class_head.weight
classifier.bias              -> local_spatial_prediction_heads.class_head.bias

所以当前 ov1_local_spatial 的初始分类能力,并不是随机的,而是:

  • 先来自 W-channel 事件分类 baseline 的 class head
  • 再接到 fused temporal readout 上

5. 当前 checkpoint 已证明的结果

参考 checkpoint:

checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt

这条 run 的已知最佳结果是:

best epoch = 7
best metric = azi_mae_deg

train:
  loss_total      = 4.0454
  loss_cls_aux    = 1.3232
  loss_direction  = 0.1966
  loss_dist       = 0.1816
  class_acc       = 0.6184
  azi_mae_deg     = 24.68
  ele_mae_deg     = 10.07
  dist_mae        = 0.465

valid:
  loss_total      = 5.1220
  loss_cls_aux    = 2.5258
  loss_direction  = 0.1856
  loss_dist       = 0.1847
  class_acc       = 0.3985
  azi_mae_deg     = 23.52
  ele_mae_deg     = 9.65
  dist_mae        = 0.475

更关键的是 epoch_0000 的行为:

train class_acc ≈ 0.246
val   class_acc ≈ 0.244
val   azi_mae   ≈ 71.49
val   ele_mae   ≈ 71.54

这说明 run1 的初始状态是:

  • class 明显高于随机
  • 空间一开始很差

这正是后续 0413 恢复工作所要回到的行为。

6. 0413 已排除掉的错误分支

为了避免后续再把这条基线改坏,下面这些都已经确认不是当前基线的一部分:

6.1 不是 semantic_aux_classifier 版本

后面曾尝试:

  • 在训练期额外挂一个完全独立的 semantic auxiliary classifier
  • 让 class loss 不再从 fused token 读

结果:

  • 这条线改变了初始化行为
  • 会让 Epoch 0 class 接近随机
  • 与原始 run1 不一致

现在已经从当前基线彻底移除。

6.2 不是 use_original_beats_semantic_frontend_for_local_spatial 版本

后面曾尝试:

  • 让 local_spatial 路径里的 semantic branch 改走原始 W-only Kaldi fbank + BEATs

结果:

  • 这条线会明显改变 local_spatial 的初始化分布
  • 也是导致后续 “空间一开始很好、class 一开始很差” 的原因之一

现在已经从当前基线移除。

6.3 不是 class/spatial 分开读不同 token 的版本

后面曾尝试:

  • class 从 semantic_tokens 池化
  • spatial 从 fused_tokens 池化

这同样改变了 0410 run1 的行为。
当前已经恢复成:

  • class 和 spatial 都从 fused_tokens

7. 当前仓库里额外存在但不属于这条基线的内容

当前仓库除了 ov1_local_spatial 之外,还保留了一些后续研发分支:

  • ov1_ast
  • ov1_pretrunk_ast_*
  • ov123_local_spatial_slot
  • ov123_local_spatial_track
  • ov123_local_spatial_accdoa

这些分支仍然在代码里,但它们不是当前这份文档关注的“恢复后的单源 local_spatial 基线”。

后续如果继续研究 frame-level / slot-level 多源建模,应当:

  • 把它们视为新的实验线
  • 不要再直接污染 ov1_local_spatial

8. 当前推荐的复现命令

恢复后的基线复现命令:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \
  train_spatial_beats.py \
  --preset ov1_local_spatial \
  --distributed \
  --batch-size 8 \
  --num-workers 24 \
  --num-epochs 20 \
  --output-dir checkpoints/spatial_beats_ov1_local_spatial_run1

如果只想做行为检查,先跑 1 个 epoch 就够:

CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc_per_node=8 --master_port=29531 \
  train_spatial_beats.py \
  --preset ov1_local_spatial \
  --distributed \
  --batch-size 8 \
  --num-workers 24 \
  --num-epochs 1 \
  --output-dir checkpoints/spatial_beats_ov1_local_spatial_repro_check

判断是否仍在 run1 轨道上的最直接标准:

  • Epoch 0 class_acc 应该明显高于随机
  • Epoch 0 azi/ele 应该仍然很差

而不是:

  • Epoch 0 class 接近 0
  • Epoch 0 spatial 一开始就很好

9. 目前这条线的结论

当前 ov1_local_spatial 基线已经证明:

  1. 单路 fused token 架构是能学到空间的。
  2. 这条线能在不动 trunk 的前提下,把 azimuth / elevation / distance 拉下来。
  3. 但它的 class 保持能力有限,最佳 valid class_acc0.40,明显低于纯 W-channel BEATs 分类 baseline。
  4. 所以这条线当前更适合作为:
    • “可工作空间基线”
    • 后续 frame-level spatial 扩展的 warm-start

它还没有证明:

  1. class 和 spatial 可以同时都做到强泛化。
  2. fused token 已经足够好,可以直接无脑接 LLM。

10. 后续继续迭代时的建议

如果以后继续改,不要再直接破坏 ov1_local_spatial 这条基线。更稳的方式是:

  1. 保留 ov1_local_spatial 不动,作为 frozen baseline。
  2. 新实验单独开新 preset。
  3. 新实验优先从:
    • checkpoints/spatial_beats_ov1_local_spatial_run1/best.pt
    • DEFAULT_OV1_LOCAL_SPATIAL_INIT warm-start。
  4. 任何涉及下面这些点的修改,都应视为“新架构”,不要再说和 run1 一样:
    • semantic aux classifier
    • 原始 W-BEATs special frontend
    • class/spatial 分离 readout
    • 新的 frame-level supervision path

当前最安全的角色定位是:

ov1_local_spatial = 已恢复并可复现的 0410 单源 local-spatial baseline