| """ |
| model/rope.py |
| |
| Rotary Position Embedding (RoPE) — Su et al. 2021 (RoFormer). |
| Used in LLaMA, Mistral, Gemma, etc. |
| |
| Core idea: |
| Instead of adding position embeddings to token vectors, we ROTATE |
| the query and key vectors in attention using position-dependent angles. |
| |
| - Relative positions are encoded implicitly via dot-product invariance. |
| - Works for any sequence length (extrapolates beyond training length). |
| - Only applied to Q and K, NOT V. |
| |
| Implementation: |
| 1. Precompute cos/sin tables for all positions up to max_seq_len. |
| Shape: (max_seq_len, head_dim) |
| |
| 2. At forward time, slice cos/sin to the current seq_len and |
| apply rotation to Q and K. |
| |
| Rotation formula (pairs of dims): |
| Given a vector x with dims [x0, x1, x2, x3, ...]: |
| Pair each consecutive two dims: (x0,x1), (x2,x3), ... |
| Rotate each pair by angle theta_i * position: |
| [x0*cos - x1*sin, x0*sin + x1*cos, ...] |
| |
| Equivalent implementation using rotate_half: |
| rotated = concat([-x_second_half, x_first_half]) # swapped halves |
| out = x * cos + rotated * sin |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Tuple |
|
|
|
|
| def precompute_rope_freqs( |
| head_dim: int, |
| max_seq_len: int, |
| theta: float = 10_000.0, |
| device: torch.device = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Precompute RoPE cosine and sine tables. |
| |
| Args: |
| head_dim : dimension of each attention head (must be even) |
| max_seq_len : max sequence length to precompute |
| theta : RoPE base frequency (default 10_000, use 500_000 for long context) |
| device : torch device |
| |
| Returns: |
| cos : (max_seq_len, head_dim) |
| sin : (max_seq_len, head_dim) |
| """ |
| assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}" |
|
|
| |
| |
| i = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) |
| inv_freq = 1.0 / (theta ** (i / head_dim)) |
|
|
| |
| positions = torch.arange(max_seq_len, dtype=torch.float32, device=device) |
|
|
| |
| freqs = torch.outer(positions, inv_freq) |
|
|
| |
| |
| |
| freqs = torch.cat([freqs, freqs], dim=-1) |
|
|
| return freqs.cos(), freqs.sin() |
|
|
|
|
| def rotate_half(x: torch.Tensor) -> torch.Tensor: |
| """ |
| Rotates pairs of dimensions in the last axis. |
| Splits last dim in half, negates the second half, then swaps: |
| [x0..xN/2, xN/2..xN] -> [-xN/2..xN, x0..xN/2] |
| |
| Args: |
| x: (..., head_dim) |
| Returns: |
| rotated: (..., head_dim) |
| """ |
| half = x.shape[-1] // 2 |
| x1 = x[..., :half] |
| x2 = x[..., half:] |
| return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
| def apply_rope( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| cos: torch.Tensor, |
| sin: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Apply RoPE rotation to query and key tensors. |
| |
| Args: |
| q : (B, n_heads, T, head_dim) |
| k : (B, n_heads, T, head_dim) |
| cos : (T, head_dim) - precomputed from precompute_rope_freqs |
| sin : (T, head_dim) - precomputed from precompute_rope_freqs |
| |
| Returns: |
| q_rot, k_rot : same shapes as inputs |
| """ |
| |
| cos = cos.unsqueeze(0).unsqueeze(0) |
| sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
| q_rot = (q * cos) + (rotate_half(q) * sin) |
| k_rot = (k * cos) + (rotate_half(k) * sin) |
| return q_rot, k_rot |
|
|
|
|
| class RoPECache(nn.Module): |
| """ |
| Module that holds the RoPE cos/sin cache as a buffer. |
| Not a learnable module — just stores precomputed freqs and moves them |
| to the right device automatically via register_buffer. |
| """ |
|
|
| def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10_000.0): |
| super().__init__() |
| cos, sin = precompute_rope_freqs(head_dim, max_seq_len, theta) |
| |
| self.register_buffer("cos", cos, persistent=True) |
| self.register_buffer("sin", sin, persistent=True) |
|
|
| def get(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Slice cos/sin to current sequence length.""" |
| return self.cos[:seq_len], self.sin[:seq_len] |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| torch.manual_seed(0) |
|
|
| B, n_heads, T, head_dim = 2, 12, 16, 64 |
|
|
| cos, sin = precompute_rope_freqs(head_dim, max_seq_len=1024) |
| cos_T = cos[:T] |
| sin_T = sin[:T] |
|
|
| q = torch.randn(B, n_heads, T, head_dim) |
| k = torch.randn(B, n_heads, T, head_dim) |
|
|
| q_rot, k_rot = apply_rope(q, k, cos_T, sin_T) |
|
|
| print(f"q shape : {q.shape}") |
| print(f"q_rot shape : {q_rot.shape}") |
| print(f"k_rot shape : {k_rot.shape}") |
|
|
| |
| q_norm = q.norm(dim=-1) |
| q_rot_norm = q_rot.norm(dim=-1) |
| print(f"Norm preserved (q): {torch.allclose(q_norm, q_rot_norm, atol=1e-5)}") |
|
|
| |
| cache = RoPECache(head_dim=64, max_seq_len=1024) |
| c, s = cache.get(T) |
| print(f"Cache cos shape: {c.shape}") |
| print("PASS") |
|
|