""" Rotary Position Embedding (RoPE) implementation. Applied to Q and K only, with fixed base (no dynamic scaling). """ import torch import torch.nn as nn from typing import Tuple class RotaryEmbedding(nn.Module): """Rotary Position Embedding (RoPE). RoPE encodes position information by rotating the query and key vectors. Key properties: - Parameter-free (no learnable embeddings) - Naturally encodes relative positions - Extrapolates well to longer sequences Reference: https://arxiv.org/abs/2104.09864 """ def __init__( self, dim: int, max_position_embeddings: int = 1024, base: float = 10000.0, ): """Initialize RoPE. Args: dim: Dimension of the rotary embedding (usually head_dim) max_position_embeddings: Maximum sequence length base: Base for the frequency computation """ super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base # Precompute inverse frequencies inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Precompute cos and sin for all positions self._set_cos_sin_cache(max_position_embeddings) def _set_cos_sin_cache(self, seq_len: int): """Precompute cos and sin values for positions.""" self.max_seq_len_cached = seq_len t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) # Outer product: [seq_len] x [dim/2] -> [seq_len, dim/2] freqs = torch.outer(t, self.inv_freq) # Concatenate to get [seq_len, dim] emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos(), persistent=False) self.register_buffer("sin_cached", emb.sin(), persistent=False) def forward( self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Apply rotary embeddings to query and key tensors. Args: q: Query tensor of shape [batch, num_heads, seq_len, head_dim] k: Key tensor of shape [batch, num_heads, seq_len, head_dim] position_ids: Position indices of shape [batch, seq_len] Returns: Tuple of (rotated_q, rotated_k) with same shapes as inputs """ seq_len = position_ids.max() + 1 # Extend cache if needed if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len) # Get cos and sin for the positions # Shape: [batch, seq_len, dim] cos = self.cos_cached[position_ids] sin = self.sin_cached[position_ids] # Add head dimension: [batch, 1, seq_len, dim] cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) # Apply rotation q_embed = (q * cos) + (self._rotate_half(q) * sin) k_embed = (k * cos) + (self._rotate_half(k) * sin) return q_embed, k_embed @staticmethod def _rotate_half(x: torch.Tensor) -> torch.Tensor: """Rotate half the hidden dims of the input. Splits the input into two halves and rotates: [x1, x2, x3, x4] -> [-x3, -x4, x1, x2] """ x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1)