sllm / model /rope.py
geeteshcodes's picture
Initial commit
7f974df verified
"""
model/rope.py
Rotary Position Embedding (RoPE) — Su et al. 2021 (RoFormer).
Used in LLaMA, Mistral, Gemma, etc.
Core idea:
Instead of adding position embeddings to token vectors, we ROTATE
the query and key vectors in attention using position-dependent angles.
- Relative positions are encoded implicitly via dot-product invariance.
- Works for any sequence length (extrapolates beyond training length).
- Only applied to Q and K, NOT V.
Implementation:
1. Precompute cos/sin tables for all positions up to max_seq_len.
Shape: (max_seq_len, head_dim)
2. At forward time, slice cos/sin to the current seq_len and
apply rotation to Q and K.
Rotation formula (pairs of dims):
Given a vector x with dims [x0, x1, x2, x3, ...]:
Pair each consecutive two dims: (x0,x1), (x2,x3), ...
Rotate each pair by angle theta_i * position:
[x0*cos - x1*sin, x0*sin + x1*cos, ...]
Equivalent implementation using rotate_half:
rotated = concat([-x_second_half, x_first_half]) # swapped halves
out = x * cos + rotated * sin
"""
import torch
import torch.nn as nn
from typing import Tuple
def precompute_rope_freqs(
head_dim: int,
max_seq_len: int,
theta: float = 10_000.0,
device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Precompute RoPE cosine and sine tables.
Args:
head_dim : dimension of each attention head (must be even)
max_seq_len : max sequence length to precompute
theta : RoPE base frequency (default 10_000, use 500_000 for long context)
device : torch device
Returns:
cos : (max_seq_len, head_dim)
sin : (max_seq_len, head_dim)
"""
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
# Inverse frequencies: shape (head_dim // 2,)
# inv_freq[i] = 1 / theta^(2i / head_dim)
i = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (theta ** (i / head_dim))
# Position indices: shape (max_seq_len,)
positions = torch.arange(max_seq_len, dtype=torch.float32, device=device)
# Outer product: (max_seq_len, head_dim // 2)
freqs = torch.outer(positions, inv_freq)
# Duplicate along last dim to match head_dim:
# (max_seq_len, head_dim // 2) -> (max_seq_len, head_dim)
# cos/sin applied to [x0,x1,x2,x3,...] as [theta0,theta0, theta1,theta1, ...]
freqs = torch.cat([freqs, freqs], dim=-1)
return freqs.cos(), freqs.sin()
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
Rotates pairs of dimensions in the last axis.
Splits last dim in half, negates the second half, then swaps:
[x0..xN/2, xN/2..xN] -> [-xN/2..xN, x0..xN/2]
Args:
x: (..., head_dim)
Returns:
rotated: (..., head_dim)
"""
half = x.shape[-1] // 2
x1 = x[..., :half] # first half
x2 = x[..., half:] # second half
return torch.cat([-x2, x1], dim=-1)
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply RoPE rotation to query and key tensors.
Args:
q : (B, n_heads, T, head_dim)
k : (B, n_heads, T, head_dim)
cos : (T, head_dim) - precomputed from precompute_rope_freqs
sin : (T, head_dim) - precomputed from precompute_rope_freqs
Returns:
q_rot, k_rot : same shapes as inputs
"""
# Broadcast cos/sin from (T, head_dim) to (1, 1, T, head_dim)
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_rot = (q * cos) + (rotate_half(q) * sin)
k_rot = (k * cos) + (rotate_half(k) * sin)
return q_rot, k_rot
class RoPECache(nn.Module):
"""
Module that holds the RoPE cos/sin cache as a buffer.
Not a learnable module — just stores precomputed freqs and moves them
to the right device automatically via register_buffer.
"""
def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10_000.0):
super().__init__()
cos, sin = precompute_rope_freqs(head_dim, max_seq_len, theta)
# register_buffer: not a parameter, but moves with .to(device)
self.register_buffer("cos", cos, persistent=True)
self.register_buffer("sin", sin, persistent=True)
def get(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Slice cos/sin to current sequence length."""
return self.cos[:seq_len], self.sin[:seq_len]
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
torch.manual_seed(0)
B, n_heads, T, head_dim = 2, 12, 16, 64
cos, sin = precompute_rope_freqs(head_dim, max_seq_len=1024)
cos_T = cos[:T]
sin_T = sin[:T]
q = torch.randn(B, n_heads, T, head_dim)
k = torch.randn(B, n_heads, T, head_dim)
q_rot, k_rot = apply_rope(q, k, cos_T, sin_T)
print(f"q shape : {q.shape}")
print(f"q_rot shape : {q_rot.shape}")
print(f"k_rot shape : {k_rot.shape}")
# Verify: rotation should preserve norm (|x| = |Rx|)
q_norm = q.norm(dim=-1)
q_rot_norm = q_rot.norm(dim=-1)
print(f"Norm preserved (q): {torch.allclose(q_norm, q_rot_norm, atol=1e-5)}")
# Test RoPECache
cache = RoPECache(head_dim=64, max_seq_len=1024)
c, s = cache.get(T)
print(f"Cache cos shape: {c.shape}")
print("PASS")