import torch import torch.nn as nn import math from typing import Optional class PositionalEncoding(nn.Module): """ Sinusoidal positional encoding. Args: d_model: embedding dimension max_len: maximum sequence length to precompute scale: scaling factor for embeddings (default: 1.0) dropout: optional dropout rate (default: 0.1) """ def __init__( self, d_model: int, max_len: int = 5000, scale: float = 1.0, dropout: float = 0.1, ): super().__init__() self.d_model = d_model self.max_len = max_len self.scale = scale # Create positional encoding pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # Compute the division term div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) # Apply sin to even indices pe[:, 0::2] = torch.sin(position * div_term) # Apply cos to odd indices (if d_model is odd, last element remains 0) if d_model % 2 == 0: pe[:, 1::2] = torch.cos(position * div_term) else: pe[:, 1::2] = torch.cos(position * div_term[:-1]) # Scale if needed if scale != 1.0: pe = pe * scale # Register buffer (not a parameter, but part of module state) self.register_buffer("pe", pe) # Optional dropout self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None def forward(self, x: torch.Tensor) -> torch.Tensor: """ Add positional encoding to input tensor. Args: x: Tensor of shape [batch_size, seq_len, d_model] or [seq_len, batch_size, d_model] Returns: Tensor with positional encoding added """ if x.dim() == 3: # Batch-first: [B, S, D] seq_len = x.size(1) else: # Seq-first: [S, B, D] or potentially other shapes seq_len = x.size(0) # Add positional encoding (broadcasting works automatically) x = x + self.pe[:seq_len, :].view(1, seq_len, self.d_model) # Apply dropout if configured if self.dropout is not None: x = self.dropout(x) return x def get_pe(self, seq_len: int) -> torch.Tensor: """ Get positional encoding for given sequence length. Args: seq_len: sequence length Returns: Positional encoding tensor of shape [seq_len, d_model] """ return self.pe[:seq_len, :].clone() if __name__ == "__main__": # Teszt 1: Alap működés pe = PositionalEncoding(d_model=512, max_len=100) # Teszt input batch_size = 4 seq_len = 50 x = torch.randn(batch_size, seq_len, 512) # Forward pass output = pe(x) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") print(f"Difference shape: {(output - x).shape}") # Teszt 2: Hosszabb szekvencia mint max_len x_long = torch.randn(2, 6000, 512) try: output_long = pe(x_long) print("\nHosszú szekvencia sikeres!") except: print("\nHosszú szekvencia túlmutat a max_len-en!")