Spatial-BEATs Coding Guide
1. 本文档的作用
本文档是 Spatial-BEATs 的最终代码实施指南,用于直接指导后续代码开发。
它基于当前已经确认的项目约束:
- FOA 采样率统一到
16 kHz - 每个样本最大同时声源数约为
4 - 每个声源都有稳定的
source-level class label distance使用连续回归Spatial-BEATs拥有自己的projector- 不需要与原始语义 audio encoder 做表示对齐
- 目标
spatial token rate约为2.5 Hz - 允许增加
source class auxiliary head
本文档应视为后续实现的主参考。
2. 最终设计结论
2.1 总体目标
构建一个独立的 Spatial-BEATs:
- 输入完整
FOA waveform - 从
FOA中计算空间特征 - 将完整空间特征送入
BEATs backbone - 输出可输入 LLM 的
spatial tokens
注意:
- 不是
W-only - 不是外挂小 adapter
- 不是在原有语义 encoder 内部混合空间分支
而是:
- 一个独立的
Spatial Encoder - 最大化复用
BEATs trunk - 最终输出自己的空间 token 序列
2.2 关键实现原则
- 完整 FOA 特征经过 BEATs 主干
- 尽量不改 BEATs trunk 内部 Transformer
- 重做输入 stem
- 重做输出头和 token 生成方式
- 主训练目标是多源空间建模,不是 clip-level 分类
3. 最终模型架构
推荐最终架构如下:
FOA waveform [B, 4, T]
-> SpatialBEATsPreprocessor
-> FOA feature map [B, C_foa, T_f, F]
-> SpatialPatchEmbedding
-> BEATs trunk
-> Patch grid reshape
-> Temporal downsampler (to 2.5 Hz)
-> Slot query decoder
-> Source slot tokens [B, T_s, K, D]
-> Prediction heads
-> Spatial projector
-> LLM spatial tokens [B, N_keep, d_llm]
其中:
T_s是时间 token 数K是每个时间步最大 source slot 数D是 BEATs hidden dimd_llm是 LLM hidden dim
4. 固定超参与默认取值
4.1 输入参数
- sample rate:
16000 - mel bins:
128 - frame length:
25 ms - frame shift:
10 ms
4.2 token 相关参数
- token rate:
2.5 Hz - 对应时间间隔:
400 ms - 对于
10 s样本:T_s = 25
4.3 source slot 参数
- 最大同时源数:
4 - 默认
K = 4
说明:
- 第一版直接令
K = 4 - 不额外引入冗余 slot
- 如果后续发现数据中存在漏标、异常源或更复杂重叠,再考虑改成
K = 5/6
4.4 输入通道数
默认推荐:
W_logmelX_logmelY_logmelZ_logmelIVxIVyIVz
因此:
C_foa = 7
5. 输入特征定义
5.1 推荐特征形式
第一版明确使用:
WXYZ log-melIVx, IVy, IVz
其中:
WXYZ提供 ambisonic 通道信息IV提供显式方向 cue
5.2 IV 计算建议
建议在 STFT 域中计算 intensity vector,然后再映射到 mel 维:
IVx ~ Re(conj(W) * X)
IVy ~ Re(conj(W) * Y)
IVz ~ Re(conj(W) * Z)
可再配合能量归一化:
IV = IV / (|W|^2 + |X|^2 + |Y|^2 + |Z|^2 + eps)
实现时可以先得到频域 IV,再通过 mel filter bank 压到 128 mel bins。
5.3 为什么不用 binaural IPD
当前任务是 FOA,不是 binaural。
Spatial-AST 的 mel + IPD 经验可借鉴其结构思路,但不能直接复用其输入表示。
本项目应优先使用:
- FOA 通道本身
- intensity vector
6. 对 BEATs 代码的具体改造
6.1 尽量保留的部分
建议完全复用:
TransformerEncoderTransformerSentenceEncoderLayerMultiheadAttentionconv_pospost_extract_proj- trunk 中的
LayerNorm / FFN / attention
也就是说:
- backbone.py 尽量不改
6.2 需要重写的部分
必须重写:
preprocesspatch_embeddingextract_features的输出形式- 原始
predictor
6.3 推荐新增文件
建议新增如下文件:
spatial_beats.pyspatial_modules.pyspatial_loss.pyspatial_dataset.pytrain_spatial_beats.py- 可选
infer_spatial_beats.py
7. 预训练权重复用方案
7.1 推荐 checkpoint
默认推荐:
BEATs_iter3+ (AS2M) pre-trained
不推荐第一版直接用 fine-tuned checkpoint 作为 trunk 初始化。
7.2 直接加载的层
建议直接加载:
post_extract_projencoder.pos_convencoder.layers.*encoder.layer_normlayer_norm
这些层使用:
strict=False
并打印缺失与不匹配项。
7.3 不能直接加载的层
以下层需要新初始化:
- 新的
patch_embedding temporal downsamplerslot query decoderprediction headsspatial projector
7.4 新 patch stem 的初始化
原始 BEATs stem:
Conv2d(1, embed_dim, kernel_size=patch, stride=patch)
新的 stem:
Conv2d(7, embed_dim, kernel_size=patch, stride=patch)
推荐初始化方案:
W_logmel通道继承原 BEATs stem 权重X/Y/Z/IVx/IVy/IVz通道初始化为较小随机值
推荐做法:
new_weight[:, 0, :, :] = old_weight[:, 0, :, :]
new_weight[:, 1:, :, :] ~ N(0, 0.02 * std(old_weight))
不推荐全部复制 inflation 作为默认方案。
第一版优先稳定,而不是让所有通道一开始等价共享单通道语义滤波器。
8. 代码结构建议
8.1 spatial_modules.py
建议包含以下模块:
SpatialBEATsPreprocessor
职责:
- 输入
FOA waveform [B, 4, T] - 计算:
WXYZ logmelIVx, IVy, IVz
- 输出:
foa_feat [B, 7, T_f, 128]
建议接口:
class SpatialBEATsPreprocessor(nn.Module):
def forward(self, waveforms: torch.Tensor) -> torch.Tensor:
...
SpatialPatchEmbedding
职责:
- 对
foa_feat做多通道 patch embedding
建议接口:
class SpatialPatchEmbedding(nn.Module):
def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, tuple[int, int]]:
# returns:
# tokens: [B, N_p, D_in]
# grid_hw: (T_p, F_p)
...
TemporalDownsampler
职责:
- 将 trunk 输出从 patch 时间分辨率下采样到
2.5 Hz
建议输入输出:
- 输入:
grid memory [B, T_p, F_p, D] - 先对
F_p做平均或轻量 attention pooling - 得到:
temporal memory [B, T_p, D] - 再用线性插值或 1D conv 下采样到:
slot memory [B, T_s, D]
建议接口:
class TemporalDownsampler(nn.Module):
def forward(self, grid_x: torch.Tensor, target_steps: int) -> torch.Tensor:
# grid_x: [B, T_p, F_p, D]
# out: [B, T_s, D]
...
默认推荐:
- 第一版使用
freq-mean + linear interpolate
原因:
- 简单
- 稳定
- 容易调试
SlotQueryDecoder
职责:
- 对每个时间步生成
K=4个 source slots
推荐设计:
- 为每个 slot 准备一个 learnable
slot embedding - 将时间 token
m_t与 slot embedding 相加,形成初始 query - query 对 trunk memory 做 cross-attention
建议输出:
slot_tokens [B, T_s, K, D]
建议接口:
class SlotQueryDecoder(nn.Module):
def forward(
self,
temporal_memory: torch.Tensor,
encoder_memory: torch.Tensor,
) -> torch.Tensor:
# temporal_memory: [B, T_s, D]
# encoder_memory: [B, N_p, D]
# out: [B, T_s, K, D]
...
实现建议:
- 先用
temporal_memory生成时间条件 query - 再用
2 层 TransformerDecoderLayer或自定义 cross-attn block
第一版推荐:
2 层 decoder- hidden dim 与 trunk 一致
SpatialPredictionHead
职责:
- 对
slot_tokens预测各任务输出
建议输出:
pred_obj: [B, T_s, K]pred_azi_logits: [B, T_s, K, 360]pred_ele_logits: [B, T_s, K, 180]pred_dist: [B, T_s, K, 1]pred_class_logits: [B, T_s, K, C_cls]
SpatialTokenProjector
职责:
- 将 slot latent 与结构化坐标信息组合
- 投影到 LLM hidden size
输出:
llm_tokens [B, N_keep, d_llm]
8.2 spatial_beats.py
建议定义:
SpatialBEATsConfig
字段建议:
sample_rate=16000num_mel_bins=128token_rate=2.5max_sources=4foa_channels=7distance_max_mllm_hidden_sizeuse_class_aux=Truenum_decoder_layers=2
SpatialBEATs
建议结构:
class SpatialBEATs(nn.Module):
def __init__(self, cfg, beats_ckpt=None):
...
def extract_spatial_features(self, waveforms):
...
def extract_spatial_tokens(self, waveforms, audio_lengths=None):
...
def project_tokens_for_llm(self, slot_tokens, preds, keep_mask=None):
...
def forward(self, waveforms, audio_lengths=None, targets=None):
...
forward() 推荐返回形式
返回字典:
{
"encoder_memory": ...,
"slot_tokens": ...,
"pred_obj": ...,
"pred_azi_logits": ...,
"pred_ele_logits": ...,
"pred_dist": ...,
"pred_class_logits": ...,
"llm_tokens": ...,
"llm_token_mask": ...,
"token_meta": ...,
}
8.3 spatial_loss.py
建议定义:
HungarianMatcher
输入:
- 预测输出
- GT targets
输出:
- 每个样本每个时间步的匹配索引
SpatialSetCriterion
计算:
- objectness loss
- azimuth loss
- elevation loss
- distance regression loss
- class auxiliary loss
可选:
- temporal smoothness loss
8.4 spatial_dataset.py
建议数据格式:
sample = {
"waveform": FloatTensor[4, T],
"duration_s": float,
"sources": [
{
"class_id": int,
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
"start_s": float,
"end_s": float,
"is_time_weak": bool,
"is_position_dynamic": bool,
"trajectory": optional,
},
...
]
}
当前推荐额外保留以下字段:
start_s: 源开始进入场景的时间end_s: 若有则保留,否则可由start_s + length_s推出length_s: 原始 source clip 时长is_time_weak: 当前时间边界是否只是弱监督is_position_dynamic: 该源位置是否随时间变化trajectory: 若位置变化,则存储分段轨迹或逐帧轨迹
如果当前只有源进入时间和原始 source length,可统一转成:
start_s = start timeend_s = start_s + length_sis_time_weak = True
9. 时间建模、2.5 Hz 输出与弱时间监督
这是当前实现中最关键的新增约束之一。
9.1 token rate 的解释
这里定义:
- 每个
source slot stream的输出速率为2.5 Hz
即:
- 每个时间步间隔
400 ms
9.2 输出张量形状
对时长为 L 秒的样本:
T_s = round(L * 2.5)
例如:
10 s -> 25
最终 slot token 形状:
[B, T_s, K, D]
在当前默认配置下:
K = 4
也就是最多 4 条并行 source slot 流。
9.3 当前时间标注的含义
你当前可提供的时间信息是:
- 知道每个 source 的
start time - 知道原始
FSD50K source clip的length - 但这段
length内不保证每一时刻都真的 active
因此,当前不应把 [start_s, end_s] 视为严格逐帧真值,而应视为:
weak temporal support window
也就是:
- 源最可能出现的候选时间范围
- 不是精确的逐帧 activity annotation
9.4 第一版如何把弱时间标注映射到 2.5 Hz token
对于第 t 个时间步,其中心时刻记为 tau_t。
对每个 GT source,定义:
candidate window = [start_s, end_s]
第一版推荐构造三种 mask:
pos_window_masktau_t落在[start_s, end_s]内
neg_window_masktau_t明确落在窗口外
ignore_mask- 可选,用于窗口边界附近或不确定区域
当前默认建议最简单实现:
- 窗口外:作为 objectness 负样本
- 窗口内:作为弱正样本候选
但不要对窗口内所有步都施加强监督位置 loss。
9.5 当前 loss 需要怎么改
为了适应弱时间标注,当前建议把 loss 拆成两层:
L_obj
- 对窗口外,正常做负样本监督
- 对窗口内,做弱正样本监督
推荐:
- 使用
BCE或focal loss - 窗口内正样本权重降低
例如:
w_obj_pos_weak = 0.3 ~ 0.5
w_obj_neg = 1.0
L_azi / L_ele / L_dist / L_cls
第一版不要在窗口内所有时间步都强制监督。
推荐只在以下位置监督:
- 与 GT source 匹配且
pred_obj较高的 slot - 或窗口内的 top-k 高置信时间步
更稳的第一版做法:
- 先对每个 GT source 在窗口内选择
top-1或top-2个 objectness 最高的时间步参与坐标/类别监督
这样可以避免:
- 源在窗口内部分时间其实不 active
- 但模型被错误惩罚
9.6 推荐的第一版弱监督匹配策略
当前建议采用两阶段匹配:
- 先按时间窗口过滤候选时间步
- 再在候选时间步内做 slot matching
更具体地说:
- 对每个 GT source,只允许匹配其时间窗口内的
slot tokens - 在这些候选中选出最优
(t, k)
这比直接对所有 [T_s, K] 位置做全局 Hungarian 更稳。
第一版推荐:
per-source best-of-window matching
而不是:
- 全局 dense set matching
原因:
- 你当前时间标注是弱的
- 先用窗口约束大幅降低匹配歧义更现实
9.7 推理时 token 序列长度不需要改变
2.5 Hz 的 token rate 不需要变。
要改的是:
- 训练 supervision 的构造方式
- objectness 与坐标 loss 的作用范围
9.8 未来如果拿到更好的 activity 标注
如果后续可以拿到:
- energy-based active mask
- frame-level source activity
- VAD / source activation probability
则可把当前的弱时间监督替换成:
strong temporal supervision
到时只需替换 target 构造和 criterion,不需要改主模型结构。
9.9 喂给 LLM 时的 token 数量
如果全部展开,理论最大 token 速率为:
2.5 Hz * 4 = 10 tokens / second
但推理时可通过 objectness 做过滤,所以通常会低于这个上限。
9.10 LLM 展开顺序
建议按如下顺序展开:
- 先按时间排序
- 每个时间步内部按
objectness从高到低排序
也就是:
t1_s1, t1_s2, t1_s3, t1_s4, t2_s1, t2_s2, ...
然后再过滤低置信 slot。
10. 输出给 LLM 的 spatial token 形式
10.1 不直接喂原始 logits
不建议直接把:
- 方位分类 logits
- 类别 logits
- 距离标量
直接作为 token 输入给 LLM。
10.2 推荐 token 构造方式
每个 slot token z_{t,k} 最终形成一个结构化 token:
s_{t,k} = Proj([z_{t,k} ; c_{t,k} ; u_{t,k} ; d_{t,k} ; o_{t,k}])
其中:
z_{t,k}: slot latentc_{t,k}: source class context embeddingu_{t,k}: 方向向量d_{t,k}: 连续距离 embeddingo_{t,k}: objectness/confidence embedding
10.3 各项具体建议
c_{t,k}: 类别上下文
由 pred_class_logits 构造:
p_cls = softmax(pred_class_logits)
c = p_cls @ E_cls
其中:
E_cls是一个可学习类别 embedding 表
作用:
- 给 spatial token 少量语义 grounding
- 但不需要与原始 audio encoder 对齐
u_{t,k}: 方向向量
先由:
pred_azi_logitspred_ele_logits
得到预测角度,再转换成单位球坐标向量:
u = [x, y, z]
推荐实现:
- 训练时用分布期望或 soft-argmax
- 推理时可用 argmax
d_{t,k}: 连续距离表示
由于 distance 是连续回归,建议:
- 对
pred_dist做归一化 - 再经一个小 MLP 变成 embedding
o_{t,k}: 置信度表示
由 pred_obj 经 sigmoid 得到 objectness,再做小 MLP 映射。
10.4 projector 的最终作用
SpatialTokenProjector 的任务是把:
- slot latent
- class context
- direction vector
- distance embedding
- objectness embedding
融合并投影到:
d_llm
输出:
llm_tokens: [B, N_keep, d_llm]
10.5 是否需要与原 audio encoder 对齐
当前结论:
- 不需要
因此:
- 这个 projector 完全独立训练
- 只服务于
Spatial-BEATs -> LLM
11. Loss 设计
11.1 任务头与 loss
推荐 loss 组成:
L_total =
lambda_obj * L_obj
+ lambda_azi * L_azi
+ lambda_ele * L_ele
+ lambda_dist * L_dist
+ lambda_cls * L_cls
11.2 各项定义
L_obj: BCE 或 focal loss,支持弱正样本权重L_azi: cross entropyL_ele: cross entropyL_dist: SmoothL1 / HuberL_cls: cross entropy
11.3 distance 回归的实现
由于你已经明确要连续回归,推荐:
- head 输出
pred_dist_norm in [0, 1] - 再乘以
distance_max_m
训练时使用:
SmoothL1Loss(pred_dist_norm, gt_dist_norm)
优点:
- 比直接回归未归一化距离更稳
- 比 MSE 更抗异常值
11.4 推荐初始权重
建议第一版从以下权重起步:
lambda_obj = 1.0
lambda_azi = 2.0
lambda_ele = 2.0
lambda_dist = 1.0
lambda_cls = 0.5
这里把 class auxiliary 权重从之前建议的 0.25 提到 0.5,因为现在你已经确认:
- 每个源有稳定的
source-level class label - Spatial-BEATs 自身保留一定语义信息是可接受的
11.5 当前版本的匹配方式修订
由于当前时间 supervision 是弱的,第一版不建议直接做:
- 全局
Hungarian matchingover[T_s, K]
更推荐:
- 先根据
source time window过滤候选时间步 - 再在候选窗口内做匹配
推荐实现:
window-constrained matching
可选两种方式:
方案 A:推荐默认方案
- 对每个 GT source
- 在其 window 内所有
(t, k)候选中,选择 cost 最小的一对
这本质上是:
best-of-window assignment
优点:
- 简单
- 稳定
- 对弱时间监督更友好
方案 B:后续增强方案
- 对每个时间步分别做 Hungarian
- 再加时间连续性正则
这更适合将来位置随时间变化时使用。
11.6 匹配 cost
Hungarian matching cost 建议:
cost =
w_obj * cost_obj
+ w_azi * cost_azi
+ w_ele * cost_ele
+ w_dist * cost_dist
+ w_cls * cost_cls
推荐初值:
w_obj = 1.0
w_azi = 2.0
w_ele = 2.0
w_dist = 1.0
w_cls = 1.0
12. 训练策略
12.1 第一阶段是否需要 SSL
当前明确结论:
- 第一版 不做新的 BEATs 式 SSL
理由:
- 已有空间 GT
- 已有 source class GT
- 已有强 trunk 预训练
- 当前主要目标是空间结构建模
12.2 推荐训练阶段
Stage A: Warmup
冻结:
- trunk 大部分层
训练:
- preprocessor
- patch stem
- temporal downsampler
- slot query decoder
- prediction heads
- projector
Stage B: Upper-trunk finetune
解冻:
- trunk 上层若干层
Stage C: Wider finetune
逐步解冻更多层,直到性能稳定。
12.3 训练时建议增加的正则项
当前位置在 clip 内固定,因此建议增加:
temporal consistency loss
具体可对同一 source 在相邻时间步的预测加约束:
- objectness 平滑
- azimuth/elevation 分布平滑
- distance 平滑
第一版可选实现:
L_temp =
smooth(pred_obj_t, pred_obj_{t+1})
+ smooth(pred_dist_t, pred_dist_{t+1})
由于当前位置固定,这类正则通常有利于稳定训练。
12.4 学习率建议
推荐:
lr_trunk = 1e-5 ~ 5e-5
lr_new = 1e-4 ~ 5e-4
并使用:
- weight decay
- warmup
- layer-wise lr decay
13. 训练与推理输出格式
13.1 训练时 forward() 输出
建议 forward() 返回:
{
"slot_tokens": FloatTensor[B, T_s, K, D],
"pred_obj": FloatTensor[B, T_s, K],
"pred_azi_logits": FloatTensor[B, T_s, K, 360],
"pred_ele_logits": FloatTensor[B, T_s, K, 180],
"pred_dist": FloatTensor[B, T_s, K, 1],
"pred_class_logits": FloatTensor[B, T_s, K, C_cls],
"llm_tokens": FloatTensor[B, N_keep, d_llm],
"llm_token_mask": BoolTensor[B, N_keep],
"token_meta": dict,
}
13.2 推理时建议额外输出
建议额外输出:
pred_azi_degpred_ele_degpred_dist_mpred_obj_probpred_class_id
便于后续可视化和调试。
14. 最小实现顺序
建议严格按以下顺序实现:
- 写
SpatialBEATsPreprocessor - 写
SpatialPatchEmbedding - 完成 trunk checkpoint 加载
- 写
TemporalDownsampler - 写
SlotQueryDecoder - 写
SpatialPredictionHead - 写
SpatialTokenProjector - 写
HungarianMatcher - 写
SpatialSetCriterion - 写 dataset 和训练脚本
15. 未来支持“位置随时间变化”时需要改什么
你已经说明:
- 当前 clip 内位置固定
- 后续会加入随时间变化的位置
这意味着当前模型结构基本可保留,但 target 和 decoder 训练方式需要升级。
15.1 当前结构哪些不用改
以下部分未来仍可直接保留:
FOA preprocessorpatch embeddingBEATs trunkTemporalDownsamplerSlotQueryDecoderSpatialTokenProjector2.5 Hztoken rate
15.2 未来必须改的部分
未来位置动态化后,需要改:
dataset target formatmatching strategyloss supervision
15.3 数据结构怎么升级
当前静态位置:
{
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
}
未来动态位置建议升级为:
{
"trajectory": [
{
"time_s": float,
"azimuth_deg": float,
"elevation_deg": float,
"distance_m": float,
},
...
]
}
或直接存成与 2.5 Hz 对齐的逐步 target:
{
"traj_azi_deg": FloatTensor[T_s],
"traj_ele_deg": FloatTensor[T_s],
"traj_dist_m": FloatTensor[T_s],
"traj_valid_mask": BoolTensor[T_s],
}
15.4 匹配怎么升级
当前位置固定时:
per-source best-of-window matching
未来位置变化时:
- 更适合改为
per-time-step matching - 或
track-level matching
推荐未来版本:
- 每个 source 对应一条 slot track
- 在整个时间维上维持 slot identity
15.5 loss 怎么升级
未来动态位置时:
L_azi / L_ele / L_dist应按时间步计算temporal consistency loss不能再强制“位置恒定”- 应改成“速度平滑”或“轨迹平滑”
也就是从:
constant-position regularization
升级成:
trajectory smoothness regularization
15.6 代码层面建议现在就预留的接口
为了兼容未来动态位置,当前第一版建议在数据与 loss 接口里预留:
is_position_dynamictrajectorytraj_valid_mask
即使第一版不用,也建议把字段和分支接口预留出来。
16. 当前仍需要确认的问题
虽然核心方案已经足够落地,但还有一个关键问题最好在编码前确认:
当前核心方案已经足够编码。
如果后续继续推进,唯一还值得尽早确认的是:
- 是否能从原始 source waveform 自动提取更精细的 energy/activity mask
如果可以,第一版的弱时间监督会明显更稳。
17. 结论
当前可以直接进入代码实现的最终方案是:
16k FOAWXYZ + IVK=42.5 Hzslot token streamsdistance连续回归class auxiliary head开启BEATs_iter3+ AS2M pre-trained作为 trunk 初始化Spatial-BEATs拥有自己的 projector- 最终输出自己的 LLM spatial tokens
- 当前时间 supervision 按
weak temporal window处理 - 当前位置 supervision 按
clip-level fixed position处理 - 未来动态位置仅需升级 target/matching/loss,不需要重写主干结构
这份文档已经足够作为第一版实现蓝图使用。