| import einops as E |
| import torch |
|
|
|
|
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: |
| """ |
| Precompute the frequency tensor for complex exponentials (cis) with given dimensions. |
| |
| This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' |
| and the end index 'end'. The 'theta' parameter scales the frequencies. |
| The returned tensor contains complex values in complex64 data type. |
| |
| Args: |
| dim (int): Dimension of the frequency tensor. |
| end (int): End index for precomputing frequencies. |
| theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. |
| |
| Returns: |
| torch.Tensor: Precomputed frequency tensor with complex exponentials. |
| """ |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) |
| t = torch.arange(end, device=freqs.device) |
| freqs = torch.outer(t, freqs).float() |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| return freqs_cis |
|
|
|
|
| def apply_rotary_emb( |
| xq: torch.Tensor, |
| xk: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """1D rotary embedding""" |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) |
| assert freqs_cis.ndim == 3, ( |
| "Freqs_cis must be indexed by position ids already and has shape (B,S,D)" |
| ) |
| freqs_cis = E.rearrange(freqs_cis, "b s d -> b s 1 d") |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) |
| return xq_out.type_as(xq), xk_out.type_as(xk) |
|
|
|
|
| |
| """ |
| Dimension key: |
| B: batch size |
| S: number of tokens per sample, Seqlen |
| T: Number of selected Tokens |
| P: pos_dim |
| h: n_heads |
| d: head_dim |
| F: num_freqs == head_dim // 2 |
| """ |
|
|
|
|
| def apply_golden_freqs_cis_to_visual_pos(freqs_hFP, pos_BSP) -> torch.Tensor: |
| """ |
| This function is applied once per input batch, and the cached |
| freqs_cis is passed through to all layers. |
| Safe for Torch‑Inductor because it never uses boolean indexing on a symbolic tensor. |
| """ |
| |
| img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all') |
| idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) |
|
|
| |
| pos_tP = pos_BSP[idx_b, idx_s].float() |
|
|
| |
| theta_thF = torch.einsum("tp,hfp->thf", pos_tP, freqs_hFP.float()) |
|
|
| |
| freqs_cis_thF = torch.polar(torch.ones_like(theta_thF), theta_thF) |
| return freqs_cis_thF |
|
|
|
|
| def apply_golden_rotary_emb(input_BShd, freqs_cis_thF, pos_BSP) -> torch.Tensor: |
| """ |
| Rotates *only* the image tokens in `input_BShd`. No boolean indexing, |
| so it is safe for Torch‑Inductor. |
| """ |
| img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all') |
| idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) |
|
|
| input_thd = input_BShd[idx_b, idx_s].float() |
| x_even = input_thd[..., 0::2] |
| x_odd = input_thd[..., 1::2] |
|
|
| cos_thF = freqs_cis_thF.real |
| sin_thF = freqs_cis_thF.imag |
|
|
| |
| rot_even = x_even * cos_thF - x_odd * sin_thF |
| rot_odd = x_even * sin_thF + x_odd * cos_thF |
|
|
| output_real = torch.empty_like(input_thd) |
| output_real[..., 0::2] = rot_even |
| output_real[..., 1::2] = rot_odd |
| output_real = output_real.type_as(input_BShd) |
|
|
| output_BShd = input_BShd.clone() |
| output_BShd[idx_b, idx_s] = output_real |
|
|
| return output_BShd |
|
|
|
|
| def apply_3d_rotary_emb( |
| xq: torch.Tensor, |
| xk: torch.Tensor, |
| freqs_cis: torch.Tensor, |
| freqs_cis_2d: torch.Tensor | None, |
| pos_hw: torch.Tensor | None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| xq_t, xq_hw = xq.chunk(chunks=2, dim=-1) |
| xk_t, xk_hw = xk.chunk(chunks=2, dim=-1) |
| B, S, H, D = xq.shape |
|
|
| xq_t, xk_t = apply_rotary_emb(xq_t, xk_t, freqs_cis) |
| if freqs_cis_2d is not None and pos_hw is not None: |
| xq_hw = apply_golden_rotary_emb(xq_hw, freqs_cis_2d, pos_hw) |
| xk_hw = apply_golden_rotary_emb(xk_hw, freqs_cis_2d, pos_hw) |
|
|
| xq_out = torch.concat([xq_t, xq_hw], dim=-1).type_as(xq) |
| xk_out = torch.concat([xk_t, xk_hw], dim=-1).type_as(xk) |
| return xq_out, xk_out |
|
|