|
|
from typing import Union |
|
|
|
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
|
|
|
|
|
|
class LearnedPositionEmbeddings(nn.Module): |
|
|
def __init__(self, seq_len, model_dim, init=.02): |
|
|
super().__init__() |
|
|
self.emb = nn.Embedding(seq_len, model_dim) |
|
|
|
|
|
self.emb.weight.data.normal_(mean=0.0, std=init) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Returns positional embeddings for index 0 up to the length of x |
|
|
""" |
|
|
sl = x.shape[1] |
|
|
return self.emb(torch.arange(0, sl, device=x.device)) |
|
|
|
|
|
def get_fixed_embedding(self, idx: 'Union[int, Tensor]'): |
|
|
""" |
|
|
Args: |
|
|
idx: scalar int or an integer tensor of shape (T,) or (B, T) |
|
|
Returns: |
|
|
positional embeddings for given indices, shape (B, T, dim), ie (1, 1, dim) for int input |
|
|
""" |
|
|
device = self.emb.weight.device |
|
|
idx = idx.to(device) if torch.is_tensor(idx) else torch.tensor(idx, device=device) |
|
|
idx = torch.atleast_2d(idx) |
|
|
assert idx.ndim == 2 |
|
|
return self.emb(idx) |
|
|
|