File size: 7,468 Bytes
fb67af8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
"""Rotary Position Embeddings (RoPE) implementation.
Critical implementation details:
1. Apply RoPE only to Q and K, never to V
2. Use head_dim, not full model dimension
3. Ensure proper dimension pairing for rotation
"""
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
class RotaryPositionEmbeddings(nn.Module):
"""Rotary Position Embeddings (RoPE) for transformer models.
Based on the paper: 'RoFormer: Enhanced Transformer with Rotary Position Embedding'
https://arxiv.org/abs/2104.09864
"""
def __init__(
self,
head_dim: int,
max_seq_len: int = 2048,
base: int = 10000,
device: Optional[torch.device] = None,
):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
# CRITICAL: head_dim must be even for proper pairing
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
# Precompute frequencies
self._precompute_freqs(device)
def _precompute_freqs(self, device: Optional[torch.device] = None):
"""Precompute the frequency tensor for RoPE."""
# Calculate theta frequencies
# theta_i = base^(-2i/d) for i in [0, 1, ..., d/2-1]
theta = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
# Create position indices
positions = torch.arange(self.max_seq_len).float()
# Compute outer product: [seq_len, head_dim/2]
freqs = torch.einsum('i,j->ij', positions, theta)
# Convert to cos and sin for rotation
freqs_cos = torch.cos(freqs) # [seq_len, head_dim/2]
freqs_sin = torch.sin(freqs) # [seq_len, head_dim/2]
# Duplicate for full dimension coverage
# [seq_len, head_dim/2] -> [seq_len, head_dim]
freqs_cos = torch.cat([freqs_cos, freqs_cos], dim=-1)
freqs_sin = torch.cat([freqs_sin, freqs_sin], dim=-1)
# Register as buffers (not trainable, moves with model to device)
self.register_buffer('freqs_cos', freqs_cos, persistent=False)
self.register_buffer('freqs_sin', freqs_sin, persistent=False)
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input.
CRITICAL: This is the most common bug - incorrect dimension pairing.
For input [1, 2, 3, 4], output should be [-3, -4, 1, 2].
"""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_pos_emb(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings to query and key tensors.
Args:
q: Query tensor of shape [batch, seq_len, num_heads, head_dim]
k: Key tensor of shape [batch, seq_len, num_heads, head_dim]
position_ids: Optional custom position IDs
Returns:
Tuple of rotated (q, k) tensors
"""
seq_len = q.shape[1]
# Get the frequency tensors for current sequence length
if position_ids is not None:
freqs_cos = self.freqs_cos[position_ids]
freqs_sin = self.freqs_sin[position_ids]
else:
freqs_cos = self.freqs_cos[:seq_len]
freqs_sin = self.freqs_sin[:seq_len]
# Reshape for broadcasting
# [seq_len, head_dim] -> [1, seq_len, 1, head_dim]
freqs_cos = freqs_cos[None, :, None, :]
freqs_sin = freqs_sin[None, :, None, :]
# Apply rotation using the formula:
# x_rotated = x * cos + rotate_half(x) * sin
q_rotated = q * freqs_cos + self.rotate_half(q) * freqs_sin
k_rotated = k * freqs_cos + self.rotate_half(k) * freqs_sin
return q_rotated, k_rotated
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass - apply RoPE to Q and K only.
CRITICAL: Never apply RoPE to V (value) tensor!
"""
return self.apply_rotary_pos_emb(q, k, position_ids)
# Alternative implementation using complex numbers directly
class RotaryPositionEmbeddingsComplex(nn.Module):
"""Alternative RoPE implementation using complex number operations.
This can be more efficient on some hardware but requires careful handling.
"""
def __init__(
self,
head_dim: int,
max_seq_len: int = 2048,
base: int = 10000,
device: Optional[torch.device] = None,
):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
# Precompute complex exponentials
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len, dtype=inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, inv_freq)
# Store as cos/sin values
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos()[None, :, None, :])
self.register_buffer('sin_cached', emb.sin()[None, :, None, :])
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply RoPE using cached cos/sin values."""
if seq_len is None:
seq_len = q.shape[1]
# Apply rotation
q_embed = (q * self.cos_cached[:, :seq_len]) + \
(self.rotate_half(q) * self.sin_cached[:, :seq_len])
k_embed = (k * self.cos_cached[:, :seq_len]) + \
(self.rotate_half(k) * self.sin_cached[:, :seq_len])
return q_embed, k_embed
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
# Test function for RoPE
def test_rope():
"""Test RoPE implementation."""
print("Testing RoPE implementation...")
batch_size = 2
seq_len = 128
n_heads = 12
head_dim = 64
# Create RoPE module
rope = RotaryPositionEmbeddings(head_dim=head_dim, max_seq_len=2048)
# Create dummy Q and K tensors
q = torch.randn(batch_size, seq_len, n_heads, head_dim)
k = torch.randn(batch_size, seq_len, n_heads, head_dim)
# Apply RoPE
q_rot, k_rot = rope(q, k)
# Check shapes
assert q_rot.shape == q.shape, f"Q shape mismatch: {q_rot.shape} != {q.shape}"
assert k_rot.shape == k.shape, f"K shape mismatch: {k_rot.shape} != {k.shape}"
# Check for NaN
assert not torch.isnan(q_rot).any(), "Q contains NaN after RoPE"
assert not torch.isnan(k_rot).any(), "K contains NaN after RoPE"
print("✓ RoPE test passed!")
print(f" Input shape: {q.shape}")
print(f" Output shape: {q_rot.shape}")
return True
if __name__ == "__main__":
test_rope() |