File size: 5,572 Bytes
7f974df | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
model/rope.py
Rotary Position Embedding (RoPE) — Su et al. 2021 (RoFormer).
Used in LLaMA, Mistral, Gemma, etc.
Core idea:
Instead of adding position embeddings to token vectors, we ROTATE
the query and key vectors in attention using position-dependent angles.
- Relative positions are encoded implicitly via dot-product invariance.
- Works for any sequence length (extrapolates beyond training length).
- Only applied to Q and K, NOT V.
Implementation:
1. Precompute cos/sin tables for all positions up to max_seq_len.
Shape: (max_seq_len, head_dim)
2. At forward time, slice cos/sin to the current seq_len and
apply rotation to Q and K.
Rotation formula (pairs of dims):
Given a vector x with dims [x0, x1, x2, x3, ...]:
Pair each consecutive two dims: (x0,x1), (x2,x3), ...
Rotate each pair by angle theta_i * position:
[x0*cos - x1*sin, x0*sin + x1*cos, ...]
Equivalent implementation using rotate_half:
rotated = concat([-x_second_half, x_first_half]) # swapped halves
out = x * cos + rotated * sin
"""
import torch
import torch.nn as nn
from typing import Tuple
def precompute_rope_freqs(
head_dim: int,
max_seq_len: int,
theta: float = 10_000.0,
device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Precompute RoPE cosine and sine tables.
Args:
head_dim : dimension of each attention head (must be even)
max_seq_len : max sequence length to precompute
theta : RoPE base frequency (default 10_000, use 500_000 for long context)
device : torch device
Returns:
cos : (max_seq_len, head_dim)
sin : (max_seq_len, head_dim)
"""
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
# Inverse frequencies: shape (head_dim // 2,)
# inv_freq[i] = 1 / theta^(2i / head_dim)
i = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (theta ** (i / head_dim))
# Position indices: shape (max_seq_len,)
positions = torch.arange(max_seq_len, dtype=torch.float32, device=device)
# Outer product: (max_seq_len, head_dim // 2)
freqs = torch.outer(positions, inv_freq)
# Duplicate along last dim to match head_dim:
# (max_seq_len, head_dim // 2) -> (max_seq_len, head_dim)
# cos/sin applied to [x0,x1,x2,x3,...] as [theta0,theta0, theta1,theta1, ...]
freqs = torch.cat([freqs, freqs], dim=-1)
return freqs.cos(), freqs.sin()
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""
Rotates pairs of dimensions in the last axis.
Splits last dim in half, negates the second half, then swaps:
[x0..xN/2, xN/2..xN] -> [-xN/2..xN, x0..xN/2]
Args:
x: (..., head_dim)
Returns:
rotated: (..., head_dim)
"""
half = x.shape[-1] // 2
x1 = x[..., :half] # first half
x2 = x[..., half:] # second half
return torch.cat([-x2, x1], dim=-1)
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply RoPE rotation to query and key tensors.
Args:
q : (B, n_heads, T, head_dim)
k : (B, n_heads, T, head_dim)
cos : (T, head_dim) - precomputed from precompute_rope_freqs
sin : (T, head_dim) - precomputed from precompute_rope_freqs
Returns:
q_rot, k_rot : same shapes as inputs
"""
# Broadcast cos/sin from (T, head_dim) to (1, 1, T, head_dim)
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_rot = (q * cos) + (rotate_half(q) * sin)
k_rot = (k * cos) + (rotate_half(k) * sin)
return q_rot, k_rot
class RoPECache(nn.Module):
"""
Module that holds the RoPE cos/sin cache as a buffer.
Not a learnable module — just stores precomputed freqs and moves them
to the right device automatically via register_buffer.
"""
def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10_000.0):
super().__init__()
cos, sin = precompute_rope_freqs(head_dim, max_seq_len, theta)
# register_buffer: not a parameter, but moves with .to(device)
self.register_buffer("cos", cos, persistent=True)
self.register_buffer("sin", sin, persistent=True)
def get(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Slice cos/sin to current sequence length."""
return self.cos[:seq_len], self.sin[:seq_len]
# ------------------------------------------------------------------ #
# QUICK CHECK
# ------------------------------------------------------------------ #
if __name__ == "__main__":
torch.manual_seed(0)
B, n_heads, T, head_dim = 2, 12, 16, 64
cos, sin = precompute_rope_freqs(head_dim, max_seq_len=1024)
cos_T = cos[:T]
sin_T = sin[:T]
q = torch.randn(B, n_heads, T, head_dim)
k = torch.randn(B, n_heads, T, head_dim)
q_rot, k_rot = apply_rope(q, k, cos_T, sin_T)
print(f"q shape : {q.shape}")
print(f"q_rot shape : {q_rot.shape}")
print(f"k_rot shape : {k_rot.shape}")
# Verify: rotation should preserve norm (|x| = |Rx|)
q_norm = q.norm(dim=-1)
q_rot_norm = q_rot.norm(dim=-1)
print(f"Norm preserved (q): {torch.allclose(q_norm, q_rot_norm, atol=1e-5)}")
# Test RoPECache
cache = RoPECache(head_dim=64, max_seq_len=1024)
c, s = cache.get(T)
print(f"Cache cos shape: {c.shape}")
print("PASS")
|