| # Spatial-BEATs 最终实施指南 (Reference Implementation Guide) |
|
|
| 本文档定义了 `Spatial-BEATs` 的模型架构、特征工程与训练流程的最终技术细节,作为代码实现的唯一参照。 |
|
|
| ## 1. 模型架构细节 (Architecture Specification) |
|
|
| ### 1.1 输入前端 (Stem) |
| - **输入特征图**: $7 \times 128 \times 1024$ (Channels $\times$ Mel-bins $\times$ Time-frames)。 |
| - **通道定义**: |
| - `[0:4]`: W, X, Y, Z 的 Log-mel。 |
| - `[4:7]`: IVx, IVy, IVz (Intensity Vector),按时间/频率对齐。 |
| - **Patch Embedding**: |
| - 结构: `nn.Conv2d(7, embed_dim, kernel_size=16, stride=16)`。 |
| - 初始化: 通道 0 (W) 复用 BEATs 预训练权重,通道 1-6 随机初始化。 |
|
|
| ### 1.2 空间 Token 提取 (Source Queries) |
| - **Token 数量 ($K$)**: 4 个。 |
| - **实现方式**: |
| - 定义 `nn.Parameter(torch.randn(1, 4, embed_dim))` 作为 Source Queries。 |
| - 使用 2 层 Transformer Decoder 层。 |
| - **Query**: Source Queries。 |
| - **Key/Value**: BEATs Trunk 的输出序列 (Dense Patch Tokens)。 |
| - **输出**: 4 个维度为 `embed_dim` 的 `Spatial Tokens`。 |
|
|
| ### 1.3 预测头 (Prediction Heads) |
| 每个 Spatial Token 独立连接以下 MLP 层: |
| - **Objectness**: `Linear -> Sigmoid` (1 unit)。 |
| - **Azimuth**: `Linear -> tanh` (2 units: $\sin, \cos$)。计算角度使用 `atan2`。 |
| - **Elevation**: `Linear -> tanh` (2 units: $\sin, \cos$)。 |
| - **Distance**: `Linear` (1 unit, 单位:**Centimeters**)。 |
| - **Class**: `Linear -> Softmax` (N units, 对应 FSD50k 类别)。 |
|
|
| ## 2. 坐标系与物理特征 (Spatial Physics) |
|
|
| ### 2.1 坐标系 (DCASE Standard) |
| - **轴向**: +x 前, +y 左, +z 上。 |
| - **方位角 (Azimuth)**: $[-180, 180]$,逆时针增加。+90 度为左,-90 度为右。 |
| - **仰角 (Elevation)**: $[-90, 90]$,向上增加。 |
| - **距离 (Distance)**: 以 **厘米 (cm)** 为单位进行回归。 |
|
|
| ### 2.2 IV 计算 (Intensity Vector) |
| 在特征提取阶段,按以下逻辑计算 IV: |
| - $I_x = \text{Re}\{W^* \cdot X\}$ |
| - $I_y = \text{Re}\{W^* \cdot Y\}$ |
| - $I_z = \text{Re}\{W^* \cdot Z\}$ |
| - 所有的 $I$ 均通过 Mel 滤波器组进行映射,以匹配 Log-mel 的分辨率。 |
| |
| ## 3. 训练策略 (Training Recipe) |
| |
| ### 3.1 损失函数 (Hungarian Loss) |
| - **匹配算法**: 使用 `scipy.optimize.linear_sum_assignment` (Hungarian Matching) 匹配 4 个预测 Token 与 $N$ 个 GT 声源 ($N \le 4$)。 |
| - **匹配代价 (Matching Cost)**: 综合位置误差 (Az/El/Dist)、类别误差和 Objectness 分数。 |
| - **总损失**: |
| - 对匹配成功的 Token:计算 $L_{MSE}(pos) + L_{BCE}(obj) + L_{CrossEntropy}(cls)$。 |
| - 对未匹配的 Token:计算 $L_{BCE}(obj, 0)$。 |
| |
| ### 3.2 训练阶段 |
| 1. **Stage 1 (Stem & Head Warmup)**: 冻结 BEATs Trunk (Transformer 层),仅训练新 Patch Embedding 和 Spatial Decoder/Heads。 |
| 2. **Stage 2 (Joint Fine-tuning)**: 以 $1 \times 10^{-5}$ 的低学习率解冻整个 Trunk 进行微调。 |
| |
| ## 4. LLM 接入接口 (LLM Interface) |
| - 提取后的 4 个 `Spatial Tokens` 将通过一个 `Linear` 投影层对齐到 LLM 的隐藏层空间。 |
| - 在 Prompt 中,这 4 个 tokens 将按 object-wise 顺序排列,代表音频中的空间实体。 |
| |