| """ |
| Rotary Position Embedding (RoPE) for FrawdLLM. |
| |
| RoPE encodes position by rotating the Q and K vectors. This has several advantages: |
| 1. No learned position embeddings (saves parameters) |
| 2. Better length generalization (can extrapolate beyond training length) |
| 3. Relative position encoding (attention depends on distance, not absolute position) |
| |
| How it works: |
| - Each position gets a rotation angle based on its index |
| - Q and K are rotated by their position's angle |
| - The dot product Q·K then naturally encodes relative distance |
| |
| Reference: https://arxiv.org/abs/2104.09864 |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import math |
|
|
|
|
| def precompute_freqs(dim: int, max_seq_len: int, theta: float = 10000.0) -> torch.Tensor: |
| """ |
| Precompute the frequency tensor for RoPE. |
| |
| Args: |
| dim: Dimension of each head (must be even) |
| max_seq_len: Maximum sequence length |
| theta: Base for frequency computation (10000 is standard) |
| |
| Returns: |
| Complex tensor of shape [max_seq_len, dim//2] containing rotation frequencies |
| """ |
| |
| |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
| |
| positions = torch.arange(max_seq_len) |
|
|
| |
| |
| angles = torch.outer(positions, freqs) |
|
|
| |
| |
| freqs_complex = torch.polar(torch.ones_like(angles), angles) |
|
|
| return freqs_complex |
|
|
|
|
| def apply_rope( |
| x: torch.Tensor, |
| freqs: torch.Tensor, |
| start_pos: int = 0, |
| ) -> torch.Tensor: |
| """ |
| Apply rotary position embedding to Q or K tensor. |
| |
| Args: |
| x: [batch, n_head, seq_len, head_dim] - Q or K tensor |
| freqs: [max_seq_len, head_dim//2] - precomputed frequencies |
| start_pos: Starting position (for KV cache during generation) |
| |
| Returns: |
| Rotated tensor with same shape as input |
| """ |
| batch, n_head, seq_len, head_dim = x.shape |
|
|
| |
| |
| seq_freqs = freqs[start_pos:start_pos + seq_len] |
|
|
| |
| |
| x_pairs = x.float().reshape(batch, n_head, seq_len, -1, 2) |
|
|
| |
| x_complex = torch.view_as_complex(x_pairs) |
|
|
| |
| seq_freqs = seq_freqs.unsqueeze(0).unsqueeze(0) |
|
|
| |
| x_rotated = x_complex * seq_freqs |
|
|
| |
| x_out = torch.view_as_real(x_rotated) |
|
|
| |
| x_out = x_out.reshape(batch, n_head, seq_len, head_dim) |
|
|
| return x_out.type_as(x) |
|
|
|
|
| class RotaryEmbedding(nn.Module): |
| """ |
| Module wrapper for rotary embeddings. |
| |
| Precomputes and caches the frequency tensor. |
| """ |
|
|
| def __init__(self, dim: int, max_seq_len: int = 4096, theta: float = 10000.0): |
| super().__init__() |
| self.dim = dim |
| self.max_seq_len = max_seq_len |
| self.theta = theta |
|
|
| |
| freqs = precompute_freqs(dim, max_seq_len, theta) |
| self.register_buffer("freqs", freqs, persistent=False) |
|
|
| def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: |
| """Apply RoPE to input tensor.""" |
| return apply_rope(x, self.freqs, start_pos) |
|
|
|
|
| if __name__ == "__main__": |
| print("Testing RoPE...") |
| print("=" * 50) |
|
|
| |
| batch, n_head, seq_len, head_dim = 2, 4, 16, 64 |
|
|
| |
| rope = RotaryEmbedding(dim=head_dim, max_seq_len=512) |
|
|
| |
| q = torch.randn(batch, n_head, seq_len, head_dim) |
| k = torch.randn(batch, n_head, seq_len, head_dim) |
|
|
| print(f"Input shape: {q.shape}") |
|
|
| |
| q_rotated = rope(q) |
| k_rotated = rope(k) |
|
|
| print(f"Output shape: {q_rotated.shape}") |
|
|
| |
| |
| print("\nVerifying relative position property...") |
|
|
| |
| attn_0_1 = (q_rotated[:, :, 0:1, :] @ k_rotated[:, :, 1:2, :].transpose(-2, -1)) |
| attn_5_6 = (q_rotated[:, :, 5:6, :] @ k_rotated[:, :, 6:7, :].transpose(-2, -1)) |
|
|
| |
| diff = (attn_0_1 - attn_5_6).abs().mean().item() |
| print(f" Attention (0,1) vs (5,6) difference: {diff:.6f}") |
| print(f" (Should be very small - same relative distance)") |
|
|
| print("\nRoPE working!") |
|
|