karthick
Upload TinyStories 24.5M model - article generation success
fb67af8
"""Rotary Position Embeddings (RoPE) implementation.
Critical implementation details:
1. Apply RoPE only to Q and K, never to V
2. Use head_dim, not full model dimension
3. Ensure proper dimension pairing for rotation
"""
import torch
import torch.nn as nn
import math
from typing import Optional, Tuple
class RotaryPositionEmbeddings(nn.Module):
"""Rotary Position Embeddings (RoPE) for transformer models.
Based on the paper: 'RoFormer: Enhanced Transformer with Rotary Position Embedding'
https://arxiv.org/abs/2104.09864
"""
def __init__(
self,
head_dim: int,
max_seq_len: int = 2048,
base: int = 10000,
device: Optional[torch.device] = None,
):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
# CRITICAL: head_dim must be even for proper pairing
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
# Precompute frequencies
self._precompute_freqs(device)
def _precompute_freqs(self, device: Optional[torch.device] = None):
"""Precompute the frequency tensor for RoPE."""
# Calculate theta frequencies
# theta_i = base^(-2i/d) for i in [0, 1, ..., d/2-1]
theta = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim))
# Create position indices
positions = torch.arange(self.max_seq_len).float()
# Compute outer product: [seq_len, head_dim/2]
freqs = torch.einsum('i,j->ij', positions, theta)
# Convert to cos and sin for rotation
freqs_cos = torch.cos(freqs) # [seq_len, head_dim/2]
freqs_sin = torch.sin(freqs) # [seq_len, head_dim/2]
# Duplicate for full dimension coverage
# [seq_len, head_dim/2] -> [seq_len, head_dim]
freqs_cos = torch.cat([freqs_cos, freqs_cos], dim=-1)
freqs_sin = torch.cat([freqs_sin, freqs_sin], dim=-1)
# Register as buffers (not trainable, moves with model to device)
self.register_buffer('freqs_cos', freqs_cos, persistent=False)
self.register_buffer('freqs_sin', freqs_sin, persistent=False)
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input.
CRITICAL: This is the most common bug - incorrect dimension pairing.
For input [1, 2, 3, 4], output should be [-3, -4, 1, 2].
"""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat([-x2, x1], dim=-1)
def apply_rotary_pos_emb(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings to query and key tensors.
Args:
q: Query tensor of shape [batch, seq_len, num_heads, head_dim]
k: Key tensor of shape [batch, seq_len, num_heads, head_dim]
position_ids: Optional custom position IDs
Returns:
Tuple of rotated (q, k) tensors
"""
seq_len = q.shape[1]
# Get the frequency tensors for current sequence length
if position_ids is not None:
freqs_cos = self.freqs_cos[position_ids]
freqs_sin = self.freqs_sin[position_ids]
else:
freqs_cos = self.freqs_cos[:seq_len]
freqs_sin = self.freqs_sin[:seq_len]
# Reshape for broadcasting
# [seq_len, head_dim] -> [1, seq_len, 1, head_dim]
freqs_cos = freqs_cos[None, :, None, :]
freqs_sin = freqs_sin[None, :, None, :]
# Apply rotation using the formula:
# x_rotated = x * cos + rotate_half(x) * sin
q_rotated = q * freqs_cos + self.rotate_half(q) * freqs_sin
k_rotated = k * freqs_cos + self.rotate_half(k) * freqs_sin
return q_rotated, k_rotated
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass - apply RoPE to Q and K only.
CRITICAL: Never apply RoPE to V (value) tensor!
"""
return self.apply_rotary_pos_emb(q, k, position_ids)
# Alternative implementation using complex numbers directly
class RotaryPositionEmbeddingsComplex(nn.Module):
"""Alternative RoPE implementation using complex number operations.
This can be more efficient on some hardware but requires careful handling.
"""
def __init__(
self,
head_dim: int,
max_seq_len: int = 2048,
base: int = 10000,
device: Optional[torch.device] = None,
):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.base = base
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"
# Precompute complex exponentials
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len, dtype=inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, inv_freq)
# Store as cos/sin values
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer('cos_cached', emb.cos()[None, :, None, :])
self.register_buffer('sin_cached', emb.sin()[None, :, None, :])
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
seq_len: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply RoPE using cached cos/sin values."""
if seq_len is None:
seq_len = q.shape[1]
# Apply rotation
q_embed = (q * self.cos_cached[:, :seq_len]) + \
(self.rotate_half(q) * self.sin_cached[:, :seq_len])
k_embed = (k * self.cos_cached[:, :seq_len]) + \
(self.rotate_half(k) * self.sin_cached[:, :seq_len])
return q_embed, k_embed
def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims."""
x1, x2 = x.chunk(2, dim=-1)
return torch.cat([-x2, x1], dim=-1)
# Test function for RoPE
def test_rope():
"""Test RoPE implementation."""
print("Testing RoPE implementation...")
batch_size = 2
seq_len = 128
n_heads = 12
head_dim = 64
# Create RoPE module
rope = RotaryPositionEmbeddings(head_dim=head_dim, max_seq_len=2048)
# Create dummy Q and K tensors
q = torch.randn(batch_size, seq_len, n_heads, head_dim)
k = torch.randn(batch_size, seq_len, n_heads, head_dim)
# Apply RoPE
q_rot, k_rot = rope(q, k)
# Check shapes
assert q_rot.shape == q.shape, f"Q shape mismatch: {q_rot.shape} != {q.shape}"
assert k_rot.shape == k.shape, f"K shape mismatch: {k_rot.shape} != {k.shape}"
# Check for NaN
assert not torch.isnan(q_rot).any(), "Q contains NaN after RoPE"
assert not torch.isnan(k_rot).any(), "K contains NaN after RoPE"
print("✓ RoPE test passed!")
print(f" Input shape: {q.shape}")
print(f" Output shape: {q_rot.shape}")
return True
if __name__ == "__main__":
test_rope()