File size: 4,689 Bytes
4f2517b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import einops as E
import torch


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    """
    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.

    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
    and the end index 'end'. The 'theta' parameter scales the frequencies.
    The returned tensor contains complex values in complex64 data type.

    Args:
        dim (int): Dimension of the frequency tensor.
        end (int): End index for precomputing frequencies.
        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.

    Returns:
        torch.Tensor: Precomputed frequency tensor with complex exponentials.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis  # [S, D//2]


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """1D rotary embedding"""
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    assert freqs_cis.ndim == 3, (
        "Freqs_cis must be indexed by position ids already and has shape (B,S,D)"
    )
    freqs_cis = E.rearrange(freqs_cis, "b s d -> b s 1 d")
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)


###### 2D golden rope
"""
Dimension key:
    B: batch size
    S: number of tokens per sample, Seqlen
    T: Number of selected Tokens
    P: pos_dim
    h: n_heads
    d: head_dim
    F: num_freqs == head_dim // 2
"""


def apply_golden_freqs_cis_to_visual_pos(freqs_hFP, pos_BSP) -> torch.Tensor:
    """
    This function is applied once per input batch, and the cached
    freqs_cis is passed through to all layers.
    Safe for Torch‑Inductor because it never uses boolean indexing on a symbolic tensor.
    """
    # 1. Boolean mask → integer indices (no unbacked shapes)
    img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all')
    idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True)   # each shape: (N,)

    # 2. Gather the positional tensor for those tokens
    pos_tP = pos_BSP[idx_b, idx_s].float() # (N, p)

    # 3. Project positions onto the frequency table → angles θ
    theta_thF = torch.einsum("tp,hfp->thf", pos_tP, freqs_hFP.float())  # (t, h, f)

    # 4. Convert to complex numbers on the unit circle
    freqs_cis_thF = torch.polar(torch.ones_like(theta_thF), theta_thF)
    return freqs_cis_thF


def apply_golden_rotary_emb(input_BShd, freqs_cis_thF, pos_BSP) -> torch.Tensor:
    """
    Rotates *only* the image tokens in `input_BShd`.  No boolean indexing,
    so it is safe for Torch‑Inductor.
    """
    img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all')
    idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True)  # (N,)

    input_thd = input_BShd[idx_b, idx_s].float()  # (N, h, d)
    x_even = input_thd[..., 0::2]  # (N, h, F)
    x_odd = input_thd[..., 1::2]   # (N, h, F)

    cos_thF = freqs_cis_thF.real
    sin_thF = freqs_cis_thF.imag

    # (a + ib) * (c + id) = (ac - bd) + i(ad + bc)
    rot_even = x_even * cos_thF - x_odd * sin_thF
    rot_odd = x_even * sin_thF + x_odd * cos_thF

    output_real = torch.empty_like(input_thd)
    output_real[..., 0::2] = rot_even
    output_real[..., 1::2] = rot_odd
    output_real = output_real.type_as(input_BShd)

    output_BShd = input_BShd.clone()
    output_BShd[idx_b, idx_s] = output_real

    return output_BShd


def apply_3d_rotary_emb(
    xq: torch.Tensor,  # (B, S, H, D)
    xk: torch.Tensor,  # (B, S, H, D)
    freqs_cis: torch.Tensor,
    freqs_cis_2d: torch.Tensor | None,
    pos_hw: torch.Tensor | None,  # (B,S,3)
) -> tuple[torch.Tensor, torch.Tensor]:
    xq_t, xq_hw = xq.chunk(chunks=2, dim=-1)
    xk_t, xk_hw = xk.chunk(chunks=2, dim=-1)
    B, S, H, D = xq.shape

    xq_t, xk_t = apply_rotary_emb(xq_t, xk_t, freqs_cis)
    if freqs_cis_2d is not None and pos_hw is not None:
        xq_hw = apply_golden_rotary_emb(xq_hw, freqs_cis_2d, pos_hw)
        xk_hw = apply_golden_rotary_emb(xk_hw, freqs_cis_2d, pos_hw)

    xq_out = torch.concat([xq_t, xq_hw], dim=-1).type_as(xq)
    xk_out = torch.concat([xk_t, xk_hw], dim=-1).type_as(xk)
    return xq_out, xk_out