"""Rotary Position Embeddings (RoPE). RoPE encodes position in the *relationship* between query and key vectors. When the attention dot product Q·Kᵀ is computed, the per-position rotations cancel to produce a score that depends only on the relative distance — not on absolute positions. Two modes are supported: default Standard RoPE with base frequency b. Each dimension pair d is assigned frequency θ_d = b^{-2d/u} where u is the head dimension. The attention scaling A_rope = 1. yarn YaRN frequency interpolation for long-context extrapolation (Peng et al., "YaRN: Efficient Context Window Extension of Large Language Models", 2023, §A.2). Three frequency regimes: - Low-frequency dimensions (r < α): fully interpolated by scale s. These dimensions have long wavelengths relative to the training window and must be compressed to avoid out-of-distribution positions. - High-frequency dimensions (r > β): left unchanged. Short-wavelength dimensions already encode relative position accurately at any scale. - Intermediate dimensions (α ≤ r ≤ β): linearly blended via ramp γ(r). Returns A_rope = (0.1·ln(s)+1)². When s = 1, YaRN reduces exactly to standard RoPE. Each attention path (h_l and BEA) constructs its own RotaryEmbedding with explicit parameters — no shared instance, no config reading. See Unit 5.A design decisions. Cache sharing: all instances with identical parameters share one cos/sin table via a class-level registry. The first instance that needs a particular (parameters, seq_len, device, dtype) combination builds the table; all subsequent instances reference it directly. This avoids redundant builds across the num_hidden_layers instances that share the same parametrisation. """ import math import torch import torch.nn as nn # --------------------------------------------------------------------------- # Rotation helper # --------------------------------------------------------------------------- def _rotate_half(x: torch.Tensor) -> torch.Tensor: """Apply the 90° rotation used in the RoPE update formula. Splits the last dimension into two halves [x1, x2] and returns [-x2, x1]. Combined with ``x * cos + rotate_half(x) * sin``, this implements a 2D rotation on each consecutive pair of dimensions, matching the block-diagonal operator R^u_{Θ,p} in the paper. """ d = x.shape[-1] // 2 x1, x2 = x[..., :d], x[..., d:] return torch.cat([-x2, x1], dim=-1) # --------------------------------------------------------------------------- # RotaryEmbedding # --------------------------------------------------------------------------- class RotaryEmbedding(nn.Module): """Rotary Position Embeddings with explicit mode and parameter control. Each caller constructs its own instance with the exact parameters it needs. h_l always uses ``mode="default"``; BEA always uses ``mode="yarn"``. No config object is read inside this module. The cos/sin cache is built lazily on the first forward call and extended automatically when a longer sequence is encountered. Instances with identical parameters share one cache via the class-level ``_cache`` registry, avoiding redundant computation across decoder layers. Args: mode: ``"default"`` for standard RoPE; ``"yarn"`` for YaRN extrapolation. head_dim: Per-head embedding dimension ``u``. Must be even. theta: Base frequency ``b`` in θ_d = b^{-2d/u}. initial_seq_length: ``C_train`` — context length the model was trained at. Required for ``mode="yarn"``. dilation: Scale factor ``s = C_target / C_train`` — how much the context window is extended beyond training length. Required for ``mode="yarn"``. When ``dilation=1.0``, YaRN reduces to standard RoPE. alpha: YaRN ramp lower boundary α. Dimensions with r(d) < α are fully interpolated. Required for ``mode="yarn"``. beta: YaRN ramp upper boundary β. Dimensions with r(d) > β are left unchanged. Required for ``mode="yarn"``. device: Optional device for initial buffer placement. Raises: NotImplementedError: If ``mode`` is not ``"default"`` or ``"yarn"``. ValueError: If ``mode="yarn"`` and any of ``initial_seq_length``, ``dilation``, ``alpha``, ``beta`` are absent. """ # Maps (freq_key, seq_len, device_str, dtype_str) → (cos_table, sin_table). # Shared across all RotaryEmbedding instances in the process. Keys include device # and dtype so that tables built on different devices or in different precisions # are stored independently. _cache: dict = {} def __init__( self, mode: str, head_dim: int, theta: float, initial_seq_length: int | None = None, dilation: float | None = None, alpha: float | None = None, beta: float | None = None, device: torch.device | None = None, ) -> None: super().__init__() self._validate_mode(mode) self._validate_yarn_params(mode, initial_seq_length, dilation, alpha, beta) self.mode = mode # Compute per-dimension rotation frequencies θ_d (default) or θ_d' (yarn). # d_index ranges over 0, 2, 4, ..., head_dim-2 — one index per dimension pair, # so rotation_freqs has head_dim/2 entries. d_index = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device) base_freqs = 1.0 / (theta ** (d_index / head_dim)) # θ_d = b^{-2d/u} if mode == "default": rotation_freqs = base_freqs self.attention_scaling: float = 1.0 else: # yarn s = dilation # r(d) = C_train · θ_d / (2π) — normalized frequency used by the ramp # function to classify each dimension into one of three regimes. normalized_freqs = initial_seq_length * base_freqs / (2.0 * math.pi) # γ(r) ramp: 0 for r < α (fully interpolate), 1 for r > β (unchanged), # linear blend between α and β. blend_weights = ((normalized_freqs - alpha) / (beta - alpha)).clamp(0.0, 1.0) # θ_d' = (1 − γ) · θ_d / s + γ · θ_d rotation_freqs = (1.0 - blend_weights) * (base_freqs / s) + blend_weights * base_freqs # A_rope = (0.1 · ln(s) + 1)² — attention logit scaling returned to caller. self.attention_scaling = (0.1 * math.log(s) + 1.0) ** 2 # freq_key uniquely identifies the parameter set that produced rotation_freqs. # Used as the primary component of the cache registry key. if mode == "default": self._freq_key: tuple = ("default", head_dim, float(theta)) else: self._freq_key = ( "yarn", head_dim, float(theta), int(initial_seq_length), float(dilation), float(alpha), float(beta), ) # rotation_freqs is a non-persistent buffer so it moves with the model across # devices via .to() / .cuda() without appearing in saved checkpoints. # It is stored per-instance rather than in the shared cache because it is # small (head_dim/2 floats) — negligible cost compared to the cos/sin tables # it is used to build. The meaningful sharing win is on those tables. self.register_buffer("rotation_freqs", rotation_freqs, persistent=False) # Cache tensors are plain instance attributes (not registered buffers) so that # sharing across identically-parametrised instances survives .to() calls. # Registered buffers are copied on device move; plain attributes are aliased, # preserving the shared-tensor identity that the cache design depends on. self._cos_cached: torch.Tensor | None = None self._sin_cached: torch.Tensor | None = None # --------------------------------------------------------------------------- # Validation helpers # --------------------------------------------------------------------------- @staticmethod def _validate_mode(mode: str) -> None: """Raise NotImplementedError if mode is not a supported value.""" if mode not in {"default", "yarn"}: raise NotImplementedError( f"RoPE mode '{mode}' is not supported. Supported modes: 'default', 'yarn'." ) @staticmethod def _validate_yarn_params( mode: str, initial_seq_length: int | None, dilation: float | None, alpha: float | None, beta: float | None, ) -> None: """Raise ValueError if mode='yarn' and any required parameter is absent.""" if mode != "yarn": return missing = [ name for name, val in [ ("initial_seq_length", initial_seq_length), ("dilation", dilation), ("alpha", alpha), ("beta", beta), ] if val is None ] if missing: raise ValueError(f"mode='yarn' requires {missing}.") # --------------------------------------------------------------------------- # Cache management # --------------------------------------------------------------------------- def _extend_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> None: """Build the cos/sin table to cover positions [0, seq_len). Checks the class-level registry first. If a table already exists for this exact (parameters, seq_len, device, dtype) combination it is reused directly; otherwise it is computed and stored. The instance attributes are pointed at the registry entry so that all layers sharing the same parametrisation reference the same tensor. """ cache_key = (self._freq_key, seq_len, str(device), str(dtype)) if cache_key not in RotaryEmbedding._cache: positions = torch.arange(seq_len, device=device, dtype=torch.float32) # outer product → (seq_len, head_dim // 2); duplicate to (seq_len, head_dim) freqs = torch.outer( positions, self.rotation_freqs.to(device=device, dtype=torch.float32), ) angle_embedding = torch.cat((freqs, freqs), dim=-1) RotaryEmbedding._cache[cache_key] = ( angle_embedding.cos().to(dtype), angle_embedding.sin().to(dtype), ) self._cos_cached, self._sin_cached = RotaryEmbedding._cache[cache_key] def forward( self, q: torch.Tensor, k: torch.Tensor, position_ids: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, float]: """Apply rotary embeddings to query and key tensors. The cos/sin cache is extended lazily when position_ids reference positions beyond its current length, or when the device or dtype has changed. ``position_ids`` may be any integer tensor shape. Its values are valid position indices into the cos/sin cache: - h_l (standard causal): position_ids (B, N), q/k (B, H, N, head_dim). - BEA (packed): position_ids (B, L, T), q/k (B, L, T, head_dim). When q/k have head dimensions absent from position_ids, broadcast dimensions are inserted automatically at dim 1. Args: q: Query tensor of shape (batch, [heads,] *pos_dims, head_dim). k: Key tensor of shape (batch, [heads,] *pos_dims, head_dim). position_ids: Integer positions of shape (batch, *pos_dims). Returns: Tuple of (q_rotated, k_rotated, attention_scaling). attention_scaling is 1.0 for default mode; YaRN returns (0.1·ln(s)+1)² which the caller must apply to attention logits before softmax. """ seq_len = int(position_ids.max().item()) + 1 # The cache is valid when it exists, covers all positions referenced by # position_ids, and matches q's dtype and device. Each condition is named # separately so the rebuild trigger is readable rather than a compound predicate. cache_missing = self._cos_cached is None cache_too_short = not cache_missing and seq_len > self._cos_cached.shape[0] wrong_dtype = not cache_missing and self._cos_cached.dtype != q.dtype wrong_device = not cache_missing and self._cos_cached.device != q.device if cache_missing or cache_too_short or wrong_dtype or wrong_device: self._extend_cache(seq_len, device=q.device, dtype=q.dtype) cos = self._cos_cached[position_ids] sin = self._sin_cached[position_ids] # Insert broadcast dimensions for any head axes present in q/k but absent # from position_ids. Standard: pos (B,N) → cos (B,N,D), q (B,H,N,D) → unsqueeze once. # BEA: pos (B,L,T) → cos (B,L,T,D), q (B,L,T,D) → no unsqueeze needed. while cos.ndim < q.ndim: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) q_rotated = q * cos + _rotate_half(q) * sin k_rotated = k * cos + _rotate_half(k) * sin return q_rotated, k_rotated, self.attention_scaling