File size: 5,572 Bytes
7f974df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
model/rope.py

Rotary Position Embedding (RoPE) — Su et al. 2021 (RoFormer).
Used in LLaMA, Mistral, Gemma, etc.

Core idea:
    Instead of adding position embeddings to token vectors, we ROTATE
    the query and key vectors in attention using position-dependent angles.

    - Relative positions are encoded implicitly via dot-product invariance.
    - Works for any sequence length (extrapolates beyond training length).
    - Only applied to Q and K, NOT V.

Implementation:
    1. Precompute cos/sin tables for all positions up to max_seq_len.
       Shape: (max_seq_len, head_dim)

    2. At forward time, slice cos/sin to the current seq_len and
       apply rotation to Q and K.

Rotation formula (pairs of dims):
    Given a vector x with dims [x0, x1, x2, x3, ...]:
    Pair each consecutive two dims:  (x0,x1), (x2,x3), ...
    Rotate each pair by angle theta_i * position:
        [x0*cos - x1*sin,  x0*sin + x1*cos, ...]

    Equivalent implementation using rotate_half:
        rotated = concat([-x_second_half, x_first_half])  # swapped halves
        out = x * cos + rotated * sin
"""

import torch
import torch.nn as nn
from typing import Tuple


def precompute_rope_freqs(
    head_dim: int,
    max_seq_len: int,
    theta: float = 10_000.0,
    device: torch.device = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Precompute RoPE cosine and sine tables.

    Args:
        head_dim    : dimension of each attention head (must be even)
        max_seq_len : max sequence length to precompute
        theta       : RoPE base frequency (default 10_000, use 500_000 for long context)
        device      : torch device

    Returns:
        cos : (max_seq_len, head_dim)
        sin : (max_seq_len, head_dim)
    """
    assert head_dim % 2 == 0, f"head_dim must be even, got {head_dim}"

    # Inverse frequencies: shape (head_dim // 2,)
    # inv_freq[i] = 1 / theta^(2i / head_dim)
    i        = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
    inv_freq = 1.0 / (theta ** (i / head_dim))

    # Position indices: shape (max_seq_len,)
    positions = torch.arange(max_seq_len, dtype=torch.float32, device=device)

    # Outer product: (max_seq_len, head_dim // 2)
    freqs = torch.outer(positions, inv_freq)

    # Duplicate along last dim to match head_dim:
    # (max_seq_len, head_dim // 2) -> (max_seq_len, head_dim)
    # cos/sin applied to [x0,x1,x2,x3,...] as [theta0,theta0, theta1,theta1, ...]
    freqs = torch.cat([freqs, freqs], dim=-1)

    return freqs.cos(), freqs.sin()


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """
    Rotates pairs of dimensions in the last axis.
    Splits last dim in half, negates the second half, then swaps:
        [x0..xN/2, xN/2..xN]  ->  [-xN/2..xN, x0..xN/2]

    Args:
        x: (..., head_dim)
    Returns:
        rotated: (..., head_dim)
    """
    half = x.shape[-1] // 2
    x1 = x[..., :half]     # first half
    x2 = x[..., half:]     # second half
    return torch.cat([-x2, x1], dim=-1)


def apply_rope(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply RoPE rotation to query and key tensors.

    Args:
        q   : (B, n_heads, T, head_dim)
        k   : (B, n_heads, T, head_dim)
        cos : (T, head_dim)  - precomputed from precompute_rope_freqs
        sin : (T, head_dim)  - precomputed from precompute_rope_freqs

    Returns:
        q_rot, k_rot : same shapes as inputs
    """
    # Broadcast cos/sin from (T, head_dim) to (1, 1, T, head_dim)
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)

    q_rot = (q * cos) + (rotate_half(q) * sin)
    k_rot = (k * cos) + (rotate_half(k) * sin)
    return q_rot, k_rot


class RoPECache(nn.Module):
    """
    Module that holds the RoPE cos/sin cache as a buffer.
    Not a learnable module — just stores precomputed freqs and moves them
    to the right device automatically via register_buffer.
    """

    def __init__(self, head_dim: int, max_seq_len: int, theta: float = 10_000.0):
        super().__init__()
        cos, sin = precompute_rope_freqs(head_dim, max_seq_len, theta)
        # register_buffer: not a parameter, but moves with .to(device)
        self.register_buffer("cos", cos, persistent=True)
        self.register_buffer("sin", sin, persistent=True)

    def get(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Slice cos/sin to current sequence length."""
        return self.cos[:seq_len], self.sin[:seq_len]


# ------------------------------------------------------------------ #
#  QUICK CHECK
# ------------------------------------------------------------------ #

if __name__ == "__main__":
    torch.manual_seed(0)

    B, n_heads, T, head_dim = 2, 12, 16, 64

    cos, sin = precompute_rope_freqs(head_dim, max_seq_len=1024)
    cos_T    = cos[:T]
    sin_T    = sin[:T]

    q = torch.randn(B, n_heads, T, head_dim)
    k = torch.randn(B, n_heads, T, head_dim)

    q_rot, k_rot = apply_rope(q, k, cos_T, sin_T)

    print(f"q shape     : {q.shape}")
    print(f"q_rot shape : {q_rot.shape}")
    print(f"k_rot shape : {k_rot.shape}")

    # Verify: rotation should preserve norm (|x| = |Rx|)
    q_norm     = q.norm(dim=-1)
    q_rot_norm = q_rot.norm(dim=-1)
    print(f"Norm preserved (q): {torch.allclose(q_norm, q_rot_norm, atol=1e-5)}")

    # Test RoPECache
    cache = RoPECache(head_dim=64, max_seq_len=1024)
    c, s  = cache.get(T)
    print(f"Cache cos shape: {c.shape}")
    print("PASS")