File size: 1,190 Bytes
cf2a3ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Rotary Position Embeddings (RoPE).

Extracted from nanochat-v3/nanochat/gpt.py — identical math, standalone module.
"""

import torch


def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
    """Apply rotary embeddings to input tensor x.

    Args:
        x: [batch, heads, seq_len, head_dim]
        cos: [seq_len, head_dim//2]
        sin: [seq_len, head_dim//2]
    """
    assert x.ndim == 4
    d = x.shape[3] // 2
    x1, x2 = x[..., :d], x[..., d:]
    y1 = x1 * cos + x2 * sin
    y2 = x1 * (-sin) + x2 * cos
    return torch.cat([y1, y2], 3)


def precompute_rotary_embeddings(
    seq_len: int, head_dim: int, base: float = 10000.0, device=None
) -> tuple[torch.Tensor, torch.Tensor]:
    """Precompute cos/sin buffers for RoPE.

    Returns:
        cos: [seq_len, head_dim//2]
        sin: [seq_len, head_dim//2]
    """
    channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
    inv_freq = 1.0 / (base ** (channel_range / head_dim))
    t = torch.arange(seq_len, dtype=torch.float32, device=device)
    freqs = torch.outer(t, inv_freq)
    cos = freqs.cos()
    sin = freqs.sin()
    return cos, sin