Ahmed
Upload code/rope.py with huggingface_hub
cf2a3ce verified
"""Rotary Position Embeddings (RoPE).
Extracted from nanochat-v3/nanochat/gpt.py — identical math, standalone module.
"""
import torch
def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply rotary embeddings to input tensor x.
Args:
x: [batch, heads, seq_len, head_dim]
cos: [seq_len, head_dim//2]
sin: [seq_len, head_dim//2]
"""
assert x.ndim == 4
d = x.shape[3] // 2
x1, x2 = x[..., :d], x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3)
def precompute_rotary_embeddings(
seq_len: int, head_dim: int, base: float = 10000.0, device=None
) -> tuple[torch.Tensor, torch.Tensor]:
"""Precompute cos/sin buffers for RoPE.
Returns:
cos: [seq_len, head_dim//2]
sin: [seq_len, head_dim//2]
"""
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
inv_freq = 1.0 / (base ** (channel_range / head_dim))
t = torch.arange(seq_len, dtype=torch.float32, device=device)
freqs = torch.outer(t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
return cos, sin