WJAD / src /wjad /modules /learned_pe.py
fuzirui's picture
Sync WJAD codebase
0cfefd2 verified
raw
history blame contribute delete
783 Bytes
"""非视觉 token 的可学习位置编码。
ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的
``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。
直接加到对应 token 上,不参与 RoPE。
"""
from __future__ import annotations
import torch
import torch.nn as nn
class LearnedTokenPE(nn.Module):
"""形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。"""
def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None:
super().__init__()
self.pe = nn.Parameter(torch.empty(num_tokens, dim))
nn.init.trunc_normal_(self.pe, std=init_std)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [B, N, D]
return x + self.pe.unsqueeze(0)