import torch import math from torch import nn class RotaryPositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): """ 初始化 RotaryPositionalEncoding。 Args: d_model (int): 特征维度。 max_len (int): 支持的最大序列长度。 """ super(RotaryPositionalEncoding, self).__init__() # 确保特征维度为偶数 assert d_model % 2 == 0, "d_model must be even for RotaryPositionalEncoding." # 创建旋转位置编码矩阵 position = torch.arange(0, max_len).float().unsqueeze(1) # [max_len, 1] dim = torch.arange(0, d_model // 2).float() # [d_model // 2] div_term = torch.exp(dim * -(math.log(10000.0) / (d_model // 2))) # [d_model // 2] # 计算正弦和余弦部分 angle = position * div_term # [max_len, d_model // 2] sin_part = torch.sin(angle) # 正弦部分 cos_part = torch.cos(angle) # 余弦部分 # 将 sin 和 cos 部分堆叠 pe = torch.cat([sin_part, cos_part], dim=-1) # [max_len, d_model] pe = pe.unsqueeze(0).unsqueeze(0) # [1, 1, max_len, d_model] self.register_buffer('pe', pe) # 注册为非参数张量 def forward(self, x, offset=0): """ 前向传播。 Args: x (Tensor): 输入张量,形状为 [batch_size, seq_len, d_model]。 offset (int): 位置偏移量,默认为 0。 Returns: Tensor: 应用旋转位置编码的张量。 """ # 获取位置编码 seq_len = x.size(1) pe = self.pe[0, :, offset:offset + seq_len, :] # [1, seq_len, d_model] # 将输入张量拆分为两部分:前半部分与后半部分 x1, x2 = x[..., :x.size(-1)//2], x[..., x.size(-1)//2:] # [batch_size, seq_len, d_model//2] # 应用旋转操作 x_rotated = torch.cat([ x1 * pe[..., :x.size(-1)//2] - x2 * pe[..., x.size(-1)//2:], x1 * pe[..., x.size(-1)//2:] + x2 * pe[..., :x.size(-1)//2] ], dim=-1) # [batch_size, seq_len, d_model] return x_rotated class ReRoPE: def __init__(self, dim: int): """ 初始化 ReRoPE 编码器。 Args: dim (int): 特征向量的维度(必须为偶数)。 """ assert dim % 2 == 0, "Dimension must be even for ReRoPE." self.dim = dim self.theta = self._compute_base_theta(dim) @staticmethod def _compute_base_theta(dim: int): """ 计算基本的 θ 值,用于旋转位置编码。 Args: dim (int): 特征向量的维度。 Returns: torch.Tensor: θ 值的张量。 """ theta = torch.tensor([10000 ** (-2 * (i // 2) / dim) for i in range(dim)]) return theta def forward(self, pos: torch.Tensor): """ 计算给定位置的 ReRoPE 编码。 Args: pos (torch.Tensor): 位置索引的张量,形状为 [seq_len] 或 [batch_size, seq_len]。 Returns: torch.Tensor: ReRoPE 编码,形状为 [seq_len, dim] 或 [batch_size, seq_len, dim]。 """ seq_len = pos.size(-1) # 获取正弦和余弦部分 angles = pos.unsqueeze(-1) * self.theta sinusoidal_embedding = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1) return sinusoidal_embedding @staticmethod def apply_rotary_embedding(query, key, sincos): """ 应用旋转位置编码到查询和键。 Args: query (torch.Tensor): 查询向量,形状为 [batch_size, seq_len, dim]. key (torch.Tensor): 键向量,形状为 [batch_size, seq_len, dim]. sincos (torch.Tensor): ReRoPE 编码,形状为 [seq_len, dim] 或 [batch_size, seq_len, dim]. Returns: Tuple[torch.Tensor, torch.Tensor]: 编码后的查询和键。 """ sin, cos = sincos[..., :query.size(-1)], sincos[..., query.size(-1):] query_rotated = query * cos + torch.roll(query, shifts=1, dims=-1) * sin key_rotated = key * cos + torch.roll(key, shifts=1, dims=-1) * sin return query_rotated, key_rotated class LearnablePositionalEmbedding(nn.Module): def __init__(self, d_model, max_len=5000): super(LearnablePositionalEmbedding, self).__init__() # Compute the positional encodings once in log space. self.pe = nn.Parameter(torch.zeros( 1, 1, max_len, d_model), requires_grad=True) pe = torch.zeros(max_len, d_model).float() position = torch.arange(0, max_len).float().unsqueeze(1) div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).unsqueeze(0) self.pe.data.copy_(pe.float()) del pe def forward(self, x, offset=0): return self.pe[0, :, offset:offset+x.size(1), :] class SinusoidalPositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): """ 初始化 SinusoidalPositionalEncoding。 Args: d_model (int): 特征维度。 max_len (int): 支持的最大序列长度。 """ super(SinusoidalPositionalEncoding, self).__init__() self.d_model = d_model # 创建一个固定的位置编码矩阵 pe = self._build_encoding(0, max_len, device=torch.device("cpu")) # 将矩阵增加批量维度,方便在 forward 中直接使用 pe = pe.unsqueeze(0).unsqueeze(0) # [1, 1, max_len, d_model] self.register_buffer('pe', pe) # 注册为非参数张量 def _build_encoding(self, start, end, device): position = torch.arange(start, end, device=device, dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, device=device, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) ) pe = torch.zeros(end - start, self.d_model, device=device, dtype=torch.float32) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) return pe def forward(self, x, offset=0): """ 前向传播。 Args: x (Tensor): 输入张量,形状为 [batch_size, seq_len, d_model] offset (int): 位置偏移量,默认为 0。 Returns: Tensor: 添加了位置编码的张量。 """ end = offset + x.size(1) if end <= self.pe.size(2): encoding = self.pe[0, :, offset:end, :] else: # Some TSQA questions exceed the original fixed 5,000 positions. # Generate the required range on demand without changing the # checkpoint buffer shape. encoding = self._build_encoding(offset, end, x.device).unsqueeze(0) return encoding.to(device=x.device, dtype=x.dtype) if __name__ == "__main__": d_model = 64 seq_len = 128 batch_size = 32 pos_encoder = LearnablePositionalEmbedding(d_model=d_model, max_len=5000) x = torch.randn(batch_size, seq_len, d_model) # 随机输入张量 pos_encoded_x = pos_encoder(x) print("Shape of Positional Encoded Output:", pos_encoded_x.shape)