nameissakthi's picture
Add model architecture code
27871e7
"""
Rotary Position Embedding (RoPE) implementation.
Applied to Q and K only, with fixed base (no dynamic scaling).
"""
import torch
import torch.nn as nn
from typing import Tuple
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE).
RoPE encodes position information by rotating the query and key vectors.
Key properties:
- Parameter-free (no learnable embeddings)
- Naturally encodes relative positions
- Extrapolates well to longer sequences
Reference: https://arxiv.org/abs/2104.09864
"""
def __init__(
self,
dim: int,
max_position_embeddings: int = 1024,
base: float = 10000.0,
):
"""Initialize RoPE.
Args:
dim: Dimension of the rotary embedding (usually head_dim)
max_position_embeddings: Maximum sequence length
base: Base for the frequency computation
"""
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
# Precompute inverse frequencies
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Precompute cos and sin for all positions
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int):
"""Precompute cos and sin values for positions."""
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
# Outer product: [seq_len] x [dim/2] -> [seq_len, dim/2]
freqs = torch.outer(t, self.inv_freq)
# Concatenate to get [seq_len, dim]
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
position_ids: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to query and key tensors.
Args:
q: Query tensor of shape [batch, num_heads, seq_len, head_dim]
k: Key tensor of shape [batch, num_heads, seq_len, head_dim]
position_ids: Position indices of shape [batch, seq_len]
Returns:
Tuple of (rotated_q, rotated_k) with same shapes as inputs
"""
seq_len = position_ids.max() + 1
# Extend cache if needed
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
# Get cos and sin for the positions
# Shape: [batch, seq_len, dim]
cos = self.cos_cached[position_ids]
sin = self.sin_cached[position_ids]
# Add head dimension: [batch, 1, seq_len, dim]
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
# Apply rotation
q_embed = (q * cos) + (self._rotate_half(q) * sin)
k_embed = (k * cos) + (self._rotate_half(k) * sin)
return q_embed, k_embed
@staticmethod
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input.
Splits the input into two halves and rotates:
[x1, x2, x3, x4] -> [-x3, -x4, x1, x2]
"""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)