Spatial-BEATs 最终实施指南 (Reference Implementation Guide)
本文档定义了 Spatial-BEATs 的模型架构、特征工程与训练流程的最终技术细节,作为代码实现的唯一参照。
1. 模型架构细节 (Architecture Specification)
1.1 输入前端 (Stem)
- 输入特征图: $7 \times 128 \times 1024$ (Channels $\times$ Mel-bins $\times$ Time-frames)。
- 通道定义:
[0:4]: W, X, Y, Z 的 Log-mel。[4:7]: IVx, IVy, IVz (Intensity Vector),按时间/频率对齐。
- Patch Embedding:
- 结构:
nn.Conv2d(7, embed_dim, kernel_size=16, stride=16)。 - 初始化: 通道 0 (W) 复用 BEATs 预训练权重,通道 1-6 随机初始化。
- 结构:
1.2 空间 Token 提取 (Source Queries)
- Token 数量 ($K$): 4 个。
- 实现方式:
- 定义
nn.Parameter(torch.randn(1, 4, embed_dim))作为 Source Queries。 - 使用 2 层 Transformer Decoder 层。
- Query: Source Queries。
- Key/Value: BEATs Trunk 的输出序列 (Dense Patch Tokens)。
- 定义
- 输出: 4 个维度为
embed_dim的Spatial Tokens。
1.3 预测头 (Prediction Heads)
每个 Spatial Token 独立连接以下 MLP 层:
- Objectness:
Linear -> Sigmoid(1 unit)。 - Azimuth:
Linear -> tanh(2 units: $\sin, \cos$)。计算角度使用atan2。 - Elevation:
Linear -> tanh(2 units: $\sin, \cos$)。 - Distance:
Linear(1 unit, 单位:Centimeters)。 - Class:
Linear -> Softmax(N units, 对应 FSD50k 类别)。
2. 坐标系与物理特征 (Spatial Physics)
2.1 坐标系 (DCASE Standard)
- 轴向: +x 前, +y 左, +z 上。
- 方位角 (Azimuth): $[-180, 180]$,逆时针增加。+90 度为左,-90 度为右。
- 仰角 (Elevation): $[-90, 90]$,向上增加。
- 距离 (Distance): 以 厘米 (cm) 为单位进行回归。
2.2 IV 计算 (Intensity Vector)
在特征提取阶段,按以下逻辑计算 IV:
- $I_x = \text{Re}{W^* \cdot X}$
- $I_y = \text{Re}{W^* \cdot Y}$
- $I_z = \text{Re}{W^* \cdot Z}$
- 所有的 $I$ 均通过 Mel 滤波器组进行映射,以匹配 Log-mel 的分辨率。
3. 训练策略 (Training Recipe)
3.1 损失函数 (Hungarian Loss)
- 匹配算法: 使用
scipy.optimize.linear_sum_assignment(Hungarian Matching) 匹配 4 个预测 Token 与 $N$ 个 GT 声源 ($N \le 4$)。 - 匹配代价 (Matching Cost): 综合位置误差 (Az/El/Dist)、类别误差和 Objectness 分数。
- 总损失:
- 对匹配成功的 Token:计算 $L_{MSE}(pos) + L_{BCE}(obj) + L_{CrossEntropy}(cls)$。
- 对未匹配的 Token:计算 $L_{BCE}(obj, 0)$。
3.2 训练阶段
- Stage 1 (Stem & Head Warmup): 冻结 BEATs Trunk (Transformer 层),仅训练新 Patch Embedding 和 Spatial Decoder/Heads。
- Stage 2 (Joint Fine-tuning): 以 $1 \times 10^{-5}$ 的低学习率解冻整个 Trunk 进行微调。
4. LLM 接入接口 (LLM Interface)
- 提取后的 4 个
Spatial Tokens将通过一个Linear投影层对齐到 LLM 的隐藏层空间。 - 在 Prompt 中,这 4 个 tokens 将按 object-wise 顺序排列,代表音频中的空间实体。