File size: 2,423 Bytes
dbd79bd 6faa82b dbd79bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
class PositionalEncoding(torch.nn.Module):
"""
Sinusoidal positional encoding module for Transformer models.
This module injects information about the relative or absolute position of
tokens in a sequence by adding fixed sinusoidal embeddings to the input
embeddings. The positional encodings are non-learnable and follow the
formulation introduced in the original Transformer architecture.
"""
def __init__(self, emb_dim: int, max_len: int = 5000, **kwargs):
"""
Initialize the positional encoding module.
Parameters
----------
emb_dim : int
Dimensionality of the embedding space.
max_len : int, optional
Maximum supported sequence length for which positional encodings
are precomputed.
"""
super().__init__(**kwargs)
# Create positional encodings:
pe = torch.zeros(max_len, emb_dim)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / emb_dim))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Register as a buffer:
self.register_buffer('positional_encoding', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encodings to the input embeddings.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, sequence_length, emb_dim).
Returns
-------
torch.Tensor
Tensor of the same shape as the input with positional encodings added.
"""
return x + self.positional_encoding[:, :x.size(-2), :]
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
|