Spatial-BEATs / docs /gemini.md
dieKarotte's picture
Add files using upload-large-folder tool
bf04039 verified
|
Raw
History Blame Contribute Delete
3.18 kB

Spatial-BEATs 最终实施指南 (Reference Implementation Guide)

本文档定义了 Spatial-BEATs 的模型架构、特征工程与训练流程的最终技术细节,作为代码实现的唯一参照。

1. 模型架构细节 (Architecture Specification)

1.1 输入前端 (Stem)

  • 输入特征图: $7 \times 128 \times 1024$ (Channels $\times$ Mel-bins $\times$ Time-frames)。
  • 通道定义:
    • [0:4]: W, X, Y, Z 的 Log-mel。
    • [4:7]: IVx, IVy, IVz (Intensity Vector),按时间/频率对齐。
  • Patch Embedding:
    • 结构: nn.Conv2d(7, embed_dim, kernel_size=16, stride=16)
    • 初始化: 通道 0 (W) 复用 BEATs 预训练权重,通道 1-6 随机初始化。

1.2 空间 Token 提取 (Source Queries)

  • Token 数量 ($K$): 4 个。
  • 实现方式:
    • 定义 nn.Parameter(torch.randn(1, 4, embed_dim)) 作为 Source Queries。
    • 使用 2 层 Transformer Decoder 层。
    • Query: Source Queries。
    • Key/Value: BEATs Trunk 的输出序列 (Dense Patch Tokens)。
  • 输出: 4 个维度为 embed_dimSpatial Tokens

1.3 预测头 (Prediction Heads)

每个 Spatial Token 独立连接以下 MLP 层:

  • Objectness: Linear -> Sigmoid (1 unit)。
  • Azimuth: Linear -> tanh (2 units: $\sin, \cos$)。计算角度使用 atan2
  • Elevation: Linear -> tanh (2 units: $\sin, \cos$)。
  • Distance: Linear (1 unit, 单位:Centimeters)。
  • Class: Linear -> Softmax (N units, 对应 FSD50k 类别)。

2. 坐标系与物理特征 (Spatial Physics)

2.1 坐标系 (DCASE Standard)

  • 轴向: +x 前, +y 左, +z 上。
  • 方位角 (Azimuth): $[-180, 180]$,逆时针增加。+90 度为左,-90 度为右。
  • 仰角 (Elevation): $[-90, 90]$,向上增加。
  • 距离 (Distance): 以 厘米 (cm) 为单位进行回归。

2.2 IV 计算 (Intensity Vector)

在特征提取阶段,按以下逻辑计算 IV:

  • $I_x = \text{Re}{W^* \cdot X}$
  • $I_y = \text{Re}{W^* \cdot Y}$
  • $I_z = \text{Re}{W^* \cdot Z}$
  • 所有的 $I$ 均通过 Mel 滤波器组进行映射,以匹配 Log-mel 的分辨率。

3. 训练策略 (Training Recipe)

3.1 损失函数 (Hungarian Loss)

  • 匹配算法: 使用 scipy.optimize.linear_sum_assignment (Hungarian Matching) 匹配 4 个预测 Token 与 $N$ 个 GT 声源 ($N \le 4$)。
  • 匹配代价 (Matching Cost): 综合位置误差 (Az/El/Dist)、类别误差和 Objectness 分数。
  • 总损失:
    • 对匹配成功的 Token:计算 $L_{MSE}(pos) + L_{BCE}(obj) + L_{CrossEntropy}(cls)$。
    • 对未匹配的 Token:计算 $L_{BCE}(obj, 0)$。

3.2 训练阶段

  1. Stage 1 (Stem & Head Warmup): 冻结 BEATs Trunk (Transformer 层),仅训练新 Patch Embedding 和 Spatial Decoder/Heads。
  2. Stage 2 (Joint Fine-tuning): 以 $1 \times 10^{-5}$ 的低学习率解冻整个 Trunk 进行微调。

4. LLM 接入接口 (LLM Interface)

  • 提取后的 4 个 Spatial Tokens 将通过一个 Linear 投影层对齐到 LLM 的隐藏层空间。
  • 在 Prompt 中,这 4 个 tokens 将按 object-wise 顺序排列,代表音频中的空间实体。