import einops as E import torch def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (int): Dimension of the frequency tensor. end (int): End index for precomputing frequencies. theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. Returns: torch.Tensor: Precomputed frequency tensor with complex exponentials. """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 return freqs_cis # [S, D//2] def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """1D rotary embedding""" xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) assert freqs_cis.ndim == 3, ( "Freqs_cis must be indexed by position ids already and has shape (B,S,D)" ) freqs_cis = E.rearrange(freqs_cis, "b s d -> b s 1 d") xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) ###### 2D golden rope """ Dimension key: B: batch size S: number of tokens per sample, Seqlen T: Number of selected Tokens P: pos_dim h: n_heads d: head_dim F: num_freqs == head_dim // 2 """ def apply_golden_freqs_cis_to_visual_pos(freqs_hFP, pos_BSP) -> torch.Tensor: """ This function is applied once per input batch, and the cached freqs_cis is passed through to all layers. Safe for Torch‑Inductor because it never uses boolean indexing on a symbolic tensor. """ # 1. Boolean mask → integer indices (no unbacked shapes) img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all') idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) # each shape: (N,) # 2. Gather the positional tensor for those tokens pos_tP = pos_BSP[idx_b, idx_s].float() # (N, p) # 3. Project positions onto the frequency table → angles θ theta_thF = torch.einsum("tp,hfp->thf", pos_tP, freqs_hFP.float()) # (t, h, f) # 4. Convert to complex numbers on the unit circle freqs_cis_thF = torch.polar(torch.ones_like(theta_thF), theta_thF) return freqs_cis_thF def apply_golden_rotary_emb(input_BShd, freqs_cis_thF, pos_BSP) -> torch.Tensor: """ Rotates *only* the image tokens in `input_BShd`. No boolean indexing, so it is safe for Torch‑Inductor. """ img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all') idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) # (N,) input_thd = input_BShd[idx_b, idx_s].float() # (N, h, d) x_even = input_thd[..., 0::2] # (N, h, F) x_odd = input_thd[..., 1::2] # (N, h, F) cos_thF = freqs_cis_thF.real sin_thF = freqs_cis_thF.imag # (a + ib) * (c + id) = (ac - bd) + i(ad + bc) rot_even = x_even * cos_thF - x_odd * sin_thF rot_odd = x_even * sin_thF + x_odd * cos_thF output_real = torch.empty_like(input_thd) output_real[..., 0::2] = rot_even output_real[..., 1::2] = rot_odd output_real = output_real.type_as(input_BShd) output_BShd = input_BShd.clone() output_BShd[idx_b, idx_s] = output_real return output_BShd def apply_3d_rotary_emb( xq: torch.Tensor, # (B, S, H, D) xk: torch.Tensor, # (B, S, H, D) freqs_cis: torch.Tensor, freqs_cis_2d: torch.Tensor | None, pos_hw: torch.Tensor | None, # (B,S,3) ) -> tuple[torch.Tensor, torch.Tensor]: xq_t, xq_hw = xq.chunk(chunks=2, dim=-1) xk_t, xk_hw = xk.chunk(chunks=2, dim=-1) B, S, H, D = xq.shape xq_t, xk_t = apply_rotary_emb(xq_t, xk_t, freqs_cis) if freqs_cis_2d is not None and pos_hw is not None: xq_hw = apply_golden_rotary_emb(xq_hw, freqs_cis_2d, pos_hw) xk_hw = apply_golden_rotary_emb(xk_hw, freqs_cis_2d, pos_hw) xq_out = torch.concat([xq_t, xq_hw], dim=-1).type_as(xq) xk_out = torch.concat([xk_t, xk_hw], dim=-1).type_as(xk) return xq_out, xk_out