Spatial-BEATs / docs /spatial_beats_coding_guide.md
dieKarotte's picture
Add files using upload-large-folder tool
86cbd36 verified
|
Raw
History Blame Contribute Delete
24 kB
# Spatial-BEATs Coding Guide
## 1. 本文档的作用
本文档是 `Spatial-BEATs` 的最终代码实施指南,用于直接指导后续代码开发。
它基于当前已经确认的项目约束:
- FOA 采样率统一到 `16 kHz`
- 每个样本最大同时声源数约为 `4`
- 每个声源都有稳定的 `source-level class label`
- `distance` 使用连续回归
- `Spatial-BEATs` 拥有自己的 `projector`
- 不需要与原始语义 audio encoder 做表示对齐
- 目标 `spatial token rate` 约为 `2.5 Hz`
- 允许增加 `source class auxiliary head`
本文档应视为后续实现的主参考。
## 2. 最终设计结论
### 2.1 总体目标
构建一个独立的 `Spatial-BEATs`
- 输入完整 `FOA waveform`
-`FOA` 中计算空间特征
- 将完整空间特征送入 `BEATs backbone`
- 输出可输入 LLM 的 `spatial tokens`
注意:
- 不是 `W-only`
- 不是外挂小 adapter
- 不是在原有语义 encoder 内部混合空间分支
而是:
- 一个独立的 `Spatial Encoder`
- 最大化复用 `BEATs trunk`
- 最终输出自己的空间 token 序列
### 2.2 关键实现原则
1. **完整 FOA 特征经过 BEATs 主干**
2. **尽量不改 BEATs trunk 内部 Transformer**
3. **重做输入 stem**
4. **重做输出头和 token 生成方式**
5. **主训练目标是多源空间建模,不是 clip-level 分类**
## 3. 最终模型架构
推荐最终架构如下:
```text
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T_f, F]
-> SpatialPatchEmbedding
-> BEATs trunk
-> Patch grid reshape
-> Temporal downsampler (to 2.5 Hz)
-> Slot query decoder
-> Source slot tokens [B, T_s, K, D]
-> Prediction heads
-> Spatial projector
-> LLM spatial tokens [B, N_keep, d_llm]
```
其中:
- `T_s` 是时间 token 数
- `K` 是每个时间步最大 source slot 数
- `D` 是 BEATs hidden dim
- `d_llm` 是 LLM hidden dim
## 4. 固定超参与默认取值
### 4.1 输入参数
- sample rate: `16000`
- mel bins: `128`
- frame length: `25 ms`
- frame shift: `10 ms`
### 4.2 token 相关参数
- token rate: `2.5 Hz`
- 对应时间间隔:`400 ms`
- 对于 `10 s` 样本:
- `T_s = 25`
### 4.3 source slot 参数
- 最大同时源数:`4`
- 默认 `K = 4`
说明:
- 第一版直接令 `K = 4`
- 不额外引入冗余 slot
- 如果后续发现数据中存在漏标、异常源或更复杂重叠,再考虑改成 `K = 5/6`
### 4.4 输入通道数
默认推荐:
- `W_logmel`
- `X_logmel`
- `Y_logmel`
- `Z_logmel`
- `IVx`
- `IVy`
- `IVz`
因此:
- `C_foa = 7`
## 5. 输入特征定义
### 5.1 推荐特征形式
第一版明确使用:
- `WXYZ log-mel`
- `IVx, IVy, IVz`
其中:
- `WXYZ` 提供 ambisonic 通道信息
- `IV` 提供显式方向 cue
### 5.2 IV 计算建议
建议在 STFT 域中计算 intensity vector,然后再映射到 mel 维:
```text
IVx ~ Re(conj(W) * X)
IVy ~ Re(conj(W) * Y)
IVz ~ Re(conj(W) * Z)
```
可再配合能量归一化:
```text
IV = IV / (|W|^2 + |X|^2 + |Y|^2 + |Z|^2 + eps)
```
实现时可以先得到频域 IV,再通过 mel filter bank 压到 `128` mel bins。
### 5.3 为什么不用 binaural IPD
当前任务是 `FOA`,不是 binaural。
Spatial-AST 的 `mel + IPD` 经验可借鉴其结构思路,但不能直接复用其输入表示。
本项目应优先使用:
- FOA 通道本身
- intensity vector
## 6. 对 BEATs 代码的具体改造
## 6.1 尽量保留的部分
建议完全复用:
- `TransformerEncoder`
- `TransformerSentenceEncoderLayer`
- `MultiheadAttention`
- `conv_pos`
- `post_extract_proj`
- trunk 中的 `LayerNorm / FFN / attention`
也就是说:
- [backbone.py](/apdcephfs_cq10/share_1603164/user/schmittzhu/code/unilm/beats/backbone.py) 尽量不改
### 6.2 需要重写的部分
必须重写:
1. `preprocess`
2. `patch_embedding`
3. `extract_features` 的输出形式
4. 原始 `predictor`
### 6.3 推荐新增文件
建议新增如下文件:
- `spatial_beats.py`
- `spatial_modules.py`
- `spatial_loss.py`
- `spatial_dataset.py`
- `train_spatial_beats.py`
- 可选 `infer_spatial_beats.py`
## 7. 预训练权重复用方案
## 7.1 推荐 checkpoint
默认推荐:
- `BEATs_iter3+ (AS2M) pre-trained`
不推荐第一版直接用 fine-tuned checkpoint 作为 trunk 初始化。
### 7.2 直接加载的层
建议直接加载:
- `post_extract_proj`
- `encoder.pos_conv`
- `encoder.layers.*`
- `encoder.layer_norm`
- `layer_norm`
这些层使用:
- `strict=False`
并打印缺失与不匹配项。
### 7.3 不能直接加载的层
以下层需要新初始化:
- 新的 `patch_embedding`
- `temporal downsampler`
- `slot query decoder`
- `prediction heads`
- `spatial projector`
### 7.4 新 patch stem 的初始化
原始 BEATs stem:
```text
Conv2d(1, embed_dim, kernel_size=patch, stride=patch)
```
新的 stem:
```text
Conv2d(7, embed_dim, kernel_size=patch, stride=patch)
```
推荐初始化方案:
- `W_logmel` 通道继承原 BEATs stem 权重
- `X/Y/Z/IVx/IVy/IVz` 通道初始化为较小随机值
推荐做法:
```text
new_weight[:, 0, :, :] = old_weight[:, 0, :, :]
new_weight[:, 1:, :, :] ~ N(0, 0.02 * std(old_weight))
```
不推荐全部复制 inflation 作为默认方案。
第一版优先稳定,而不是让所有通道一开始等价共享单通道语义滤波器。
## 8. 代码结构建议
## 8.1 `spatial_modules.py`
建议包含以下模块:
### `SpatialBEATsPreprocessor`
职责:
- 输入 `FOA waveform [B, 4, T]`
- 计算:
- `WXYZ logmel`
- `IVx, IVy, IVz`
- 输出:
- `foa_feat [B, 7, T_f, 128]`
建议接口:
```python
class SpatialBEATsPreprocessor(nn.Module):
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
...
```
### `SpatialPatchEmbedding`
职责:
- 对 `foa_feat` 做多通道 patch embedding
建议接口:
```python
class SpatialPatchEmbedding(nn.Module):
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
# returns:
# tokens: [B, N_p, D_in]
# grid_hw: (T_p, F_p)
...
```
### `TemporalDownsampler`
职责:
- 将 trunk 输出从 patch 时间分辨率下采样到 `2.5 Hz`
建议输入输出:
- 输入:`grid memory [B, T_p, F_p, D]`
- 先对 `F_p` 做平均或轻量 attention pooling
- 得到:`temporal memory [B, T_p, D]`
- 再用线性插值或 1D conv 下采样到:
- `slot memory [B, T_s, D]`
建议接口:
```python
class TemporalDownsampler(nn.Module):
def forward(self, grid_x: torch.Tensor, target_steps: int) -> torch.Tensor:
# grid_x: [B, T_p, F_p, D]
# out: [B, T_s, D]
...
```
默认推荐:
- 第一版使用 `freq-mean + linear interpolate`
原因:
- 简单
- 稳定
- 容易调试
### `SlotQueryDecoder`
职责:
- 对每个时间步生成 `K=4` 个 source slots
推荐设计:
- 为每个 slot 准备一个 learnable `slot embedding`
- 将时间 token `m_t` 与 slot embedding 相加,形成初始 query
- query 对 trunk memory 做 cross-attention
建议输出:
- `slot_tokens [B, T_s, K, D]`
建议接口:
```python
class SlotQueryDecoder(nn.Module):
def forward(
self,
temporal_memory: torch.Tensor,
encoder_memory: torch.Tensor,
) -> torch.Tensor:
# temporal_memory: [B, T_s, D]
# encoder_memory: [B, N_p, D]
# out: [B, T_s, K, D]
...
```
实现建议:
- 先用 `temporal_memory` 生成时间条件 query
- 再用 `2 层 TransformerDecoderLayer` 或自定义 cross-attn block
第一版推荐:
- `2 层 decoder`
- hidden dim 与 trunk 一致
### `SpatialPredictionHead`
职责:
-`slot_tokens` 预测各任务输出
建议输出:
- `pred_obj: [B, T_s, K]`
- `pred_azi_logits: [B, T_s, K, 360]`
- `pred_ele_logits: [B, T_s, K, 180]`
- `pred_dist: [B, T_s, K, 1]`
- `pred_class_logits: [B, T_s, K, C_cls]`
### `SpatialTokenProjector`
职责:
- 将 slot latent 与结构化坐标信息组合
- 投影到 LLM hidden size
输出:
- `llm_tokens [B, N_keep, d_llm]`
## 8.2 `spatial_beats.py`
建议定义:
### `SpatialBEATsConfig`
字段建议:
- `sample_rate=16000`
- `num_mel_bins=128`
- `token_rate=2.5`
- `max_sources=4`
- `foa_channels=7`
- `distance_max_m`
- `llm_hidden_size`
- `use_class_aux=True`
- `num_decoder_layers=2`
### `SpatialBEATs`
建议结构:
```python
class SpatialBEATs(nn.Module):
def __init__(self, cfg, beats_ckpt=None):
...
def extract_spatial_features(self, waveforms):
...
def extract_spatial_tokens(self, waveforms, audio_lengths=None):
...
def project_tokens_for_llm(self, slot_tokens, preds, keep_mask=None):
...
def forward(self, waveforms, audio_lengths=None, targets=None):
...
```
### `forward()` 推荐返回形式
返回字典:
```python
{
"encoder_memory": ...,
"slot_tokens": ...,
"pred_obj": ...,
"pred_azi_logits": ...,
"pred_ele_logits": ...,
"pred_dist": ...,
"pred_class_logits": ...,
"llm_tokens": ...,
"llm_token_mask": ...,
"token_meta": ...,
}
```
## 8.3 `spatial_loss.py`
建议定义:
### `HungarianMatcher`
输入:
- 预测输出
- GT targets
输出:
- 每个样本每个时间步的匹配索引
### `SpatialSetCriterion`
计算:
- objectness loss
- azimuth loss
- elevation loss
- distance regression loss
- class auxiliary loss
可选:
- temporal smoothness loss
## 8.4 `spatial_dataset.py`
建议数据格式:
```python
sample = {
"waveform": FloatTensor[4, T],
"duration_s": float,
"sources": [
{
"class_id": int,
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
"start_s": float,
"end_s": float,
"is_time_weak": bool,
"is_position_dynamic": bool,
"trajectory": optional,
},
...
]
}
```
当前推荐额外保留以下字段:
- `start_s`: 源开始进入场景的时间
- `end_s`: 若有则保留,否则可由 `start_s + length_s` 推出
- `length_s`: 原始 source clip 时长
- `is_time_weak`: 当前时间边界是否只是弱监督
- `is_position_dynamic`: 该源位置是否随时间变化
- `trajectory`: 若位置变化,则存储分段轨迹或逐帧轨迹
如果当前只有源进入时间和原始 source length,可统一转成:
- `start_s = start time`
- `end_s = start_s + length_s`
- `is_time_weak = True`
## 9. 时间建模、2.5 Hz 输出与弱时间监督
这是当前实现中最关键的新增约束之一。
### 9.1 token rate 的解释
这里定义:
- 每个 `source slot stream` 的输出速率为 `2.5 Hz`
即:
- 每个时间步间隔 `400 ms`
### 9.2 输出张量形状
对时长为 `L` 秒的样本:
```text
T_s = round(L * 2.5)
```
例如:
- `10 s -> 25`
最终 slot token 形状:
```text
[B, T_s, K, D]
```
在当前默认配置下:
- `K = 4`
也就是最多 `4` 条并行 source slot 流。
### 9.3 当前时间标注的含义
你当前可提供的时间信息是:
- 知道每个 source 的 `start time`
- 知道原始 `FSD50K source clip``length`
- 但这段 `length` 内不保证每一时刻都真的 active
因此,当前不应把 `[start_s, end_s]` 视为严格逐帧真值,而应视为:
- `weak temporal support window`
也就是:
- 源最可能出现的候选时间范围
- 不是精确的逐帧 activity annotation
### 9.4 第一版如何把弱时间标注映射到 2.5 Hz token
对于第 `t` 个时间步,其中心时刻记为 `tau_t`
对每个 GT source,定义:
- `candidate window = [start_s, end_s]`
第一版推荐构造三种 mask:
1. `pos_window_mask`
- `tau_t` 落在 `[start_s, end_s]`
2. `neg_window_mask`
- `tau_t` 明确落在窗口外
3. `ignore_mask`
- 可选,用于窗口边界附近或不确定区域
当前默认建议最简单实现:
- 窗口外:作为 objectness 负样本
- 窗口内:作为弱正样本候选
但不要对窗口内所有步都施加强监督位置 loss。
### 9.5 当前 loss 需要怎么改
为了适应弱时间标注,当前建议把 loss 拆成两层:
#### `L_obj`
- 对窗口外,正常做负样本监督
- 对窗口内,做弱正样本监督
推荐:
- 使用 `BCE` 或 `focal loss`
- 窗口内正样本权重降低
例如:
```text
w_obj_pos_weak = 0.3 ~ 0.5
w_obj_neg = 1.0
```
#### `L_azi / L_ele / L_dist / L_cls`
第一版不要在窗口内所有时间步都强制监督。
推荐只在以下位置监督:
- 与 GT source 匹配且 `pred_obj` 较高的 slot
- 或窗口内的 top-k 高置信时间步
更稳的第一版做法:
- 先对每个 GT source 在窗口内选择 `top-1` 或 `top-2` 个 objectness 最高的时间步参与坐标/类别监督
这样可以避免:
- 源在窗口内部分时间其实不 active
- 但模型被错误惩罚
### 9.6 推荐的第一版弱监督匹配策略
当前建议采用两阶段匹配:
1. 先按时间窗口过滤候选时间步
2. 再在候选时间步内做 slot matching
更具体地说:
- 对每个 GT source,只允许匹配其时间窗口内的 `slot tokens`
- 在这些候选中选出最优 `(t, k)`
这比直接对所有 `[T_s, K]` 位置做全局 Hungarian 更稳。
第一版推荐:
- `per-source best-of-window matching`
而不是:
- 全局 dense set matching
原因:
- 你当前时间标注是弱的
- 先用窗口约束大幅降低匹配歧义更现实
### 9.7 推理时 token 序列长度不需要改变
`2.5 Hz` 的 token rate 不需要变。
要改的是:
- 训练 supervision 的构造方式
- objectness 与坐标 loss 的作用范围
### 9.8 未来如果拿到更好的 activity 标注
如果后续可以拿到:
- energy-based active mask
- frame-level source activity
- VAD / source activation probability
则可把当前的弱时间监督替换成:
- `strong temporal supervision`
到时只需替换 target 构造和 criterion,不需要改主模型结构。
### 9.9 喂给 LLM 时的 token 数量
如果全部展开,理论最大 token 速率为:
```text
2.5 Hz * 4 = 10 tokens / second
```
但推理时可通过 `objectness` 做过滤,所以通常会低于这个上限。
### 9.10 LLM 展开顺序
建议按如下顺序展开:
- 先按时间排序
- 每个时间步内部按 `objectness` 从高到低排序
也就是:
```text
t1_s1, t1_s2, t1_s3, t1_s4, t2_s1, t2_s2, ...
```
然后再过滤低置信 slot。
## 10. 输出给 LLM 的 spatial token 形式
## 10.1 不直接喂原始 logits
不建议直接把:
- 方位分类 logits
- 类别 logits
- 距离标量
直接作为 token 输入给 LLM。
### 10.2 推荐 token 构造方式
每个 slot token `z_{t,k}` 最终形成一个结构化 token:
```text
s_{t,k} = Proj([z_{t,k} ; c_{t,k} ; u_{t,k} ; d_{t,k} ; o_{t,k}])
```
其中:
- `z_{t,k}`: slot latent
- `c_{t,k}`: source class context embedding
- `u_{t,k}`: 方向向量
- `d_{t,k}`: 连续距离 embedding
- `o_{t,k}`: objectness/confidence embedding
### 10.3 各项具体建议
#### `c_{t,k}`: 类别上下文
由 `pred_class_logits` 构造:
```text
p_cls = softmax(pred_class_logits)
c = p_cls @ E_cls
```
其中:
- `E_cls` 是一个可学习类别 embedding 表
作用:
- 给 spatial token 少量语义 grounding
- 但不需要与原始 audio encoder 对齐
#### `u_{t,k}`: 方向向量
先由:
- `pred_azi_logits`
- `pred_ele_logits`
得到预测角度,再转换成单位球坐标向量:
```text
u = [x, y, z]
```
推荐实现:
- 训练时用分布期望或 soft-argmax
- 推理时可用 argmax
#### `d_{t,k}`: 连续距离表示
由于 distance 是连续回归,建议:
-`pred_dist` 做归一化
- 再经一个小 MLP 变成 embedding
#### `o_{t,k}`: 置信度表示
由 `pred_obj` 经 sigmoid 得到 objectness,再做小 MLP 映射。
### 10.4 projector 的最终作用
`SpatialTokenProjector` 的任务是把:
- slot latent
- class context
- direction vector
- distance embedding
- objectness embedding
融合并投影到:
- `d_llm`
输出:
```text
llm_tokens: [B, N_keep, d_llm]
```
### 10.5 是否需要与原 audio encoder 对齐
当前结论:
- **不需要**
因此:
- 这个 projector 完全独立训练
- 只服务于 `Spatial-BEATs -> LLM`
## 11. Loss 设计
## 11.1 任务头与 loss
推荐 loss 组成:
```text
L_total =
lambda_obj * L_obj
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls
```
### 11.2 各项定义
- `L_obj`: BCE 或 focal loss,支持弱正样本权重
- `L_azi`: cross entropy
- `L_ele`: cross entropy
- `L_dist`: SmoothL1 / Huber
- `L_cls`: cross entropy
### 11.3 distance 回归的实现
由于你已经明确要连续回归,推荐:
- head 输出 `pred_dist_norm in [0, 1]`
- 再乘以 `distance_max_m`
训练时使用:
- `SmoothL1Loss(pred_dist_norm, gt_dist_norm)`
优点:
- 比直接回归未归一化距离更稳
- 比 MSE 更抗异常值
### 11.4 推荐初始权重
建议第一版从以下权重起步:
```text
lambda_obj = 1.0
lambda_azi = 2.0
lambda_ele = 2.0
lambda_dist = 1.0
lambda_cls = 0.5
```
这里把 `class auxiliary` 权重从之前建议的 `0.25` 提到 `0.5`,因为现在你已经确认:
- 每个源有稳定的 `source-level class label`
- Spatial-BEATs 自身保留一定语义信息是可接受的
### 11.5 当前版本的匹配方式修订
由于当前时间 supervision 是弱的,第一版不建议直接做:
- 全局 `Hungarian matching` over `[T_s, K]`
更推荐:
1. 先根据 `source time window` 过滤候选时间步
2. 再在候选窗口内做匹配
推荐实现:
- `window-constrained matching`
可选两种方式:
#### 方案 A:推荐默认方案
- 对每个 GT source
- 在其 window 内所有 `(t, k)` 候选中,选择 cost 最小的一对
这本质上是:
- `best-of-window assignment`
优点:
- 简单
- 稳定
- 对弱时间监督更友好
#### 方案 B:后续增强方案
- 对每个时间步分别做 Hungarian
- 再加时间连续性正则
这更适合将来位置随时间变化时使用。
### 11.6 匹配 cost
Hungarian matching cost 建议:
```text
cost =
w_obj * cost_obj
+ w_azi * cost_azi
+ w_ele * cost_ele
+ w_dist * cost_dist
+ w_cls * cost_cls
```
推荐初值:
```text
w_obj = 1.0
w_azi = 2.0
w_ele = 2.0
w_dist = 1.0
w_cls = 1.0
```
## 12. 训练策略
### 12.1 第一阶段是否需要 SSL
当前明确结论:
- 第一版 **不做新的 BEATs 式 SSL**
理由:
- 已有空间 GT
- 已有 source class GT
- 已有强 trunk 预训练
- 当前主要目标是空间结构建模
### 12.2 推荐训练阶段
#### Stage A: Warmup
冻结:
- trunk 大部分层
训练:
- preprocessor
- patch stem
- temporal downsampler
- slot query decoder
- prediction heads
- projector
#### Stage B: Upper-trunk finetune
解冻:
- trunk 上层若干层
#### Stage C: Wider finetune
逐步解冻更多层,直到性能稳定。
### 12.3 训练时建议增加的正则项
当前位置在 clip 内固定,因此建议增加:
- `temporal consistency loss`
具体可对同一 source 在相邻时间步的预测加约束:
- objectness 平滑
- azimuth/elevation 分布平滑
- distance 平滑
第一版可选实现:
```text
L_temp =
smooth(pred_obj_t, pred_obj_{t+1})
+ smooth(pred_dist_t, pred_dist_{t+1})
```
由于当前位置固定,这类正则通常有利于稳定训练。
### 12.4 学习率建议
推荐:
```text
lr_trunk = 1e-5 ~ 5e-5
lr_new = 1e-4 ~ 5e-4
```
并使用:
- weight decay
- warmup
- layer-wise lr decay
## 13. 训练与推理输出格式
## 13.1 训练时 `forward()` 输出
建议 `forward()` 返回:
```python
{
"slot_tokens": FloatTensor[B, T_s, K, D],
"pred_obj": FloatTensor[B, T_s, K],
"pred_azi_logits": FloatTensor[B, T_s, K, 360],
"pred_ele_logits": FloatTensor[B, T_s, K, 180],
"pred_dist": FloatTensor[B, T_s, K, 1],
"pred_class_logits": FloatTensor[B, T_s, K, C_cls],
"llm_tokens": FloatTensor[B, N_keep, d_llm],
"llm_token_mask": BoolTensor[B, N_keep],
"token_meta": dict,
}
```
### 13.2 推理时建议额外输出
建议额外输出:
- `pred_azi_deg`
- `pred_ele_deg`
- `pred_dist_m`
- `pred_obj_prob`
- `pred_class_id`
便于后续可视化和调试。
## 14. 最小实现顺序
建议严格按以下顺序实现:
1.`SpatialBEATsPreprocessor`
2.`SpatialPatchEmbedding`
3. 完成 trunk checkpoint 加载
4.`TemporalDownsampler`
5.`SlotQueryDecoder`
6.`SpatialPredictionHead`
7.`SpatialTokenProjector`
8.`HungarianMatcher`
9.`SpatialSetCriterion`
10. 写 dataset 和训练脚本
## 15. 未来支持“位置随时间变化”时需要改什么
你已经说明:
- 当前 clip 内位置固定
- 后续会加入随时间变化的位置
这意味着当前模型结构基本可保留,但 target 和 decoder 训练方式需要升级。
### 15.1 当前结构哪些不用改
以下部分未来仍可直接保留:
- `FOA preprocessor`
- `patch embedding`
- `BEATs trunk`
- `TemporalDownsampler`
- `SlotQueryDecoder`
- `SpatialTokenProjector`
- `2.5 Hz` token rate
### 15.2 未来必须改的部分
未来位置动态化后,需要改:
1. `dataset target format`
2. `matching strategy`
3. `loss supervision`
### 15.3 数据结构怎么升级
当前静态位置:
```python
{
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
}
```
未来动态位置建议升级为:
```python
{
"trajectory": [
{
"time_s": float,
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
},
...
]
}
```
或直接存成与 `2.5 Hz` 对齐的逐步 target:
```python
{
"traj_azi_deg": FloatTensor[T_s],
"traj_ele_deg": FloatTensor[T_s],
"traj_dist_m": FloatTensor[T_s],
"traj_valid_mask": BoolTensor[T_s],
}
```
### 15.4 匹配怎么升级
当前位置固定时:
- `per-source best-of-window matching`
未来位置变化时:
- 更适合改为 `per-time-step matching`
-`track-level matching`
推荐未来版本:
- 每个 source 对应一条 slot track
- 在整个时间维上维持 slot identity
### 15.5 loss 怎么升级
未来动态位置时:
- `L_azi / L_ele / L_dist` 应按时间步计算
- `temporal consistency loss` 不能再强制“位置恒定”
- 应改成“速度平滑”或“轨迹平滑”
也就是从:
- `constant-position regularization`
升级成:
- `trajectory smoothness regularization`
### 15.6 代码层面建议现在就预留的接口
为了兼容未来动态位置,当前第一版建议在数据与 loss 接口里预留:
- `is_position_dynamic`
- `trajectory`
- `traj_valid_mask`
即使第一版不用,也建议把字段和分支接口预留出来。
## 16. 当前仍需要确认的问题
虽然核心方案已经足够落地,但还有一个关键问题最好在编码前确认:
当前核心方案已经足够编码。
如果后续继续推进,唯一还值得尽早确认的是:
- 是否能从原始 source waveform 自动提取更精细的 energy/activity mask
如果可以,第一版的弱时间监督会明显更稳。
## 17. 结论
当前可以直接进入代码实现的最终方案是:
- `16k FOA`
- `WXYZ + IV`
- `K=4`
- `2.5 Hz` slot token streams
- `distance` 连续回归
- `class auxiliary head` 开启
- `BEATs_iter3+ AS2M pre-trained` 作为 trunk 初始化
- `Spatial-BEATs` 拥有自己的 projector
- 最终输出自己的 LLM spatial tokens
- 当前时间 supervision 按 `weak temporal window` 处理
- 当前位置 supervision 按 `clip-level fixed position` 处理
- 未来动态位置仅需升级 target/matching/loss,不需要重写主干结构
这份文档已经足够作为第一版实现蓝图使用。