Spatial-BEATs / docs /v13d_full_hyperparameters.md
dieKarotte's picture
Add files using upload-large-folder tool
29ab2d0 verified
|
Raw
History Blame Contribute Delete
8.35 kB

Spatial-BEATs v13d 完整超参数与实现细节附录

适用于 NeurIPS 论文附录。本附录详尽地列出 v13d 模型的全部架构超参数、训练超参数、损失函数权重、数据预处理参数以及优化器配置。所有数值均与 train_spatial_beats.py::make_ov1_unified_v13d_config() 以及 run_ov1_unified_v13d.sh 中的代码一致。


A. 输入与特征提取

参数 取值 说明
采样率 16 kHz FOA 4 通道,顺序 [W, X, Y, Z]
单 clip 时长 10 s 输入波形形状 [B, 4, 160000]
STFT n_fft 400 Qwen-2.5-Omni 对齐
STFT hop_length 160 时间步长 10 ms
STFT win_length 400 窗长 25 ms
窗函数 Hann
Mel 滤波器组数 128 f_min=0, f_max=8000
时间帧数 T_f 1000 10 s × 100 帧/s
输入特征通道数 7 4 个 mel (W/X/Y/Z) + 3 个 IV (x/y/z)
IV 公式 IV_d = Re[W · conj(X_d)] / (|W|² + ε) ε=1e-8,IV 经 mel 投影后 clamp 到 ±10
W 通道归一化 mean=15.41663, std=6.55582 BEATs 预训练统计量
SpecAugment(仅 W 通道) 2 个时间 mask × 100 帧, 2 个频率 mask × 27 bin 训练时启用

B. 模型架构超参数

B.1 SpatialDeltaPatchAdapter (v1)

参数 取值
输入通道数 7
隐藏通道数 32
输出维度 512(patch embedding 维度)
Patch size (16, 16), stride=16
残差缩放 α 初始化 0.1(可学习)
结构 Conv2d(7→32, 1×1) → GELU → DWConv(32, 3×3) → GELU → Conv2d(32→512, 16×16, s=16)

B.2 SpatialPatchEmbedding(继承 BEATs)

  • 单通道(W)patch embedding,预训练权重不修改
  • 输出 token 数 = 496(10 s clip)
  • Hidden = 512,再投影至 768

B.3 BEATs Transformer Trunk

