| # Spatial-BEATs Coding Guide |
|
|
| ## 1. 本文档的作用 |
|
|
| 本文档是 `Spatial-BEATs` 的最终代码实施指南,用于直接指导后续代码开发。 |
|
|
| 它基于当前已经确认的项目约束: |
|
|
| - FOA 采样率统一到 `16 kHz` |
| - 每个样本最大同时声源数约为 `4` |
| - 每个声源都有稳定的 `source-level class label` |
| - `distance` 使用连续回归 |
| - `Spatial-BEATs` 拥有自己的 `projector` |
| - 不需要与原始语义 audio encoder 做表示对齐 |
| - 目标 `spatial token rate` 约为 `2.5 Hz` |
| - 允许增加 `source class auxiliary head` |
|
|
| 本文档应视为后续实现的主参考。 |
|
|
| ## 2. 最终设计结论 |
|
|
| ### 2.1 总体目标 |
|
|
| 构建一个独立的 `Spatial-BEATs`: |
|
|
| - 输入完整 `FOA waveform` |
| - 从 `FOA` 中计算空间特征 |
| - 将完整空间特征送入 `BEATs backbone` |
| - 输出可输入 LLM 的 `spatial tokens` |
|
|
| 注意: |
|
|
| - 不是 `W-only` |
| - 不是外挂小 adapter |
| - 不是在原有语义 encoder 内部混合空间分支 |
|
|
| 而是: |
|
|
| - 一个独立的 `Spatial Encoder` |
| - 最大化复用 `BEATs trunk` |
| - 最终输出自己的空间 token 序列 |
|
|
| ### 2.2 关键实现原则 |
|
|
| 1. **完整 FOA 特征经过 BEATs 主干** |
| 2. **尽量不改 BEATs trunk 内部 Transformer** |
| 3. **重做输入 stem** |
| 4. **重做输出头和 token 生成方式** |
| 5. **主训练目标是多源空间建模,不是 clip-level 分类** |
|
|
| ## 3. 最终模型架构 |
|
|
| 推荐最终架构如下: |
|
|
| ```text |
| FOA waveform [B, 4, T] |
| -> SpatialBEATsPreprocessor |
| -> FOA feature map [B, C_foa, T_f, F] |
| -> SpatialPatchEmbedding |
| -> BEATs trunk |
| -> Patch grid reshape |
| -> Temporal downsampler (to 2.5 Hz) |
| -> Slot query decoder |
| -> Source slot tokens [B, T_s, K, D] |
| -> Prediction heads |
| -> Spatial projector |
| -> LLM spatial tokens [B, N_keep, d_llm] |
| ``` |
|
|
| 其中: |
|
|
| - `T_s` 是时间 token 数 |
| - `K` 是每个时间步最大 source slot 数 |
| - `D` 是 BEATs hidden dim |
| - `d_llm` 是 LLM hidden dim |
|
|
| ## 4. 固定超参与默认取值 |
|
|
| ### 4.1 输入参数 |
|
|
| - sample rate: `16000` |
| - mel bins: `128` |
| - frame length: `25 ms` |
| - frame shift: `10 ms` |
|
|
| ### 4.2 token 相关参数 |
|
|
| - token rate: `2.5 Hz` |
| - 对应时间间隔:`400 ms` |
| - 对于 `10 s` 样本: |
| - `T_s = 25` |
|
|
| ### 4.3 source slot 参数 |
|
|
| - 最大同时源数:`4` |
| - 默认 `K = 4` |
|
|
| 说明: |
|
|
| - 第一版直接令 `K = 4` |
| - 不额外引入冗余 slot |
| - 如果后续发现数据中存在漏标、异常源或更复杂重叠,再考虑改成 `K = 5/6` |
|
|
| ### 4.4 输入通道数 |
|
|
| 默认推荐: |
|
|
| - `W_logmel` |
| - `X_logmel` |
| - `Y_logmel` |
| - `Z_logmel` |
| - `IVx` |
| - `IVy` |
| - `IVz` |
|
|
| 因此: |
|
|
| - `C_foa = 7` |
|
|
| ## 5. 输入特征定义 |
|
|
| ### 5.1 推荐特征形式 |
|
|
| 第一版明确使用: |
|
|
| - `WXYZ log-mel` |
| - `IVx, IVy, IVz` |
|
|
| 其中: |
|
|
| - `WXYZ` 提供 ambisonic 通道信息 |
| - `IV` 提供显式方向 cue |
|
|
| ### 5.2 IV 计算建议 |
|
|
| 建议在 STFT 域中计算 intensity vector,然后再映射到 mel 维: |
|
|
| ```text |
| IVx ~ Re(conj(W) * X) |
| IVy ~ Re(conj(W) * Y) |
| IVz ~ Re(conj(W) * Z) |
| ``` |
|
|
| 可再配合能量归一化: |
|
|
| ```text |
| IV = IV / (|W|^2 + |X|^2 + |Y|^2 + |Z|^2 + eps) |
| ``` |
|
|
| 实现时可以先得到频域 IV,再通过 mel filter bank 压到 `128` mel bins。 |
|
|
| ### 5.3 为什么不用 binaural IPD |
|
|
| 当前任务是 `FOA`,不是 binaural。 |
|
|
| Spatial-AST 的 `mel + IPD` 经验可借鉴其结构思路,但不能直接复用其输入表示。 |
|
|
| 本项目应优先使用: |
|
|
| - FOA 通道本身 |
| - intensity vector |
|
|
| ## 6. 对 BEATs 代码的具体改造 |
|
|
| ## 6.1 尽量保留的部分 |
|
|
| 建议完全复用: |
|
|
| - `TransformerEncoder` |
| - `TransformerSentenceEncoderLayer` |
| - `MultiheadAttention` |
| - `conv_pos` |
| - `post_extract_proj` |
| - trunk 中的 `LayerNorm / FFN / attention` |
|
|
| 也就是说: |
|
|
| - [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py) 尽量不改 |
|
|
| ### 6.2 需要重写的部分 |
|
|
| 必须重写: |
|
|
| 1. `preprocess` |
| 2. `patch_embedding` |
| 3. `extract_features` 的输出形式 |
| 4. 原始 `predictor` |
|
|
| ### 6.3 推荐新增文件 |
|
|
| 建议新增如下文件: |
|
|
| - `spatial_beats.py` |
| - `spatial_modules.py` |
| - `spatial_loss.py` |
| - `spatial_dataset.py` |
| - `train_spatial_beats.py` |
| - 可选 `infer_spatial_beats.py` |
|
|
| ## 7. 预训练权重复用方案 |
|
|
| ## 7.1 推荐 checkpoint |
|
|
| 默认推荐: |
|
|
| - `BEATs_iter3+ (AS2M) pre-trained` |
|
|
| 不推荐第一版直接用 fine-tuned checkpoint 作为 trunk 初始化。 |
|
|
| ### 7.2 直接加载的层 |
|
|
| 建议直接加载: |
|
|
| - `post_extract_proj` |
| - `encoder.pos_conv` |
| - `encoder.layers.*` |
| - `encoder.layer_norm` |
| - `layer_norm` |
|
|
| 这些层使用: |
|
|
| - `strict=False` |
|
|
| 并打印缺失与不匹配项。 |
|
|
| ### 7.3 不能直接加载的层 |
|
|
| 以下层需要新初始化: |
|
|
| - 新的 `patch_embedding` |
| - `temporal downsampler` |
| - `slot query decoder` |
| - `prediction heads` |
| - `spatial projector` |
|
|
| ### 7.4 新 patch stem 的初始化 |
|
|
| 原始 BEATs stem: |
|
|
| ```text |
| Conv2d(1, embed_dim, kernel_size=patch, stride=patch) |
| ``` |
|
|
| 新的 stem: |
|
|
| ```text |
| Conv2d(7, embed_dim, kernel_size=patch, stride=patch) |
| ``` |
|
|
| 推荐初始化方案: |
|
|
| - `W_logmel` 通道继承原 BEATs stem 权重 |
| - `X/Y/Z/IVx/IVy/IVz` 通道初始化为较小随机值 |
|
|
| 推荐做法: |
|
|
| ```text |
| new_weight[:, 0, :, :] = old_weight[:, 0, :, :] |
| new_weight[:, 1:, :, :] ~ N(0, 0.02 * std(old_weight)) |
| ``` |
|
|
| 不推荐全部复制 inflation 作为默认方案。 |
| 第一版优先稳定,而不是让所有通道一开始等价共享单通道语义滤波器。 |
|
|
| ## 8. 代码结构建议 |
|
|
| ## 8.1 `spatial_modules.py` |
| |
| 建议包含以下模块: |
| |
| ### `SpatialBEATsPreprocessor` |
| |
| 职责: |
| |
| - 输入 `FOA waveform [B, 4, T]` |
| - 计算: |
| - `WXYZ logmel` |
| - `IVx, IVy, IVz` |
| - 输出: |
| - `foa_feat [B, 7, T_f, 128]` |
| |
| 建议接口: |
| |
| ```python |
| class SpatialBEATsPreprocessor(nn.Module): |
| def forward(self, waveforms: torch.Tensor) -> torch.Tensor: |
| ... |
| ``` |
| |
| ### `SpatialPatchEmbedding` |
| |
| 职责: |
| |
| - 对 `foa_feat` 做多通道 patch embedding |
|
|
| 建议接口: |
|
|
| ```python |
| class SpatialPatchEmbedding(nn.Module): |
| def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]: |
| # returns: |
| # tokens: [B, N_p, D_in] |
| # grid_hw: (T_p, F_p) |
| ... |
| ``` |
|
|
| ### `TemporalDownsampler` |
|
|
| 职责: |
|
|
| - 将 trunk 输出从 patch 时间分辨率下采样到 `2.5 Hz` |
|
|
| 建议输入输出: |
|
|
| - 输入:`grid memory [B, T_p, F_p, D]` |
| - 先对 `F_p` 做平均或轻量 attention pooling |
| - 得到:`temporal memory [B, T_p, D]` |
| - 再用线性插值或 1D conv 下采样到: |
| - `slot memory [B, T_s, D]` |
|
|
| 建议接口: |
|
|
| ```python |
| class TemporalDownsampler(nn.Module): |
| def forward(self, grid_x: torch.Tensor, target_steps: int) -> torch.Tensor: |
| # grid_x: [B, T_p, F_p, D] |
| # out: [B, T_s, D] |
| ... |
| ``` |
|
|
| 默认推荐: |
|
|
| - 第一版使用 `freq-mean + linear interpolate` |
|
|
| 原因: |
|
|
| - 简单 |
| - 稳定 |
| - 容易调试 |
|
|
| ### `SlotQueryDecoder` |
|
|
| 职责: |
|
|
| - 对每个时间步生成 `K=4` 个 source slots |
|
|
| 推荐设计: |
|
|
| - 为每个 slot 准备一个 learnable `slot embedding` |
| - 将时间 token `m_t` 与 slot embedding 相加,形成初始 query |
| - query 对 trunk memory 做 cross-attention |
|
|
| 建议输出: |
|
|
| - `slot_tokens [B, T_s, K, D]` |
|
|
| 建议接口: |
|
|
| ```python |
| class SlotQueryDecoder(nn.Module): |
| def forward( |
| self, |
| temporal_memory: torch.Tensor, |
| encoder_memory: torch.Tensor, |
| ) -> torch.Tensor: |
| # temporal_memory: [B, T_s, D] |
| # encoder_memory: [B, N_p, D] |
| # out: [B, T_s, K, D] |
| ... |
| ``` |
|
|
| 实现建议: |
|
|
| - 先用 `temporal_memory` 生成时间条件 query |
| - 再用 `2 层 TransformerDecoderLayer` 或自定义 cross-attn block |
|
|
| 第一版推荐: |
|
|
| - `2 层 decoder` |
| - hidden dim 与 trunk 一致 |
|
|
| ### `SpatialPredictionHead` |
|
|
| 职责: |
|
|
| - 对 `slot_tokens` 预测各任务输出 |
|
|
| 建议输出: |
|
|
| - `pred_obj: [B, T_s, K]` |
| - `pred_azi_logits: [B, T_s, K, 360]` |
| - `pred_ele_logits: [B, T_s, K, 180]` |
| - `pred_dist: [B, T_s, K, 1]` |
| - `pred_class_logits: [B, T_s, K, C_cls]` |
|
|
| ### `SpatialTokenProjector` |
|
|
| 职责: |
|
|
| - 将 slot latent 与结构化坐标信息组合 |
| - 投影到 LLM hidden size |
|
|
| 输出: |
|
|
| - `llm_tokens [B, N_keep, d_llm]` |
|
|
| ## 8.2 `spatial_beats.py` |
| |
| 建议定义: |
| |
| ### `SpatialBEATsConfig` |
| |
| 字段建议: |
| |
| - `sample_rate=16000` |
| - `num_mel_bins=128` |
| - `token_rate=2.5` |
| - `max_sources=4` |
| - `foa_channels=7` |
| - `distance_max_m` |
| - `llm_hidden_size` |
| - `use_class_aux=True` |
| - `num_decoder_layers=2` |
|
|
| ### `SpatialBEATs` |
|
|
| 建议结构: |
|
|
| ```python |
| class SpatialBEATs(nn.Module): |
| def __init__(self, cfg, beats_ckpt=None): |
| ... |
| |
| def extract_spatial_features(self, waveforms): |
| ... |
| |
| def extract_spatial_tokens(self, waveforms, audio_lengths=None): |
| ... |
| |
| def project_tokens_for_llm(self, slot_tokens, preds, keep_mask=None): |
| ... |
| |
| def forward(self, waveforms, audio_lengths=None, targets=None): |
| ... |
| ``` |
|
|
| ### `forward()` 推荐返回形式 |
|
|
| 返回字典: |
|
|
| ```python |
| { |
| "encoder_memory": ..., |
| "slot_tokens": ..., |
| "pred_obj": ..., |
| "pred_azi_logits": ..., |
| "pred_ele_logits": ..., |
| "pred_dist": ..., |
| "pred_class_logits": ..., |
| "llm_tokens": ..., |
| "llm_token_mask": ..., |
| "token_meta": ..., |
| } |
| ``` |
|
|
| ## 8.3 `spatial_loss.py` |
| |
| 建议定义: |
| |
| ### `HungarianMatcher` |
| |
| 输入: |
| |
| - 预测输出 |
| - GT targets |
| |
| 输出: |
| |
| - 每个样本每个时间步的匹配索引 |
| |
| ### `SpatialSetCriterion` |
| |
| 计算: |
| |
| - objectness loss |
| - azimuth loss |
| - elevation loss |
| - distance regression loss |
| - class auxiliary loss |
| |
| 可选: |
| |
| - temporal smoothness loss |
| |
| ## 8.4 `spatial_dataset.py` |
|
|
| 建议数据格式: |
|
|
| ```python |
| sample = { |
| "waveform": FloatTensor[4, T], |
| "duration_s": float, |
| "sources": [ |
| { |
| "class_id": int, |
| "azimuth_deg": float, |
| "elevation_deg": float, |
| "distance_m": float, |
| "start_s": float, |
| "end_s": float, |
| "is_time_weak": bool, |
| "is_position_dynamic": bool, |
| "trajectory": optional, |
| }, |
| ... |
| ] |
| } |
| ``` |
|
|
| 当前推荐额外保留以下字段: |
|
|
| - `start_s`: 源开始进入场景的时间 |
| - `end_s`: 若有则保留,否则可由 `start_s + length_s` 推出 |
| - `length_s`: 原始 source clip 时长 |
| - `is_time_weak`: 当前时间边界是否只是弱监督 |
| - `is_position_dynamic`: 该源位置是否随时间变化 |
| - `trajectory`: 若位置变化,则存储分段轨迹或逐帧轨迹 |
|
|
| 如果当前只有源进入时间和原始 source length,可统一转成: |
|
|
| - `start_s = start time` |
| - `end_s = start_s + length_s` |
| - `is_time_weak = True` |
|
|
| ## 9. 时间建模、2.5 Hz 输出与弱时间监督 |
|
|
| 这是当前实现中最关键的新增约束之一。 |
|
|
| ### 9.1 token rate 的解释 |
|
|
| 这里定义: |
|
|
| - 每个 `source slot stream` 的输出速率为 `2.5 Hz` |
|
|
| 即: |
|
|
| - 每个时间步间隔 `400 ms` |
|
|
| ### 9.2 输出张量形状 |
|
|
| 对时长为 `L` 秒的样本: |
|
|
| ```text |
| T_s = round(L * 2.5) |
| ``` |
|
|
| 例如: |
|
|
| - `10 s -> 25` |
|
|
| 最终 slot token 形状: |
|
|
| ```text |
| [B, T_s, K, D] |
| ``` |
|
|
| 在当前默认配置下: |
|
|
| - `K = 4` |
|
|
| 也就是最多 `4` 条并行 source slot 流。 |
|
|
| ### 9.3 当前时间标注的含义 |
|
|
| 你当前可提供的时间信息是: |
|
|
| - 知道每个 source 的 `start time` |
| - 知道原始 `FSD50K source clip` 的 `length` |
| - 但这段 `length` 内不保证每一时刻都真的 active |
|
|
| 因此,当前不应把 `[start_s, end_s]` 视为严格逐帧真值,而应视为: |
|
|
| - `weak temporal support window` |
|
|
| 也就是: |
|
|
| - 源最可能出现的候选时间范围 |
| - 不是精确的逐帧 activity annotation |
|
|
| ### 9.4 第一版如何把弱时间标注映射到 2.5 Hz token |
|
|
| 对于第 `t` 个时间步,其中心时刻记为 `tau_t`。 |
|
|
| 对每个 GT source,定义: |
|
|
| - `candidate window = [start_s, end_s]` |
|
|
| 第一版推荐构造三种 mask: |
|
|
| 1. `pos_window_mask` |
| - `tau_t` 落在 `[start_s, end_s]` 内 |
| 2. `neg_window_mask` |
| - `tau_t` 明确落在窗口外 |
| 3. `ignore_mask` |
| - 可选,用于窗口边界附近或不确定区域 |
|
|
| 当前默认建议最简单实现: |
|
|
| - 窗口外:作为 objectness 负样本 |
| - 窗口内:作为弱正样本候选 |
|
|
| 但不要对窗口内所有步都施加强监督位置 loss。 |
|
|
| ### 9.5 当前 loss 需要怎么改 |
|
|
| 为了适应弱时间标注,当前建议把 loss 拆成两层: |
|
|
| #### `L_obj` |
| |
| - 对窗口外,正常做负样本监督 |
| - 对窗口内,做弱正样本监督 |
| |
| 推荐: |
| |
| - 使用 `BCE` 或 `focal loss` |
| - 窗口内正样本权重降低 |
| |
| 例如: |
| |
| ```text |
| w_obj_pos_weak = 0.3 ~ 0.5 |
| w_obj_neg = 1.0 |
| ``` |
| |
| #### `L_azi / L_ele / L_dist / L_cls` |
| |
| 第一版不要在窗口内所有时间步都强制监督。 |
| 推荐只在以下位置监督: |
| |
| - 与 GT source 匹配且 `pred_obj` 较高的 slot |
| - 或窗口内的 top-k 高置信时间步 |
| |
| 更稳的第一版做法: |
| |
| - 先对每个 GT source 在窗口内选择 `top-1` 或 `top-2` 个 objectness 最高的时间步参与坐标/类别监督 |
| |
| 这样可以避免: |
| |
| - 源在窗口内部分时间其实不 active |
| - 但模型被错误惩罚 |
| |
| ### 9.6 推荐的第一版弱监督匹配策略 |
| |
| 当前建议采用两阶段匹配: |
| |
| 1. 先按时间窗口过滤候选时间步 |
| 2. 再在候选时间步内做 slot matching |
| |
| 更具体地说: |
| |
| - 对每个 GT source,只允许匹配其时间窗口内的 `slot tokens` |
| - 在这些候选中选出最优 `(t, k)` |
| |
| 这比直接对所有 `[T_s, K]` 位置做全局 Hungarian 更稳。 |
| |
| 第一版推荐: |
| |
| - `per-source best-of-window matching` |
| |
| 而不是: |
| |
| - 全局 dense set matching |
| |
| 原因: |
| |
| - 你当前时间标注是弱的 |
| - 先用窗口约束大幅降低匹配歧义更现实 |
| |
| ### 9.7 推理时 token 序列长度不需要改变 |
| |
| `2.5 Hz` 的 token rate 不需要变。 |
| 要改的是: |
| |
| - 训练 supervision 的构造方式 |
| - objectness 与坐标 loss 的作用范围 |
| |
| ### 9.8 未来如果拿到更好的 activity 标注 |
| |
| 如果后续可以拿到: |
| |
| - energy-based active mask |
| - frame-level source activity |
| - VAD / source activation probability |
| |
| 则可把当前的弱时间监督替换成: |
| |
| - `strong temporal supervision` |
| |
| 到时只需替换 target 构造和 criterion,不需要改主模型结构。 |
| |
| ### 9.9 喂给 LLM 时的 token 数量 |
| |
| 如果全部展开,理论最大 token 速率为: |
| |
| ```text |
| 2.5 Hz * 4 = 10 tokens / second |
| ``` |
| |
| 但推理时可通过 `objectness` 做过滤,所以通常会低于这个上限。 |
| |
| ### 9.10 LLM 展开顺序 |
| |
| 建议按如下顺序展开: |
| |
| - 先按时间排序 |
| - 每个时间步内部按 `objectness` 从高到低排序 |
| |
| 也就是: |
| |
| ```text |
| t1_s1, t1_s2, t1_s3, t1_s4, t2_s1, t2_s2, ... |
| ``` |
| |
| 然后再过滤低置信 slot。 |
| |
| ## 10. 输出给 LLM 的 spatial token 形式 |
| |
| ## 10.1 不直接喂原始 logits |
| |
| 不建议直接把: |
| |
| - 方位分类 logits |
| - 类别 logits |
| - 距离标量 |
| |
| 直接作为 token 输入给 LLM。 |
| |
| ### 10.2 推荐 token 构造方式 |
| |
| 每个 slot token `z_{t,k}` 最终形成一个结构化 token: |
| |
| ```text |
| s_{t,k} = Proj([z_{t,k} ; c_{t,k} ; u_{t,k} ; d_{t,k} ; o_{t,k}]) |
| ``` |
| |
| 其中: |
| |
| - `z_{t,k}`: slot latent |
| - `c_{t,k}`: source class context embedding |
| - `u_{t,k}`: 方向向量 |
| - `d_{t,k}`: 连续距离 embedding |
| - `o_{t,k}`: objectness/confidence embedding |
| |
| ### 10.3 各项具体建议 |
| |
| #### `c_{t,k}`: 类别上下文 |
| |
| 由 `pred_class_logits` 构造: |
| |
| ```text |
| p_cls = softmax(pred_class_logits) |
| c = p_cls @ E_cls |
| ``` |
| |
| 其中: |
| |
| - `E_cls` 是一个可学习类别 embedding 表 |
|
|
| 作用: |
|
|
| - 给 spatial token 少量语义 grounding |
| - 但不需要与原始 audio encoder 对齐 |
|
|
| #### `u_{t,k}`: 方向向量 |
| |
| 先由: |
| |
| - `pred_azi_logits` |
| - `pred_ele_logits` |
| |
| 得到预测角度,再转换成单位球坐标向量: |
| |
| ```text |
| u = [x, y, z] |
| ``` |
| |
| 推荐实现: |
| |
| - 训练时用分布期望或 soft-argmax |
| - 推理时可用 argmax |
| |
| #### `d_{t,k}`: 连续距离表示 |
|
|
| 由于 distance 是连续回归,建议: |
|
|
| - 对 `pred_dist` 做归一化 |
| - 再经一个小 MLP 变成 embedding |
|
|
| #### `o_{t,k}`: 置信度表示 |
| |
| 由 `pred_obj` 经 sigmoid 得到 objectness,再做小 MLP 映射。 |
|
|
| ### 10.4 projector 的最终作用 |
|
|
| `SpatialTokenProjector` 的任务是把: |
|
|
| - slot latent |
| - class context |
| - direction vector |
| - distance embedding |
| - objectness embedding |
|
|
| 融合并投影到: |
|
|
| - `d_llm` |
|
|
| 输出: |
|
|
| ```text |
| llm_tokens: [B, N_keep, d_llm] |
| ``` |
|
|
| ### 10.5 是否需要与原 audio encoder 对齐 |
|
|
| 当前结论: |
|
|
| - **不需要** |
|
|
| 因此: |
|
|
| - 这个 projector 完全独立训练 |
| - 只服务于 `Spatial-BEATs -> LLM` |
|
|
| ## 11. Loss 设计 |
|
|
| ## 11.1 任务头与 loss |
|
|
| 推荐 loss 组成: |
|
|
| ```text |
| L_total = |
| lambda_obj * L_obj |
| + lambda_azi * L_azi |
| + lambda_ele * L_ele |
| + lambda_dist * L_dist |
| + lambda_cls * L_cls |
| ``` |
|
|
| ### 11.2 各项定义 |
|
|
| - `L_obj`: BCE 或 focal loss,支持弱正样本权重 |
| - `L_azi`: cross entropy |
| - `L_ele`: cross entropy |
| - `L_dist`: SmoothL1 / Huber |
| - `L_cls`: cross entropy |
|
|
| ### 11.3 distance 回归的实现 |
|
|
| 由于你已经明确要连续回归,推荐: |
|
|
| - head 输出 `pred_dist_norm in [0, 1]` |
| - 再乘以 `distance_max_m` |
|
|
| 训练时使用: |
|
|
| - `SmoothL1Loss(pred_dist_norm, gt_dist_norm)` |
|
|
| 优点: |
|
|
| - 比直接回归未归一化距离更稳 |
| - 比 MSE 更抗异常值 |
|
|
| ### 11.4 推荐初始权重 |
|
|
| 建议第一版从以下权重起步: |
|
|
| ```text |
| lambda_obj = 1.0 |
| lambda_azi = 2.0 |
| lambda_ele = 2.0 |
| lambda_dist = 1.0 |
| lambda_cls = 0.5 |
| ``` |
|
|
| 这里把 `class auxiliary` 权重从之前建议的 `0.25` 提到 `0.5`,因为现在你已经确认: |
|
|
| - 每个源有稳定的 `source-level class label` |
| - Spatial-BEATs 自身保留一定语义信息是可接受的 |
|
|
| ### 11.5 当前版本的匹配方式修订 |
|
|
| 由于当前时间 supervision 是弱的,第一版不建议直接做: |
|
|
| - 全局 `Hungarian matching` over `[T_s, K]` |
|
|
| 更推荐: |
|
|
| 1. 先根据 `source time window` 过滤候选时间步 |
| 2. 再在候选窗口内做匹配 |
|
|
| 推荐实现: |
|
|
| - `window-constrained matching` |
|
|
| 可选两种方式: |
|
|
| #### 方案 A:推荐默认方案 |
|
|
| - 对每个 GT source |
| - 在其 window 内所有 `(t, k)` 候选中,选择 cost 最小的一对 |
|
|
| 这本质上是: |
|
|
| - `best-of-window assignment` |
|
|
| 优点: |
|
|
| - 简单 |
| - 稳定 |
| - 对弱时间监督更友好 |
|
|
| #### 方案 B:后续增强方案 |
|
|
| - 对每个时间步分别做 Hungarian |
| - 再加时间连续性正则 |
|
|
| 这更适合将来位置随时间变化时使用。 |
|
|
| ### 11.6 匹配 cost |
|
|
| Hungarian matching cost 建议: |
|
|
| ```text |
| cost = |
| w_obj * cost_obj |
| + w_azi * cost_azi |
| + w_ele * cost_ele |
| + w_dist * cost_dist |
| + w_cls * cost_cls |
| ``` |
|
|
| 推荐初值: |
|
|
| ```text |
| w_obj = 1.0 |
| w_azi = 2.0 |
| w_ele = 2.0 |
| w_dist = 1.0 |
| w_cls = 1.0 |
| ``` |
|
|
| ## 12. 训练策略 |
|
|
| ### 12.1 第一阶段是否需要 SSL |
|
|
| 当前明确结论: |
|
|
| - 第一版 **不做新的 BEATs 式 SSL** |
|
|
| 理由: |
|
|
| - 已有空间 GT |
| - 已有 source class GT |
| - 已有强 trunk 预训练 |
| - 当前主要目标是空间结构建模 |
|
|
| ### 12.2 推荐训练阶段 |
|
|
| #### Stage A: Warmup |
|
|
| 冻结: |
|
|
| - trunk 大部分层 |
|
|
| 训练: |
|
|
| - preprocessor |
| - patch stem |
| - temporal downsampler |
| - slot query decoder |
| - prediction heads |
| - projector |
|
|
| #### Stage B: Upper-trunk finetune |
|
|
| 解冻: |
|
|
| - trunk 上层若干层 |
|
|
| #### Stage C: Wider finetune |
|
|
| 逐步解冻更多层,直到性能稳定。 |
|
|
| ### 12.3 训练时建议增加的正则项 |
|
|
| 当前位置在 clip 内固定,因此建议增加: |
|
|
| - `temporal consistency loss` |
|
|
| 具体可对同一 source 在相邻时间步的预测加约束: |
|
|
| - objectness 平滑 |
| - azimuth/elevation 分布平滑 |
| - distance 平滑 |
|
|
| 第一版可选实现: |
|
|
| ```text |
| L_temp = |
| smooth(pred_obj_t, pred_obj_{t+1}) |
| + smooth(pred_dist_t, pred_dist_{t+1}) |
| ``` |
|
|
| 由于当前位置固定,这类正则通常有利于稳定训练。 |
|
|
| ### 12.4 学习率建议 |
|
|
| 推荐: |
|
|
| ```text |
| lr_trunk = 1e-5 ~ 5e-5 |
| lr_new = 1e-4 ~ 5e-4 |
| ``` |
|
|
| 并使用: |
|
|
| - weight decay |
| - warmup |
| - layer-wise lr decay |
|
|
| ## 13. 训练与推理输出格式 |
|
|
| ## 13.1 训练时 `forward()` 输出 |
|
|
| 建议 `forward()` 返回: |
|
|
| ```python |
| { |
| "slot_tokens": FloatTensor[B, T_s, K, D], |
| "pred_obj": FloatTensor[B, T_s, K], |
| "pred_azi_logits": FloatTensor[B, T_s, K, 360], |
| "pred_ele_logits": FloatTensor[B, T_s, K, 180], |
| "pred_dist": FloatTensor[B, T_s, K, 1], |
| "pred_class_logits": FloatTensor[B, T_s, K, C_cls], |
| "llm_tokens": FloatTensor[B, N_keep, d_llm], |
| "llm_token_mask": BoolTensor[B, N_keep], |
| "token_meta": dict, |
| } |
| ``` |
|
|
| ### 13.2 推理时建议额外输出 |
|
|
| 建议额外输出: |
|
|
| - `pred_azi_deg` |
| - `pred_ele_deg` |
| - `pred_dist_m` |
| - `pred_obj_prob` |
| - `pred_class_id` |
|
|
| 便于后续可视化和调试。 |
|
|
| ## 14. 最小实现顺序 |
|
|
| 建议严格按以下顺序实现: |
|
|
| 1. 写 `SpatialBEATsPreprocessor` |
| 2. 写 `SpatialPatchEmbedding` |
| 3. 完成 trunk checkpoint 加载 |
| 4. 写 `TemporalDownsampler` |
| 5. 写 `SlotQueryDecoder` |
| 6. 写 `SpatialPredictionHead` |
| 7. 写 `SpatialTokenProjector` |
| 8. 写 `HungarianMatcher` |
| 9. 写 `SpatialSetCriterion` |
| 10. 写 dataset 和训练脚本 |
|
|
| ## 15. 未来支持“位置随时间变化”时需要改什么 |
|
|
| 你已经说明: |
|
|
| - 当前 clip 内位置固定 |
| - 后续会加入随时间变化的位置 |
|
|
| 这意味着当前模型结构基本可保留,但 target 和 decoder 训练方式需要升级。 |
|
|
| ### 15.1 当前结构哪些不用改 |
|
|
| 以下部分未来仍可直接保留: |
|
|
| - `FOA preprocessor` |
| - `patch embedding` |
| - `BEATs trunk` |
| - `TemporalDownsampler` |
| - `SlotQueryDecoder` |
| - `SpatialTokenProjector` |
| - `2.5 Hz` token rate |
|
|
| ### 15.2 未来必须改的部分 |
|
|
| 未来位置动态化后,需要改: |
|
|
| 1. `dataset target format` |
| 2. `matching strategy` |
| 3. `loss supervision` |
|
|
| ### 15.3 数据结构怎么升级 |
|
|
| 当前静态位置: |
|
|
| ```python |
| { |
| "azimuth_deg": float, |
| "elevation_deg": float, |
| "distance_m": float, |
| } |
| ``` |
|
|
| 未来动态位置建议升级为: |
|
|
| ```python |
| { |
| "trajectory": [ |
| { |
| "time_s": float, |
| "azimuth_deg": float, |
| "elevation_deg": float, |
| "distance_m": float, |
| }, |
| ... |
| ] |
| } |
| ``` |
|
|
| 或直接存成与 `2.5 Hz` 对齐的逐步 target: |
|
|
| ```python |
| { |
| "traj_azi_deg": FloatTensor[T_s], |
| "traj_ele_deg": FloatTensor[T_s], |
| "traj_dist_m": FloatTensor[T_s], |
| "traj_valid_mask": BoolTensor[T_s], |
| } |
| ``` |
|
|
| ### 15.4 匹配怎么升级 |
|
|
| 当前位置固定时: |
|
|
| - `per-source best-of-window matching` |
|
|
| 未来位置变化时: |
|
|
| - 更适合改为 `per-time-step matching` |
| - 或 `track-level matching` |
|
|
| 推荐未来版本: |
|
|
| - 每个 source 对应一条 slot track |
| - 在整个时间维上维持 slot identity |
|
|
| ### 15.5 loss 怎么升级 |
|
|
| 未来动态位置时: |
|
|
| - `L_azi / L_ele / L_dist` 应按时间步计算 |
| - `temporal consistency loss` 不能再强制“位置恒定” |
| - 应改成“速度平滑”或“轨迹平滑” |
|
|
| 也就是从: |
|
|
| - `constant-position regularization` |
|
|
| 升级成: |
|
|
| - `trajectory smoothness regularization` |
|
|
| ### 15.6 代码层面建议现在就预留的接口 |
|
|
| 为了兼容未来动态位置,当前第一版建议在数据与 loss 接口里预留: |
|
|
| - `is_position_dynamic` |
| - `trajectory` |
| - `traj_valid_mask` |
|
|
| 即使第一版不用,也建议把字段和分支接口预留出来。 |
|
|
| ## 16. 当前仍需要确认的问题 |
|
|
| 虽然核心方案已经足够落地,但还有一个关键问题最好在编码前确认: |
|
|
| 当前核心方案已经足够编码。 |
| 如果后续继续推进,唯一还值得尽早确认的是: |
|
|
| - 是否能从原始 source waveform 自动提取更精细的 energy/activity mask |
|
|
| 如果可以,第一版的弱时间监督会明显更稳。 |
|
|
| ## 17. 结论 |
|
|
| 当前可以直接进入代码实现的最终方案是: |
|
|
| - `16k FOA` |
| - `WXYZ + IV` |
| - `K=4` |
| - `2.5 Hz` slot token streams |
| - `distance` 连续回归 |
| - `class auxiliary head` 开启 |
| - `BEATs_iter3+ AS2M pre-trained` 作为 trunk 初始化 |
| - `Spatial-BEATs` 拥有自己的 projector |
| - 最终输出自己的 LLM spatial tokens |
| - 当前时间 supervision 按 `weak temporal window` 处理 |
| - 当前位置 supervision 按 `clip-level fixed position` 处理 |
| - 未来动态位置仅需升级 target/matching/loss,不需要重写主干结构 |
|
|
| 这份文档已经足够作为第一版实现蓝图使用。 |
|
|