File size: 2,123 Bytes
3f40093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09597f3
 
 
 
 
 
 
 
 
 
 
4fe8766
09597f3
 
 
 
 
4fe8766
09597f3
 
 
 
 
 
 
 
 
 
 
3f40093
09597f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1168c80
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# Ethically sourced from https://github.com/xjdr-alt/entropix

import torch


def precompute_freqs_cis(
    dim: int,
    end: int,
    theta: float = 10000.0,
    use_scaled: bool = False,
    dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=dtype)[: (dim // 2)] / dim))
    t = torch.arange(end, dtype=dtype).unsqueeze(1)
    freqs = t * freqs.unsqueeze(0)
    freqs = torch.exp(1j * freqs)
    return torch.stack([freqs.real, freqs.imag], dim=-1)

# rope.py
import torch

def apply_rotary_emb(
    x: torch.Tensor,
    freqs_cis: torch.Tensor,
    position_ids: torch.Tensor,
    num_heads: int,
    rot_dim: int = 32,
    interleave: bool = False,
) -> torch.Tensor:
    """
    RoPE as used in the original moondream2 text stack:
      x: (B, H, T, D)
      freqs_cis: (max_T, rot_dim//2, 2) where [...,0]=cos, [...,1]=sin
      position_ids: (T,) or (B,T)
      returns x with first rot_dim dims rotated.
    """
    assert rot_dim == freqs_cis.shape[-2] * 2
    assert num_heads == x.shape[1]

    B, H, T, D = x.shape
    rd = min(rot_dim, D)
    x_rot, x_pass = x[..., :rd], x[..., rd:]

    # split real/imag parts depending on layout
    if interleave:
        xr = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 0]
        xi = x_rot.float().reshape(*x_rot.shape[:-1], -1, 2)[..., 1]
    else:
        d = x_rot.shape[-1] // 2
        xr, xi = x_rot[..., :d], x_rot[..., d:]

    # gather cos/sin for these positions
    if position_ids.dim() == 2 and position_ids.size(0) == B:
        freq = freqs_cis[position_ids]                         # (B, T, rd//2, 2)
    else:  # (T,) or scalar
        freq = freqs_cis[position_ids].unsqueeze(0).expand(B, -1, -1, -1)

    rot_half = rd // 2
    cos = freq[..., 0][..., :rot_half].unsqueeze(1).to(x.dtype)  # (B,1,T,rot_half)
    sin = freq[..., 1][..., :rot_half].unsqueeze(1).to(x.dtype)

    # complex multiply
    yr = xr * cos - xi * sin
    yi = xr * sin + xi * cos
    y = torch.stack((yr, yi), dim=-1).flatten(-2)                # (B,H,T,rd)

    return torch.cat([y, x_pass], dim=-1)