File size: 13,441 Bytes
7bf638f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
"""Rotary Position Embeddings (RoPE).

RoPE encodes position in the *relationship* between query and key vectors. When the
attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce
a score that depends only on the relative distance — not on absolute positions.

Two modes are supported:

  default  Standard RoPE with base frequency b. Each dimension pair d is assigned
           frequency θ_d = b^{-2d/u} where u is the head dimension. The attention
           scaling A_rope = 1.

  yarn     YaRN frequency interpolation for long-context extrapolation (Peng et al.,
           "YaRN: Efficient Context Window Extension of Large Language Models", 2023,
           §A.2). Three frequency regimes:
             - Low-frequency dimensions (r < α): fully interpolated by scale s.
               These dimensions have long wavelengths relative to the training window
               and must be compressed to avoid out-of-distribution positions.
             - High-frequency dimensions (r > β): left unchanged. Short-wavelength
               dimensions already encode relative position accurately at any scale.
             - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r).
           Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to
           standard RoPE.

Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit
parameters — no shared instance, no config reading. See Unit 5.A design decisions.

Cache sharing: all instances with identical parameters share one cos/sin table via a
class-level registry. The first instance that needs a particular (parameters, seq_len,
device, dtype) combination builds the table; all subsequent instances reference it
directly. This avoids redundant builds across the num_hidden_layers instances that
share the same parametrisation.
"""

import math

import torch
import torch.nn as nn


# ---------------------------------------------------------------------------
# Rotation helper
# ---------------------------------------------------------------------------

def _rotate_half(x: torch.Tensor) -> torch.Tensor:
    """Apply the 90° rotation used in the RoPE update formula.

    Splits the last dimension into two halves [x1, x2] and returns [-x2, x1].
    Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation
    on each consecutive pair of dimensions, matching the block-diagonal operator
    R^u_{Θ,p} in the paper.
    """
    d = x.shape[-1] // 2
    x1, x2 = x[..., :d], x[..., d:]
    return torch.cat([-x2, x1], dim=-1)


# ---------------------------------------------------------------------------
# RotaryEmbedding
# ---------------------------------------------------------------------------

