|
|
"""Rotary Position Embeddings (RoPE) implementation. |
|
|
|
|
|
Critical implementation details: |
|
|
1. Apply RoPE only to Q and K, never to V |
|
|
2. Use head_dim, not full model dimension |
|
|
3. Ensure proper dimension pairing for rotation |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
|
|
|
class RotaryPositionEmbeddings(nn.Module): |
|
|
"""Rotary Position Embeddings (RoPE) for transformer models. |
|
|
|
|
|
Based on the paper: 'RoFormer: Enhanced Transformer with Rotary Position Embedding' |
|
|
https://arxiv.org/abs/2104.09864 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
head_dim: int, |
|
|
max_seq_len: int = 2048, |
|
|
base: int = 10000, |
|
|
device: Optional[torch.device] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.head_dim = head_dim |
|
|
self.max_seq_len = max_seq_len |
|
|
self.base = base |
|
|
|
|
|
|
|
|
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}" |
|
|
|
|
|
|
|
|
self._precompute_freqs(device) |
|
|
|
|
|
def _precompute_freqs(self, device: Optional[torch.device] = None): |
|
|
"""Precompute the frequency tensor for RoPE.""" |
|
|
|
|
|
|
|
|
theta = 1.0 / (self.base ** (torch.arange(0, self.head_dim, 2).float() / self.head_dim)) |
|
|
|
|
|
|
|
|
positions = torch.arange(self.max_seq_len).float() |
|
|
|
|
|
|
|
|
freqs = torch.einsum('i,j->ij', positions, theta) |
|
|
|
|
|
|
|
|
freqs_cos = torch.cos(freqs) |
|
|
freqs_sin = torch.sin(freqs) |
|
|
|
|
|
|
|
|
|
|
|
freqs_cos = torch.cat([freqs_cos, freqs_cos], dim=-1) |
|
|
freqs_sin = torch.cat([freqs_sin, freqs_sin], dim=-1) |
|
|
|
|
|
|
|
|
self.register_buffer('freqs_cos', freqs_cos, persistent=False) |
|
|
self.register_buffer('freqs_sin', freqs_sin, persistent=False) |
|
|
|
|
|
def rotate_half(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Rotate half the hidden dims of the input. |
|
|
|
|
|
CRITICAL: This is the most common bug - incorrect dimension pairing. |
|
|
For input [1, 2, 3, 4], output should be [-3, -4, 1, 2]. |
|
|
""" |
|
|
x1 = x[..., :x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2:] |
|
|
return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Apply rotary position embeddings to query and key tensors. |
|
|
|
|
|
Args: |
|
|
q: Query tensor of shape [batch, seq_len, num_heads, head_dim] |
|
|
k: Key tensor of shape [batch, seq_len, num_heads, head_dim] |
|
|
position_ids: Optional custom position IDs |
|
|
|
|
|
Returns: |
|
|
Tuple of rotated (q, k) tensors |
|
|
""" |
|
|
seq_len = q.shape[1] |
|
|
|
|
|
|
|
|
if position_ids is not None: |
|
|
freqs_cos = self.freqs_cos[position_ids] |
|
|
freqs_sin = self.freqs_sin[position_ids] |
|
|
else: |
|
|
freqs_cos = self.freqs_cos[:seq_len] |
|
|
freqs_sin = self.freqs_sin[:seq_len] |
|
|
|
|
|
|
|
|
|
|
|
freqs_cos = freqs_cos[None, :, None, :] |
|
|
freqs_sin = freqs_sin[None, :, None, :] |
|
|
|
|
|
|
|
|
|
|
|
q_rotated = q * freqs_cos + self.rotate_half(q) * freqs_sin |
|
|
k_rotated = k * freqs_cos + self.rotate_half(k) * freqs_sin |
|
|
|
|
|
return q_rotated, k_rotated |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Forward pass - apply RoPE to Q and K only. |
|
|
|
|
|
CRITICAL: Never apply RoPE to V (value) tensor! |
|
|
""" |
|
|
return self.apply_rotary_pos_emb(q, k, position_ids) |
|
|
|
|
|
|
|
|
|
|
|
class RotaryPositionEmbeddingsComplex(nn.Module): |
|
|
"""Alternative RoPE implementation using complex number operations. |
|
|
|
|
|
This can be more efficient on some hardware but requires careful handling. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
head_dim: int, |
|
|
max_seq_len: int = 2048, |
|
|
base: int = 10000, |
|
|
device: Optional[torch.device] = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.head_dim = head_dim |
|
|
self.max_seq_len = max_seq_len |
|
|
self.base = base |
|
|
|
|
|
assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}" |
|
|
|
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim)) |
|
|
t = torch.arange(max_seq_len, dtype=inv_freq.dtype) |
|
|
freqs = torch.einsum('i,j->ij', t, inv_freq) |
|
|
|
|
|
|
|
|
emb = torch.cat([freqs, freqs], dim=-1) |
|
|
self.register_buffer('cos_cached', emb.cos()[None, :, None, :]) |
|
|
self.register_buffer('sin_cached', emb.sin()[None, :, None, :]) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
seq_len: Optional[int] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Apply RoPE using cached cos/sin values.""" |
|
|
if seq_len is None: |
|
|
seq_len = q.shape[1] |
|
|
|
|
|
|
|
|
q_embed = (q * self.cos_cached[:, :seq_len]) + \ |
|
|
(self.rotate_half(q) * self.sin_cached[:, :seq_len]) |
|
|
k_embed = (k * self.cos_cached[:, :seq_len]) + \ |
|
|
(self.rotate_half(k) * self.sin_cached[:, :seq_len]) |
|
|
|
|
|
return q_embed, k_embed |
|
|
|
|
|
def rotate_half(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Rotate half the hidden dims.""" |
|
|
x1, x2 = x.chunk(2, dim=-1) |
|
|
return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
def test_rope(): |
|
|
"""Test RoPE implementation.""" |
|
|
print("Testing RoPE implementation...") |
|
|
|
|
|
batch_size = 2 |
|
|
seq_len = 128 |
|
|
n_heads = 12 |
|
|
head_dim = 64 |
|
|
|
|
|
|
|
|
rope = RotaryPositionEmbeddings(head_dim=head_dim, max_seq_len=2048) |
|
|
|
|
|
|
|
|
q = torch.randn(batch_size, seq_len, n_heads, head_dim) |
|
|
k = torch.randn(batch_size, seq_len, n_heads, head_dim) |
|
|
|
|
|
|
|
|
q_rot, k_rot = rope(q, k) |
|
|
|
|
|
|
|
|
assert q_rot.shape == q.shape, f"Q shape mismatch: {q_rot.shape} != {q.shape}" |
|
|
assert k_rot.shape == k.shape, f"K shape mismatch: {k_rot.shape} != {k.shape}" |
|
|
|
|
|
|
|
|
assert not torch.isnan(q_rot).any(), "Q contains NaN after RoPE" |
|
|
assert not torch.isnan(k_rot).any(), "K contains NaN after RoPE" |
|
|
|
|
|
print("✓ RoPE test passed!") |
|
|
print(f" Input shape: {q.shape}") |
|
|
print(f" Output shape: {q_rot.shape}") |
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
test_rope() |