Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| __all__ = ['WordPosEnc'] | |
| class WordPosEnc(nn.Module): | |
| def __init__( | |
| self, d_model: int = 512, max_len: int = 500, temperature: float = 10000.0 | |
| ) -> None: | |
| super().__init__() | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float) | |
| dim_t = torch.arange(0, d_model, 2, dtype=torch.float) | |
| div_term = 1.0 / (temperature ** (dim_t / d_model)) | |
| inv_freq = torch.einsum("i, j -> i j", position, div_term) | |
| pe[:, 0::2] = inv_freq.sin() | |
| pe[:, 1::2] = inv_freq.cos() | |
| self.register_buffer("pe", pe) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| _, seq_len, _ = x.size() | |
| emb = self.pe[:seq_len, :] | |
| x = x + emb[None, :, :] | |
| return x | |