| """非视觉 token 的可学习位置编码。 |
| |
| ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的 |
| ``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。 |
| 直接加到对应 token 上,不参与 RoPE。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class LearnedTokenPE(nn.Module): |
| """形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。""" |
|
|
| def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None: |
| super().__init__() |
| self.pe = nn.Parameter(torch.empty(num_tokens, dim)) |
| nn.init.trunc_normal_(self.pe, std=init_std) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| |
| return x + self.pe.unsqueeze(0) |
|
|