duycse1603's picture
[Add] source
6163604
import torch
from torch import nn
__all__ = ['WordPosEnc']
class WordPosEnc(nn.Module):
def __init__(
self, d_model: int = 512, max_len: int = 500, temperature: float = 10000.0
) -> None:
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float)
dim_t = torch.arange(0, d_model, 2, dtype=torch.float)
div_term = 1.0 / (temperature ** (dim_t / d_model))
inv_freq = torch.einsum("i, j -> i j", position, div_term)
pe[:, 0::2] = inv_freq.sin()
pe[:, 1::2] = inv_freq.cos()
self.register_buffer("pe", pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
_, seq_len, _ = x.size()
emb = self.pe[:seq_len, :]
x = x + emb[None, :, :]
return x