class RotaryEmbedding(nn.Module):
    """Rotary Position Embeddings with explicit mode and parameter control.

    Each caller constructs its own instance with the exact parameters it needs.
    h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No
    config object is read inside this module.

    The cos/sin cache is built lazily on the first forward call and extended
    automatically when a longer sequence is encountered. Instances with identical
    parameters share one cache via the class-level ``_cache`` registry,
    avoiding redundant computation across decoder layers.

    Args:
        mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation.
        head_dim: Per-head embedding dimension ``u``. Must be even.
        theta: Base frequency ``b`` in θ_d = b^{-2d/u}.
        initial_seq_length: ``C_train`` — context length the model was trained at.
            Required for ``mode="yarn"``.
        dilation: Scale factor ``s = C_target / C_train`` — how much the context
            window is extended beyond training length. Required for ``mode="yarn"``.
            When ``dilation=1.0``, YaRN reduces to standard RoPE.
        alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully
            interpolated. Required for ``mode="yarn"``.
        beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left
            unchanged. Required for ``mode="yarn"``.
        device: Optional device for initial buffer placement.

    Raises:
        NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``.
        ValueError: If ``mode="yarn"`` and any of ``initial_seq_length``,
            ``dilation``, ``alpha``, ``beta`` are absent.
    """

    # Maps (freq_key, seq_len, device_str, dtype_str) → (cos_table, sin_table).
    # Shared across all RotaryEmbedding instances in the process. Keys include device
    # and dtype so that tables built on different devices or in different precisions
    # are stored independently.
    _cache: dict = {}

    def __init__(
        self,
        mode: str,
        head_dim: int,
        theta: float,
        initial_seq_length: int | None = None,
        dilation: float | None = None,
        alpha: float | None = None,
        beta: float | None = None,
        device: torch.device | None = None,
    ) -> None:
        super().__init__()

        self._validate_mode(mode)
        self._validate_yarn_params(mode, initial_seq_length, dilation, alpha, beta)
        self.mode = mode

        # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn).
        # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair,
        # so rotation_freqs has head_dim/2 entries.
        d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
        base_freqs = 1.0 / (theta ** (d_index / head_dim))  # θ_d = b^{-2d/u}

        if mode == "default":
            rotation_freqs = base_freqs
            self.attention_scaling: float = 1.0

        else:  # yarn
            s = dilation

            # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp
            # function to classify each dimension into one of three regimes.
            normalized_freqs = initial_seq_length * base_freqs / (2.0 * math.pi)

            # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged),
            # linear blend between α and β.
            blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0)

            # θ_d' = (1 − γ) · θ_d / s + γ · θ_d
            rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs

            # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller.
            self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2

        # freq_key uniquely identifies the parameter set that produced rotation_freqs.
        # Used as the primary component of the cache registry key.
        if mode == "default":
            self._freq_key: tuple = ("default", head_dim, float(theta))
        else:
            self._freq_key = (
                "yarn", head_dim, float(theta),
                int(initial_seq_length), float(dilation),
                float(alpha), float(beta),
            )

        # rotation_freqs is a non-persistent buffer so it moves with the model across
        # devices via .to() / .cuda() without appearing in saved checkpoints.
        # It is stored per-instance rather than in the shared cache because it is
        # small (head_dim/2 floats) — negligible cost compared to the cos/sin tables
        # it is used to build. The meaningful sharing win is on those tables.
        self.register_buffer("rotation_freqs", rotation_freqs, persistent=False)

        # Cache tensors are plain instance attributes (not registered buffers) so that
        # sharing across identically-parametrised instances survives .to() calls.
        # Registered buffers are copied on device move; plain attributes are aliased,
        # preserving the shared-tensor identity that the cache design depends on.
        self._cos_cached: torch.Tensor | None = None
        self._sin_cached: torch.Tensor | None = None

    # ---------------------------------------------------------------------------
    # Validation helpers
    # ---------------------------------------------------------------------------

    @staticmethod
    def _validate_mode(mode: str) -> None:
        """Raise NotImplementedError if mode is not a supported value."""
        if mode not in {"default", "yarn"}:
            raise NotImplementedError(
                f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'."
            )

    @staticmethod
    def _validate_yarn_params(
        mode: str,
        initial_seq_length: int | None,
        dilation: float | None,
        alpha: float | None,
        beta: float | None,
    ) -> None:
        """Raise ValueError if mode='yarn' and any required parameter is absent."""
        if mode != "yarn":
            return
        missing = [
            name for name, val in [
                ("initial_seq_length", initial_seq_length),
                ("dilation", dilation),
                ("alpha", alpha),
                ("beta", beta),
            ]
            if val is None
        ]
        if missing:
            raise ValueError(f"mode='yarn' requires {missing}.")

    # ---------------------------------------------------------------------------
    # Cache management
    # ---------------------------------------------------------------------------

    def _extend_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None:
        """Build the cos/sin table to cover positions [0, seq_len).

        Checks the class-level registry first. If a table already exists for this
        exact (parameters, seq_len, device, dtype) combination it is reused directly;
        otherwise it is computed and stored. The instance attributes are pointed at
        the registry entry so that all layers sharing the same parametrisation
        reference the same tensor.
        """
        cache_key = (self._freq_key, seq_len, str(device), str(dtype))

        if cache_key not in RotaryEmbedding._cache:
            positions = torch.arange(seq_len, device=device, dtype=torch.float32)
            # outer product → (seq_len, head_dim // 2); duplicate to (seq_len, head_dim)
            freqs = torch.outer(
                positions,
                self.rotation_freqs.to(device=device, dtype=torch.float32),
            )
            angle_embedding = torch.cat((freqs, freqs), dim=-1)
            RotaryEmbedding._cache[cache_key] = (
                angle_embedding.cos().to(dtype),
                angle_embedding.sin().to(dtype),
            )

        self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key]

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, float]:
        """Apply rotary embeddings to query and key tensors.

        The cos/sin cache is extended lazily when position_ids reference positions
        beyond its current length, or when the device or dtype has changed.

        ``position_ids`` may be any integer tensor shape. Its values are valid
        position indices into the cos/sin cache:

        - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim).
        - BEA (packed):          position_ids (B, L, T), q/k (B, L, T, head_dim).

        When q/k have head dimensions absent from position_ids, broadcast dimensions
        are inserted automatically at dim 1.

        Args:
            q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim).
            k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim).
            position_ids: Integer positions of shape (batch, *pos_dims).

        Returns:
            Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is
            1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must
            apply to attention logits before softmax.
        """
        seq_len = int(position_ids.max().item()) + 1

        # The cache is valid when it exists, covers all positions referenced by
        # position_ids, and matches q's dtype and device. Each condition is named
        # separately so the rebuild trigger is readable rather than a compound predicate.
        cache_missing = self._cos_cached is None
        cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0]
        wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype
        wrong_device = not cache_missing and self._cos_cached.device != q.device

        if cache_missing or cache_too_short or wrong_dtype or wrong_device:
            self._extend_cache(seq_len, device=q.device, dtype=q.dtype)

        cos = self._cos_cached[position_ids]
        sin = self._sin_cached[position_ids]

        # Insert broadcast dimensions for any head axes present in q/k but absent
        # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once.
        # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed.
        while cos.ndim < q.ndim:
            cos = cos.unsqueeze(1)
            sin = sin.unsqueeze(1)

        q_rotated = q * cos + _rotate_half(q) * sin
        k_rotated = k * cos + _rotate_half(k) * sin

        return q_rotated, k_rotated, self.attention_scaling