""" model/rope.py Rotary Position Embedding (RoPE) — Su et al. 2021 (RoFormer). Used in LLaMA, Mistral, Gemma, etc. Core idea: Instead of adding position embeddings to token vectors, we ROTATE the query and key vectors in attention using position-dependent angles. - Relative positions are encoded implicitly via dot-product invariance. - Works for any sequence length (extrapolates beyond training length). - Only applied to Q and K, NOT V. Implementation: 1. Precompute cos/sin tables for all positions up to max_seq_len. Shape: (max_seq_len, head_dim) 2. At forward time, slice cos/sin to the current seq_len and apply rotation to Q and K. Rotation formula (pairs of dims): Given a vector x with dims [x0, x1, x2, x3, ...]: Pair each consecutive two dims: (x0,x1), (x2,x3), ... Rotate each pair by angle theta_i * position: [x0*cos - x1*sin, x0*sin + x1*cos, ...] Equivalent implementation using rotate_half: rotated = concat([-x_second_half, x_first_half]) # swapped halves out = x * cos + rotated * sin """ import torch import torch.nn as nn from typing import Tuple def precompute_rope_freqs( head_dim: int, max_seq_len: int, theta: float = 10_000.0, device: torch.device = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Precompute RoPE cosine and sine tables. Args: head_dim : dimension of each attention head (must be even) max_seq_len : max sequence length to precompute theta : RoPE base frequency (default 10_000, use 500_000 for long context) device : torch device Returns: cos : (max_seq_len, head_dim) sin : (max_seq_len, head_dim) """ assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}" # Inverse frequencies: shape (head_dim // 2,) # inv_freq[i] = 1 / theta^(2i / head_dim) i = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) inv_freq = 1.0 / (theta ** (i / head_dim)) # Position indices: shape (max_seq_len,) positions = torch.arange(max_seq_len, dtype=torch.float32, device=device) # Outer product: (max_seq_len, head_dim // 2) freqs = torch.outer(positions, inv_freq) # Duplicate along last dim to match head_dim: # (max_seq_len, head_dim // 2) -> (max_seq_len, head_dim) # cos/sin applied to [x0,x1,x2,x3,...] as [theta0,theta0, theta1,theta1, ...] freqs = torch.cat([freqs, freqs], dim=-1) return freqs.cos(), freqs.sin() def rotate_half(x: torch.Tensor) -> torch.Tensor: """ Rotates pairs of dimensions in the last axis. Splits last dim in half, negates the second half, then swaps: [x0..xN/2, xN/2..xN] -> [-xN/2..xN, x0..xN/2] Args: x: (..., head_dim) Returns: rotated: (..., head_dim) """ half = x.shape[-1] // 2 x1 = x[..., :half] # first half x2 = x[..., half:] # second half return torch.cat([-x2, x1], dim=-1) def apply_rope( q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply RoPE rotation to query and key tensors. Args: q : (B, n_heads, T, head_dim) k : (B, n_heads, T, head_dim) cos : (T, head_dim) - precomputed from precompute_rope_freqs sin : (T, head_dim) - precomputed from precompute_rope_freqs Returns: q_rot, k_rot : same shapes as inputs """ # Broadcast cos/sin from (T, head_dim) to (1, 1, T, head_dim) cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_rot = (q * cos) + (rotate_half(q) * sin) k_rot = (k * cos) + (rotate_half(k) * sin) return q_rot, k_rot class RoPECache(nn.Module): """ Module that holds the RoPE cos/sin cache as a buffer. Not a learnable module — just stores precomputed freqs and moves them to the right device automatically via register_buffer. """ def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10_000.0): super().__init__() cos, sin = precompute_rope_freqs(head_dim, max_seq_len, theta) # register_buffer: not a parameter, but moves with .to(device) self.register_buffer("cos", cos, persistent=True) self.register_buffer("sin", sin, persistent=True) def get(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: """Slice cos/sin to current sequence length.""" return self.cos[:seq_len], self.sin[:seq_len] # ------------------------------------------------------------------ # # QUICK CHECK # ------------------------------------------------------------------ # if __name__ == "__main__": torch.manual_seed(0) B, n_heads, T, head_dim = 2, 12, 16, 64 cos, sin = precompute_rope_freqs(head_dim, max_seq_len=1024) cos_T = cos[:T] sin_T = sin[:T] q = torch.randn(B, n_heads, T, head_dim) k = torch.randn(B, n_heads, T, head_dim) q_rot, k_rot = apply_rope(q, k, cos_T, sin_T) print(f"q shape : {q.shape}") print(f"q_rot shape : {q_rot.shape}") print(f"k_rot shape : {k_rot.shape}") # Verify: rotation should preserve norm (|x| = |Rx|) q_norm = q.norm(dim=-1) q_rot_norm = q_rot.norm(dim=-1) print(f"Norm preserved (q): {torch.allclose(q_norm, q_rot_norm, atol=1e-5)}") # Test RoPECache cache = RoPECache(head_dim=64, max_seq_len=1024) c, s = cache.get(T) print(f"Cache cos shape: {c.shape}") print("PASS")