yasserDahou's picture
Initial release: Falcon Perception open-vocabulary segmentation
4f2517b verified
import einops as E
import torch
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
end (int): End index for precomputing frequencies.
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
Returns:
torch.Tensor: Precomputed frequency tensor with complex exponentials.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis # [S, D//2]
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""1D rotary embedding"""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
assert freqs_cis.ndim == 3, (
"Freqs_cis must be indexed by position ids already and has shape (B,S,D)"
)
freqs_cis = E.rearrange(freqs_cis, "b s d -> b s 1 d")
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
###### 2D golden rope
"""
Dimension key:
B: batch size
S: number of tokens per sample, Seqlen
T: Number of selected Tokens
P: pos_dim
h: n_heads
d: head_dim
F: num_freqs == head_dim // 2
"""
def apply_golden_freqs_cis_to_visual_pos(freqs_hFP, pos_BSP) -> torch.Tensor:
"""
This function is applied once per input batch, and the cached
freqs_cis is passed through to all layers.
Safe for Torch‑Inductor because it never uses boolean indexing on a symbolic tensor.
"""
# 1. Boolean mask → integer indices (no unbacked shapes)
img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all')
idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) # each shape: (N,)
# 2. Gather the positional tensor for those tokens
pos_tP = pos_BSP[idx_b, idx_s].float() # (N, p)
# 3. Project positions onto the frequency table → angles θ
theta_thF = torch.einsum("tp,hfp->thf", pos_tP, freqs_hFP.float()) # (t, h, f)
# 4. Convert to complex numbers on the unit circle
freqs_cis_thF = torch.polar(torch.ones_like(theta_thF), theta_thF)
return freqs_cis_thF
def apply_golden_rotary_emb(input_BShd, freqs_cis_thF, pos_BSP) -> torch.Tensor:
"""
Rotates *only* the image tokens in `input_BShd`. No boolean indexing,
so it is safe for Torch‑Inductor.
"""
img_mask_BS = E.reduce(~torch.isnan(pos_BSP), 'b s p -> b s', reduction='all')
idx_b, idx_s = torch.nonzero(img_mask_BS, as_tuple=True) # (N,)
input_thd = input_BShd[idx_b, idx_s].float() # (N, h, d)
x_even = input_thd[..., 0::2] # (N, h, F)
x_odd = input_thd[..., 1::2] # (N, h, F)
cos_thF = freqs_cis_thF.real
sin_thF = freqs_cis_thF.imag
# (a + ib) * (c + id) = (ac - bd) + i(ad + bc)
rot_even = x_even * cos_thF - x_odd * sin_thF
rot_odd = x_even * sin_thF + x_odd * cos_thF
output_real = torch.empty_like(input_thd)
output_real[..., 0::2] = rot_even
output_real[..., 1::2] = rot_odd
output_real = output_real.type_as(input_BShd)
output_BShd = input_BShd.clone()
output_BShd[idx_b, idx_s] = output_real
return output_BShd
def apply_3d_rotary_emb(
xq: torch.Tensor, # (B, S, H, D)
xk: torch.Tensor, # (B, S, H, D)
freqs_cis: torch.Tensor,
freqs_cis_2d: torch.Tensor | None,
pos_hw: torch.Tensor | None, # (B,S,3)
) -> tuple[torch.Tensor, torch.Tensor]:
xq_t, xq_hw = xq.chunk(chunks=2, dim=-1)
xk_t, xk_hw = xk.chunk(chunks=2, dim=-1)
B, S, H, D = xq.shape
xq_t, xk_t = apply_rotary_emb(xq_t, xk_t, freqs_cis)
if freqs_cis_2d is not None and pos_hw is not None:
xq_hw = apply_golden_rotary_emb(xq_hw, freqs_cis_2d, pos_hw)
xk_hw = apply_golden_rotary_emb(xk_hw, freqs_cis_2d, pos_hw)
xq_out = torch.concat([xq_t, xq_hw], dim=-1).type_as(xq)
xk_out = torch.concat([xk_t, xk_hw], dim=-1).type_as(xk)
return xq_out, xk_out