File size: 783 Bytes
0cfefd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
"""非视觉 token 的可学习位置编码。

ego(8) / det(1024) / ctrl(24) / extra(256) 各自维护一份独立的
``[N, D]`` 可学习参数,初始化 ``trunc_normal(std=0.02)``。
直接加到对应 token 上,不参与 RoPE。
"""

from __future__ import annotations

import torch
import torch.nn as nn


class LearnedTokenPE(nn.Module):
    """形状为 ``[N, D]`` 的可学习位置编码,前向时按 batch 广播加。"""

    def __init__(self, num_tokens: int, dim: int, init_std: float = 0.02) -> None:
        super().__init__()
        self.pe = nn.Parameter(torch.empty(num_tokens, dim))
        nn.init.trunc_normal_(self.pe, std=init_std)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, N, D]
        return x + self.pe.unsqueeze(0)