File size: 3,448 Bytes
900b898 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | import torch
import torch.nn as nn
import math
from typing import Optional
class PositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding.
Args:
d_model: embedding dimension
max_len: maximum sequence length to precompute
scale: scaling factor for embeddings (default: 1.0)
dropout: optional dropout rate (default: 0.1)
"""
def __init__(
self,
d_model: int,
max_len: int = 5000,
scale: float = 1.0,
dropout: float = 0.1,
):
super().__init__()
self.d_model = d_model
self.max_len = max_len
self.scale = scale
# Create positional encoding
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# Compute the division term
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
# Apply sin to even indices
pe[:, 0::2] = torch.sin(position * div_term)
# Apply cos to odd indices (if d_model is odd, last element remains 0)
if d_model % 2 == 0:
pe[:, 1::2] = torch.cos(position * div_term)
else:
pe[:, 1::2] = torch.cos(position * div_term[:-1])
# Scale if needed
if scale != 1.0:
pe = pe * scale
# Register buffer (not a parameter, but part of module state)
self.register_buffer("pe", pe)
# Optional dropout
self.dropout = nn.Dropout(p=dropout) if dropout > 0 else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding to input tensor.
Args:
x: Tensor of shape [batch_size, seq_len, d_model]
or [seq_len, batch_size, d_model]
Returns:
Tensor with positional encoding added
"""
if x.dim() == 3:
# Batch-first: [B, S, D]
seq_len = x.size(1)
else:
# Seq-first: [S, B, D] or potentially other shapes
seq_len = x.size(0)
# Add positional encoding (broadcasting works automatically)
x = x + self.pe[:seq_len, :].view(1, seq_len, self.d_model)
# Apply dropout if configured
if self.dropout is not None:
x = self.dropout(x)
return x
def get_pe(self, seq_len: int) -> torch.Tensor:
"""
Get positional encoding for given sequence length.
Args:
seq_len: sequence length
Returns:
Positional encoding tensor of shape [seq_len, d_model]
"""
return self.pe[:seq_len, :].clone()
if __name__ == "__main__":
# Teszt 1: Alap működés
pe = PositionalEncoding(d_model=512, max_len=100)
# Teszt input
batch_size = 4
seq_len = 50
x = torch.randn(batch_size, seq_len, 512)
# Forward pass
output = pe(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Difference shape: {(output - x).shape}")
# Teszt 2: Hosszabb szekvencia mint max_len
x_long = torch.randn(2, 6000, 512)
try:
output_long = pe(x_long)
print("\nHosszú szekvencia sikeres!")
except:
print("\nHosszú szekvencia túlmutat a max_len-en!")
|