| import torch |
| import math |
| from torch import nn |
| class RotaryPositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=5000): |
| """ |
| 初始化 RotaryPositionalEncoding。 |
| |
| Args: |
| d_model (int): 特征维度。 |
| max_len (int): 支持的最大序列长度。 |
| """ |
| super(RotaryPositionalEncoding, self).__init__() |
| |
| |
| assert d_model % 2 == 0, "d_model must be even for RotaryPositionalEncoding." |
|
|
| |
| position = torch.arange(0, max_len).float().unsqueeze(1) |
| dim = torch.arange(0, d_model // 2).float() |
| div_term = torch.exp(dim * -(math.log(10000.0) / (d_model // 2))) |
|
|
| |
| angle = position * div_term |
| sin_part = torch.sin(angle) |
| cos_part = torch.cos(angle) |
|
|
| |
| pe = torch.cat([sin_part, cos_part], dim=-1) |
| pe = pe.unsqueeze(0).unsqueeze(0) |
| self.register_buffer('pe', pe) |
|
|
| def forward(self, x, offset=0): |
| """ |
| 前向传播。 |
| |
| Args: |
| x (Tensor): 输入张量,形状为 [batch_size, seq_len, d_model]。 |
| offset (int): 位置偏移量,默认为 0。 |
| |
| Returns: |
| Tensor: 应用旋转位置编码的张量。 |
| """ |
| |
| seq_len = x.size(1) |
| pe = self.pe[0, :, offset:offset + seq_len, :] |
|
|
| |
| x1, x2 = x[..., :x.size(-1)//2], x[..., x.size(-1)//2:] |
|
|
| |
| x_rotated = torch.cat([ |
| x1 * pe[..., :x.size(-1)//2] - x2 * pe[..., x.size(-1)//2:], |
| x1 * pe[..., x.size(-1)//2:] + x2 * pe[..., :x.size(-1)//2] |
| ], dim=-1) |
|
|
| return x_rotated |
|
|
| class ReRoPE: |
| def __init__(self, dim: int): |
| """ |
| 初始化 ReRoPE 编码器。 |
| |
| Args: |
| dim (int): 特征向量的维度(必须为偶数)。 |
| """ |
| assert dim % 2 == 0, "Dimension must be even for ReRoPE." |
| self.dim = dim |
| self.theta = self._compute_base_theta(dim) |
|
|
| @staticmethod |
| def _compute_base_theta(dim: int): |
| """ |
| 计算基本的 θ 值,用于旋转位置编码。 |
| |
| Args: |
| dim (int): 特征向量的维度。 |
| |
| Returns: |
| torch.Tensor: θ 值的张量。 |
| """ |
| theta = torch.tensor([10000 ** (-2 * (i // 2) / dim) for i in range(dim)]) |
| return theta |
|
|
| def forward(self, pos: torch.Tensor): |
| """ |
| 计算给定位置的 ReRoPE 编码。 |
| |
| Args: |
| pos (torch.Tensor): 位置索引的张量,形状为 [seq_len] 或 [batch_size, seq_len]。 |
| |
| Returns: |
| torch.Tensor: ReRoPE 编码,形状为 [seq_len, dim] 或 [batch_size, seq_len, dim]。 |
| """ |
| seq_len = pos.size(-1) |
| |
| angles = pos.unsqueeze(-1) * self.theta |
| sinusoidal_embedding = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) |
| return sinusoidal_embedding |
|
|
| @staticmethod |
| def apply_rotary_embedding(query, key, sincos): |
| """ |
| 应用旋转位置编码到查询和键。 |
| |
| Args: |
| query (torch.Tensor): 查询向量,形状为 [batch_size, seq_len, dim]. |
| key (torch.Tensor): 键向量,形状为 [batch_size, seq_len, dim]. |
| sincos (torch.Tensor): ReRoPE 编码,形状为 [seq_len, dim] 或 [batch_size, seq_len, dim]. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: 编码后的查询和键。 |
| """ |
| sin, cos = sincos[..., :query.size(-1)], sincos[..., query.size(-1):] |
| query_rotated = query * cos + torch.roll(query, shifts=1, dims=-1) * sin |
| key_rotated = key * cos + torch.roll(key, shifts=1, dims=-1) * sin |
| return query_rotated, key_rotated |
| |
| class LearnablePositionalEmbedding(nn.Module): |
| def __init__(self, d_model, max_len=5000): |
| super(LearnablePositionalEmbedding, self).__init__() |
| |
| self.pe = nn.Parameter(torch.zeros( |
| 1, 1, max_len, d_model), requires_grad=True) |
|
|
| pe = torch.zeros(max_len, d_model).float() |
| position = torch.arange(0, max_len).float().unsqueeze(1) |
| div_term = (torch.arange(0, d_model, 2).float() |
| * -(math.log(10000.0) / d_model)).exp() |
|
|
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
|
|
| pe = pe.unsqueeze(0).unsqueeze(0) |
| self.pe.data.copy_(pe.float()) |
| del pe |
|
|
| def forward(self, x, offset=0): |
| return self.pe[0, :, offset:offset+x.size(1), :] |
|
|
| class SinusoidalPositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=5000): |
| """ |
| 初始化 SinusoidalPositionalEncoding。 |
| |
| Args: |
| d_model (int): 特征维度。 |
| max_len (int): 支持的最大序列长度。 |
| """ |
| super(SinusoidalPositionalEncoding, self).__init__() |
| self.d_model = d_model |
|
|
| |
| pe = self._build_encoding(0, max_len, device=torch.device("cpu")) |
|
|
| |
| pe = pe.unsqueeze(0).unsqueeze(0) |
| self.register_buffer('pe', pe) |
|
|
| def _build_encoding(self, start, end, device): |
| position = torch.arange(start, end, device=device, dtype=torch.float32).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, self.d_model, 2, device=device, dtype=torch.float32) |
| * -(math.log(10000.0) / self.d_model) |
| ) |
| pe = torch.zeros(end - start, self.d_model, device=device, dtype=torch.float32) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| return pe |
|
|
| def forward(self, x, offset=0): |
| """ |
| 前向传播。 |
| |
| Args: |
| x (Tensor): 输入张量,形状为 [batch_size, seq_len, d_model] |
| offset (int): 位置偏移量,默认为 0。 |
| |
| Returns: |
| Tensor: 添加了位置编码的张量。 |
| """ |
| end = offset + x.size(1) |
| if end <= self.pe.size(2): |
| encoding = self.pe[0, :, offset:end, :] |
| else: |
| |
| |
| |
| encoding = self._build_encoding(offset, end, x.device).unsqueeze(0) |
| return encoding.to(device=x.device, dtype=x.dtype) |
|
|
| if __name__ == "__main__": |
| d_model = 64 |
| seq_len = 128 |
| batch_size = 32 |
|
|
| pos_encoder = LearnablePositionalEmbedding(d_model=d_model, max_len=5000) |
| x = torch.randn(batch_size, seq_len, d_model) |
| pos_encoded_x = pos_encoder(x) |
| print("Shape of Positional Encoded Output:", pos_encoded_x.shape) |
|
|