"""非视觉 token 的可学习位置编码。 ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的 ``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。 直接加到对应 token 上,不参与 RoPE。 """ from __future__ import annotations import torch import torch.nn as nn class LearnedTokenPE(nn.Module): """形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。""" def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None: super().__init__() self.pe = nn.Parameter(torch.empty(num_tokens, dim)) nn.init.trunc_normal_(self.pe, std=init_std) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [B, N, D] return x + self.pe.unsqueeze(0)