File size: 2,423 Bytes
dbd79bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6faa82b
dbd79bd
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                                                           #
#   This file was created by: Alberto Palomo Alonso         #
# Universidad de Alcalá - Escuela Politécnica Superior      #
#                                                           #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch


class PositionalEncoding(torch.nn.Module):
    """
    Sinusoidal positional encoding module for Transformer models.

    This module injects information about the relative or absolute position of
    tokens in a sequence by adding fixed sinusoidal embeddings to the input
    embeddings. The positional encodings are non-learnable and follow the
    formulation introduced in the original Transformer architecture.
    """
    def __init__(self, emb_dim: int, max_len: int = 5000, **kwargs):
        """
        Initialize the positional encoding module.

        Parameters
        ----------
        emb_dim : int
            Dimensionality of the embedding space.
        max_len : int, optional
            Maximum supported sequence length for which positional encodings
            are precomputed.
        """
        super().__init__(**kwargs)

        # Create positional encodings:
        pe = torch.zeros(max_len, emb_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, emb_dim, 2).float() * -(torch.log(torch.tensor(10000.0)) / emb_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)

        # Register as a buffer:
        self.register_buffer('positional_encoding', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Add positional encodings to the input embeddings.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, emb_dim).

        Returns
        -------
        torch.Tensor
            Tensor of the same shape as the input with positional encodings added.
        """
        return x + self.positional_encoding[:, :x.size(-2), :]
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
#                        END OF FILE                        #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #