|
|
""" |
|
|
Rotary Position Embedding (RoPE) implementation. |
|
|
Applied to Q and K only, with fixed base (no dynamic scaling). |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Tuple |
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
"""Rotary Position Embedding (RoPE). |
|
|
|
|
|
RoPE encodes position information by rotating the query and key vectors. |
|
|
Key properties: |
|
|
- Parameter-free (no learnable embeddings) |
|
|
- Naturally encodes relative positions |
|
|
- Extrapolates well to longer sequences |
|
|
|
|
|
Reference: https://arxiv.org/abs/2104.09864 |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
max_position_embeddings: int = 1024, |
|
|
base: float = 10000.0, |
|
|
): |
|
|
"""Initialize RoPE. |
|
|
|
|
|
Args: |
|
|
dim: Dimension of the rotary embedding (usually head_dim) |
|
|
max_position_embeddings: Maximum sequence length |
|
|
base: Base for the frequency computation |
|
|
""" |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.base = base |
|
|
|
|
|
|
|
|
inv_freq = 1.0 / ( |
|
|
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
|
|
) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
|
|
|
self._set_cos_sin_cache(max_position_embeddings) |
|
|
|
|
|
def _set_cos_sin_cache(self, seq_len: int): |
|
|
"""Precompute cos and sin values for positions.""" |
|
|
self.max_seq_len_cached = seq_len |
|
|
t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
|
|
|
|
|
|
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
|
|
|
|
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
self.register_buffer("cos_cached", emb.cos(), persistent=False) |
|
|
self.register_buffer("sin_cached", emb.sin(), persistent=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
position_ids: torch.Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Apply rotary embeddings to query and key tensors. |
|
|
|
|
|
Args: |
|
|
q: Query tensor of shape [batch, num_heads, seq_len, head_dim] |
|
|
k: Key tensor of shape [batch, num_heads, seq_len, head_dim] |
|
|
position_ids: Position indices of shape [batch, seq_len] |
|
|
|
|
|
Returns: |
|
|
Tuple of (rotated_q, rotated_k) with same shapes as inputs |
|
|
""" |
|
|
seq_len = position_ids.max() + 1 |
|
|
|
|
|
|
|
|
if seq_len > self.max_seq_len_cached: |
|
|
self._set_cos_sin_cache(seq_len) |
|
|
|
|
|
|
|
|
|
|
|
cos = self.cos_cached[position_ids] |
|
|
sin = self.sin_cached[position_ids] |
|
|
|
|
|
|
|
|
cos = cos.unsqueeze(1) |
|
|
sin = sin.unsqueeze(1) |
|
|
|
|
|
|
|
|
q_embed = (q * cos) + (self._rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (self._rotate_half(k) * sin) |
|
|
|
|
|
return q_embed, k_embed |
|
|
|
|
|
@staticmethod |
|
|
def _rotate_half(x: torch.Tensor) -> torch.Tensor: |
|
|
"""Rotate half the hidden dims of the input. |
|
|
|
|
|
Splits the input into two halves and rotates: |
|
|
[x1, x2, x3, x4] -> [-x3, -x4, x1, x2] |
|
|
""" |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|