| | """ |
| | Rotary Position Embedding (RoPE) implementation. |
| | Applied to Q and K only, with fixed base (no dynamic scaling). |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from typing import Tuple |
| |
|
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """Rotary Position Embedding (RoPE). |
| | |
| | RoPE encodes position information by rotating the query and key vectors. |
| | Key properties: |
| | - Parameter-free (no learnable embeddings) |
| | - Naturally encodes relative positions |
| | - Extrapolates well to longer sequences |
| | |
| | Reference: https://arxiv.org/abs/2104.09864 |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | max_position_embeddings: int = 1024, |
| | base: float = 10000.0, |
| | ): |
| | """Initialize RoPE. |
| | |
| | Args: |
| | dim: Dimension of the rotary embedding (usually head_dim) |
| | max_position_embeddings: Maximum sequence length |
| | base: Base for the frequency computation |
| | """ |
| | super().__init__() |
| | self.dim = dim |
| | self.max_position_embeddings = max_position_embeddings |
| | self.base = base |
| |
|
| | |
| | inv_freq = 1.0 / ( |
| | self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
| | ) |
| | self.register_buffer("inv_freq", inv_freq, persistent=False) |
| |
|
| | |
| | self._set_cos_sin_cache(max_position_embeddings) |
| |
|
| | def _set_cos_sin_cache(self, seq_len: int): |
| | """Precompute cos and sin values for positions.""" |
| | self.max_seq_len_cached = seq_len |
| | t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| |
|
| | |
| | freqs = torch.outer(t, self.inv_freq) |
| |
|
| | |
| | emb = torch.cat((freqs, freqs), dim=-1) |
| |
|
| | self.register_buffer("cos_cached", emb.cos(), persistent=False) |
| | self.register_buffer("sin_cached", emb.sin(), persistent=False) |
| |
|
| | def forward( |
| | self, |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | position_ids: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Apply rotary embeddings to query and key tensors. |
| | |
| | Args: |
| | q: Query tensor of shape [batch, num_heads, seq_len, head_dim] |
| | k: Key tensor of shape [batch, num_heads, seq_len, head_dim] |
| | position_ids: Position indices of shape [batch, seq_len] |
| | |
| | Returns: |
| | Tuple of (rotated_q, rotated_k) with same shapes as inputs |
| | """ |
| | seq_len = position_ids.max() + 1 |
| |
|
| | |
| | if seq_len > self.max_seq_len_cached: |
| | self._set_cos_sin_cache(seq_len) |
| |
|
| | |
| | |
| | cos = self.cos_cached[position_ids] |
| | sin = self.sin_cached[position_ids] |
| |
|
| | |
| | cos = cos.unsqueeze(1) |
| | sin = sin.unsqueeze(1) |
| |
|
| | |
| | q_embed = (q * cos) + (self._rotate_half(q) * sin) |
| | k_embed = (k * cos) + (self._rotate_half(k) * sin) |
| |
|
| | return q_embed, k_embed |
| |
|
| | @staticmethod |
| | def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
| | """Rotate half the hidden dims of the input. |
| | |
| | Splits the input into two halves and rotates: |
| | [x1, x2, x3, x4] -> [-x3, -x4, x1, x2] |
| | """ |
| | x1 = x[..., : x.shape[-1] // 2] |
| | x2 = x[..., x.shape[-1] // 2 :] |
| | return torch.cat((-x2, x1), dim=-1) |
| |
|