参数 取值
Layer 数 12
Hidden 维度 768
注意力头数 12
FFN 维度 3072
相对位置偏置 sinusoidal + GRU gating
Trunk adapter 1 层 spectral demixer,零门控初始化(继承 v11a 的 use_spatial_head_demixer=True

B.4 LocalSpatialEncoder(并行空间分支)

参数 取值
输入 7 通道 FOA 特征 [B, 7, T_f, 128]
CNN block 1 Conv2d(7→64, 3×3) + GroupNorm(8) + GELU
CNN block 2 Conv2d(64→128, 3×3, stride=(1,2)) + GroupNorm(8) + GELU
CNN block 3 Conv2d(128→256, 3×3, stride=(1,2)) + GroupNorm(16) + GELU
频率维度处理 在最终 GN 后对频率轴做 mean → [B, T_f, 256]
Transformer 层数 2
Transformer hidden 256
Transformer heads 4
Norm 顺序 norm_first = True (pre-LN)
Dropout 0.1
输出投影 Linear(256 → 768)

B.5 FrequencyPool + TemporalResampler

  • FrequencyPool:reshape [B, 496, 768] → [B, 62, 8, 768],频率轴均值 → [B, 62, 768]
  • TemporalResampler:线性插值到 10 Hz 网格 → [B, T_s=100, 768]
  • Token 频率 = 10 Hz(继承自 v9_real_balanced_10hz)

B.6 LocalSpatialCrossFuser(语义-空间融合)

参数 取值
模式 cross_attn_gated
层数 2
Embed 维度 768
注意力头数 8
Gate bias -2.0(即 sigmoid(-2.0)≈0.119 初始化)
Direct gate bias -1.5(sigmoid≈0.182)
ShallowTemporalReadout 1 层 Transformer + LayerNorm
输出 fused_tokens [B, T_s=100, 768]

B.7 SourceQueryDecoder(多源解耦)

参数 取值
Track query 数 K 4
Stage-1 层数 2(TransformerDecoder)
Stage-2 层数 1(per-frame refinement + LN)
注意力头数 8
FFN 维度 3072
时间位置编码 可学习 [T_s, 768]
输出 [B, K=4, T_s=100, 768]

B.8 FrameTrackPredictionHeads(每个 (track, frame) 4 个预测头)

Head 结构 输出
Activity LayerNorm + Linear(768→1) logit ℓ ∈ ℝ
Class MLP + 残差 + spectral demixer 63 类 logits
Direction MLP(768→768→3) + L2 normalize 单位向量 ∈ ℝ³
Distance MLP(768→768→1) + softplus 距离(米)

C. 损失函数与权重

C.1 损失项与权重

损失项 权重 备注
lambda_frame_class 1.0 63 类 cross-entropy
lambda_frame_activity 1.0 Top-K rank loss(v13d 核心改动)
lambda_frame_direction 1.0 1 - cos(pred, gt)
lambda_frame_distance 1.0 smooth-L1
lambda_frame_hemisphere 1.0 半球 BCE(继承 v11a)

C.2 Top-K Rank Activity Loss(D-2)

Lrank=1P(i,j)Pmax(0,m+ji),Lact=Lrank+0.1LBCE\mathcal{L}_{\text{rank}} = \frac{1}{|P|}\sum_{(i,j)\in P}\max(0, m + \ell_j - \ell_i),\quad \mathcal{L}_{\text{act}} = \mathcal{L}_{\text{rank}} + 0.1 \cdot \mathcal{L}_{\text{BCE}}

超参数 取值
frame_activity_loss_type topk_rank
margin m 2.0
BCE anchor 权重 0.1

C.3 Spatial loss warmup / ramp(D-1)

阶段 Epoch 范围 空间 loss 权重
cls-only warmup 0 – 7(共 8 ep) 0
linear ramp 8 – 9(共 2 ep) 0 → 1
full joint training 10 – 24 1

对应 cfg:frame_spatial_loss_warmup_epochs=8, frame_spatial_loss_ramp_epochs=2.


D. 训练超参数

D.1 优化器

参数 取值
Optimizer AdamW
β₁, β₂ 0.9, 0.999
ε 1e-8
Weight decay 0.01
Gradient clipping 1.0(global L2 norm)
Resume optimizer state True(D-5:从 v12 best.pt 继承 Adam momentum)

D.2 学习率(Cosine schedule,D-1)

参数 取值
Peak LR 1.5e-5
Linear warmup epochs 3(LR 从 0 → peak)
Cosine decay epochs 22(peak → peak × min_ratio)
Min LR ratio 0.05(最低 LR = 7.5e-7)
use_cosine_lr True

D.3 训练规模

参数 取值
总 epoch 数 25
GPUs 8 × A100
单 GPU batch size 8
等效 batch size 64
数据并行 torchrun + DDP
精度 fp32
Num workers 8 / GPU
Hot-start checkpoint v12 best.pt(strict=False,missing=0/unexpected=0)

D.4 EMA shadow weights(D-6)

参数 取值
use_ema True
EMA decay 0.9995
启动 epoch 3(避开 LR warmup 噪声)
应用范围 验证、保存 best.pt 时使用 EMA 权重;训练 forward/backward 仍用原权重
实现方式 swap → evaluate → restore(不污染训练梯度)

E. 数据集与采样

参数 取值
训练 manifest unified_spatial_foa_fsd63_all/train.jsonl
训练样本总数 约 329 K
- sim_static 304 K
- dcase_real 20 K
- qa_sim 74 K
Manifest replication (1,)(v13d 不做真实数据加权)
词表 FSD50K 衍生 63 类(final_vocabulary.csv
验证集 ov1/ov2/ov3 sim + ov1/ov2/ov3 real + dcase_starss_valid + unified_valid(约 35 K)
数据增广 仅 W 通道 SpecAugment;不开启 v13b 的 random gain / channel dropout / lowpass

F. Hungarian 匹配与推理

参数 取值
匹配粒度 段级(segment-level,相同 active set 窗口内稳定分配)
匹配代价 activity + class CE + direction cosine + distance L1 加权和
推理时活跃 track 选择 top-K̂(DCASE SELD evaluator 统一标准),与训练 Top-K rank loss 对齐

G. 实际训练曲线(参考)

Epoch F20 oracle_cls azi MAE
0 0.311 0.650 28.6°
7(cls warmup 末) 0.193 0.786 31.0°
8(spatial 启动) 0.397 0.876 18.5°
10(当前最佳) 0.402 0.864 17.2°
25(预期) 0.43 ~ 0.46 ~0.88 17~19°

ep1→ep7 期间 F20 下降是 预期行为:cls warmup 中 trunk 逐步适配类别学习,但空间梯度被 mask 为 0,方向头无监督信号导致 azi 漂移。ep8 空间 loss 解锁后 F20 单 epoch 跃升 +107%(0.193 → 0.397),证明 D-1 ~ D-6 的训练机制改造工作正常。


H. 复现命令

# 默认 8 GPU、bs=8/GPU、peak LR=1.5e-5、25 epochs
GPUS=8 BATCH_SIZE=8 SPATIAL_EPOCHS=25 SPATIAL_LR=1.5e-5 \
  RESUME_CKPT=checkpoints/spatial_beats_ov1_unified_v12_exp/03_ov123_top4/best.pt \
  ./run_ov1_unified_v13d.sh

所有改动通过 cfg flag 控制,默认 False,因此 v12 / v13b / v13c 实验不受影响。