File size: 3,448 Bytes
900b898
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
import math
from typing import Optional


class PositionalEncoding(nn.Module):
    """

    Sinusoidal positional encoding.



    Args:

        d_model: embedding dimension

        max_len: maximum sequence length to precompute

        scale: scaling factor for embeddings (default: 1.0)

        dropout: optional dropout rate (default: 0.1)

    """

    def __init__(

        self,

        d_model: int,

        max_len: int = 5000,

        scale: float = 1.0,

        dropout: float = 0.1,

    ):
        super().__init__()
        self.d_model = d_model
        self.max_len = max_len
        self.scale = scale

        # Create positional encoding
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # Compute the division term
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )

        # Apply sin to even indices
        pe[:, 0::2] = torch.sin(position * div_term)

        # Apply cos to odd indices (if d_model is odd, last element remains 0)
        if d_model % 2 == 0:
            pe[:, 1::2] = torch.cos(position * div_term)
        else:
            pe[:, 1::2] = torch.cos(position * div_term[:-1])

        # Scale if needed
        if scale != 1.0:
            pe = pe * scale

        # Register buffer (not a parameter, but part of module state)
        self.register_buffer("pe", pe)

        # Optional dropout
        self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Add positional encoding to input tensor.



        Args:

            x: Tensor of shape [batch_size, seq_len, d_model]

               or [seq_len, batch_size, d_model]



        Returns:

            Tensor with positional encoding added

        """
        if x.dim() == 3:
            # Batch-first: [B, S, D]
            seq_len = x.size(1)
        else:
            # Seq-first: [S, B, D] or potentially other shapes
            seq_len = x.size(0)

        # Add positional encoding (broadcasting works automatically)
        x = x + self.pe[:seq_len, :].view(1, seq_len, self.d_model)

        # Apply dropout if configured
        if self.dropout is not None:
            x = self.dropout(x)

        return x

    def get_pe(self, seq_len: int) -> torch.Tensor:
        """

        Get positional encoding for given sequence length.



        Args:

            seq_len: sequence length



        Returns:

            Positional encoding tensor of shape [seq_len, d_model]

        """
        return self.pe[:seq_len, :].clone()


if __name__ == "__main__":
    # Teszt 1: Alap működés
    pe = PositionalEncoding(d_model=512, max_len=100)

    # Teszt input
    batch_size = 4
    seq_len = 50
    x = torch.randn(batch_size, seq_len, 512)

    # Forward pass
    output = pe(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Difference shape: {(output - x).shape}")

    # Teszt 2: Hosszabb szekvencia mint max_len
    x_long = torch.randn(2, 6000, 512)
    try:
        output_long = pe(x_long)
        print("\nHosszú szekvencia sikeres!")
    except:
        print("\nHosszú szekvencia túlmutat a max_len-en!")