Falcon-Perception / modeling_falcon_perception.py
yasserDahou's picture
Update modeling_falcon_perception.py
ad720cd verified
import math
from pathlib import Path
import einops as E
import numpy as np
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from PIL import Image
from pycocotools import mask as mask_utils
from torch import Tensor as T
from torch import nn
from torch.nn.attention.flex_attention import (
AuxRequest,
BlockMask,
)
from transformers import AutoTokenizer, PreTrainedModel
from .anyup import AnyUp, get_attention_mask_mod as get_upsampler_attn_mask_mod
from .attention import (
compiled_flex_attn_decode,
compiled_flex_attn_prefill,
create_attention_mask,
create_batch_attention_mask,
offset_mask_mod,
)
from .configuration_falcon_perception import FalconPerceptionConfig
from .processing_falcon_perception import load_image, process_batch
from .rope import (
apply_3d_rotary_emb,
apply_golden_freqs_cis_to_visual_pos,
precompute_freqs_cis,
)
# ---------------------------------------------------------------------------
# Sub-modules: Heads
# ---------------------------------------------------------------------------
class FourierEncoder(nn.Module):
def __init__(self, in_dim: int, feat_dim: int, out_dim: int):
super().__init__()
self.embed = nn.Linear(in_dim, feat_dim // 2, bias=False)
self.transform = nn.Linear(feat_dim, out_dim, bias=False)
def forward(self, x):
f = 2 * math.pi * self.embed(x)
f = torch.cat([f.cos(), f.sin()], dim=-1)
return self.transform(f)
class BboxDecoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None:
super().__init__()
self.w1 = nn.Linear(in_dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
def forward(self, x: T) -> T:
return self.w2(F.relu(self.w1(x)).square())
class SegmDecoder(nn.Module):
def __init__(self, in_dim: int, out_dim: int, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, in_dim) for _ in range(num_layers - 1)])
self.pixel_layer = nn.Linear(in_dim, out_dim, bias=False)
def forward(self, x) -> torch.Tensor:
for layer in self.layers:
x = F.relu(layer(x)).square()
return self.pixel_layer(x)
# ---------------------------------------------------------------------------
# Sub-modules: Attention
# ---------------------------------------------------------------------------
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
B, S, H, D = x.shape
if n_rep == 1:
return x
return torch.unsqueeze(x, dim=3).expand(B, S, H, n_rep, D).reshape(B, S, H * n_rep, D)
class Attention(nn.Module):
def __init__(self, config: FalconPerceptionConfig, layer_id: int):
super().__init__()
self.layer_id = layer_id
self.n_kv_heads = config.n_kv_heads or config.n_heads
self.n_rep = config.n_heads // self.n_kv_heads
self.head_dim = config.head_dim or config.dim // config.n_heads
self.q_dim = config.n_heads * self.head_dim
self.kv_dim = self.n_kv_heads * self.head_dim
self.wqkv = nn.Linear(config.dim, self.q_dim + 2 * self.kv_dim, bias=False)
self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
self.sinks = nn.Parameter(torch.empty((config.n_heads,)))
def _pre_attention_qkv(self, x) -> tuple[T, T, T]:
qkv = self.wqkv(F.rms_norm(x, (x.size(-1),)))
xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim)
xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim)
xv = E.rearrange(xv, "b s (h d) -> b s h d", d=self.head_dim)
xq = F.rms_norm(xq, (xq.size(-1),))
xk = F.rms_norm(xk, (xk.size(-1),))
xk = repeat_kv(xk, n_rep=self.n_rep)
xv = repeat_kv(xv, n_rep=self.n_rep)
return xq, xk, xv
def _post_attention(self, output: T, lse: T) -> T:
sinks_BHS = self.sinks.view(1, -1, 1)
sink_scale = torch.sigmoid(lse - sinks_BHS)
output = (output * sink_scale.unsqueeze(-1)).to(output.dtype)
output = output.permute(0, 2, 1, 3).contiguous().flatten(2)
return self.wo(output)
def compile_attention(self, *, dynamic: bool = True, mode: str = "default"):
self._pre_attention_qkv = torch.compile(self._pre_attention_qkv, dynamic=dynamic, mode=mode)
self._post_attention = torch.compile(self._post_attention, dynamic=dynamic, mode=mode)
def forward(
self, x: T, attention_masks: BlockMask, freqs_cis: T,
freqs_cis_2d: T | None = None, pos_hw: T | None = None,
kv_cache=None, input_pos=None, batch_idx=None,
flex_attn_kernel_options=None,
):
xq, xk, xv = self._pre_attention_qkv(x)
xq, xk = apply_3d_rotary_emb(xq, xk, freqs_cis, freqs_cis_2d, pos_hw)
xq = E.rearrange(xq, "b s h d -> b h s d")
xk = E.rearrange(xk, "b s h d -> b h s d")
xv = E.rearrange(xv, "b s h d -> b h s d")
xk, xv = kv_cache.insert_kv(self.layer_id, xk, xv, input_pos=input_pos, batch_idx=batch_idx)
flex_fn = compiled_flex_attn_decode if xq.shape[2] == 1 else compiled_flex_attn_prefill
output, aux_output = flex_fn(xq, xk, xv, block_mask=attention_masks, return_aux=AuxRequest(lse=True))
return self._post_attention(output, aux_output.lse)
# ---------------------------------------------------------------------------
# Sub-modules: FeedForward
# ---------------------------------------------------------------------------
@triton.jit
def _squared_relu_gate_kernel(
packed_ptr, out_ptr, n_rows, n_cols,
in_row_stride, in_col_stride, out_row_stride, out_col_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_elements = n_rows * n_cols
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
rows = offsets // n_cols
cols = offsets % n_cols
gate_idx = rows * in_row_stride + (2 * cols) * in_col_stride
up_idx = rows * in_row_stride + (2 * cols + 1) * in_col_stride
out_idx = rows * out_row_stride + cols * out_col_stride
gate = tl.load(packed_ptr + gate_idx, mask=mask)
up = tl.load(packed_ptr + up_idx, mask=mask)
gate = tl.where(gate > 0, gate, 0.0)
out = gate * gate * up
tl.store(out_ptr + out_idx, out, mask=mask)
def squared_relu_gate(packed: T, hidden_dim: int) -> T:
packed_2d = packed.flatten(0, -2)
n_rows = packed_2d.shape[0]
n_cols = hidden_dim
out_2d = torch.empty((n_rows, n_cols), device=packed.device, dtype=packed.dtype)
n = n_rows * n_cols
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
_squared_relu_gate_kernel[grid](
packed_2d, out_2d, n_rows, n_cols,
packed_2d.stride(0), packed_2d.stride(1),
out_2d.stride(0), out_2d.stride(1),
BLOCK_SIZE=1024,
)
return out_2d.view(*packed.shape[:-1], hidden_dim)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.hidden_dim = hidden_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.rms_norm(x, (x.size(-1),))
w13_out = self.w13(x)
return self.w2(squared_relu_gate(w13_out, self.hidden_dim))
# ---------------------------------------------------------------------------
# Sub-modules: TransformerBlock
# ---------------------------------------------------------------------------
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, config: FalconPerceptionConfig):
super().__init__()
self.attention = Attention(config, layer_id)
self.feed_forward = FeedForward(config.dim, config.ffn_dim)
def compile(self, *, dynamic: bool = True, mode: str = "default"):
self.feed_forward = torch.compile(self.feed_forward, dynamic=dynamic, mode=mode)
self.attention.compile_attention(dynamic=dynamic, mode=mode)
return self
def forward(
self, x: T, freqs_cis: T, freqs_cis_2d: T | None = None,
pos_hw: T | None = None, attention_masks=None, kv_cache=None,
input_pos=None, batch_idx=None, flex_attn_kernel_options=None,
):
B, S, D = x.shape
x = x + self.attention(
x, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_2d, pos_hw=pos_hw,
attention_masks=attention_masks, kv_cache=kv_cache,
input_pos=input_pos, batch_idx=batch_idx,
flex_attn_kernel_options=flex_attn_kernel_options,
)
out = x + self.feed_forward(x)
return out.reshape(B, S, D)
# ---------------------------------------------------------------------------
# KV Cache
# ---------------------------------------------------------------------------
class KVCache:
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, num_layers):
self.kv_shape = (num_layers, 2, max_batch_size, n_heads, max_seq_length, head_dim)
self.kv_cache = None
self.pos = 0
self.pos_t: T | None = None
def reset(self):
self.pos = 0
self.pos_t = None
def get_pos(self):
return self.pos
def set_pos_t(self, pos_t):
self.pos_t = pos_t
def increment_and_get_pos_t(self):
assert self.pos_t is not None
self.pos_t += 1
return self.pos_t
def insert_kv(self, layer_id: int, k: T, v: T, **kwargs):
del kwargs
assert self.pos_t is not None
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
self.kv_cache[layer_id, 0, :, :, t0:t1] = k
self.kv_cache[layer_id, 1, :, :, t0:t1] = v
key_view = self.kv_cache[layer_id, 0, :, :, :t1]
value_view = self.kv_cache[layer_id, 1, :, :, :t1]
if layer_id == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
# ---------------------------------------------------------------------------
# Sampling
# ---------------------------------------------------------------------------
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=0.0, top_k=None):
assert temperature >= 0.0
if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
if top_k is not None:
k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, k, dim=-1)
vals = vals / temperature
probs = F.softmax(vals, dim=-1)
choice = torch.multinomial(probs, num_samples=1, generator=rng)
return idx.gather(1, choice)
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)
# ---------------------------------------------------------------------------
# Main Model
# ---------------------------------------------------------------------------
class FalconPerceptionForSegmentation(PreTrainedModel):
config_class = FalconPerceptionConfig
_no_split_modules = ["TransformerBlock"]
def __init__(self, config: FalconPerceptionConfig):
super().__init__(config)
img_in_dim = config.temporal_patch_size * config.spatial_patch_size ** 2 * config.channel_size
self.img_projector = nn.Linear(img_in_dim, config.dim, bias=False)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleDict()
for layer_id in range(config.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, config)
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.coord_encoder = FourierEncoder(2, config.coord_enc_dim, config.dim)
self.coord_decoder = BboxDecoder(config.dim, config.coord_dec_dim, config.coord_out_dim)
self.size_encoder = FourierEncoder(2, config.size_enc_dim, config.dim)
self.size_decoder = BboxDecoder(config.dim, config.size_dec_dim, config.size_out_dim)
if config.do_segmentation:
self.itok_upsampler = AnyUp()
self.proj_segm = SegmDecoder(config.dim, config.segm_out_dim, config.num_segm_layers)
self.conv_segm = nn.Conv2d(config.dim, config.segm_out_dim, kernel_size=3, padding=1)
rope_dim = config.head_dim // 2
freqs_cis = precompute_freqs_cis(rope_dim, config.max_seq_len, config.rope_theta)
freqs_cis_golden = torch.empty((config.n_heads, rope_dim // 2, 2), dtype=torch.float)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
self.register_buffer("freqs_cis_golden", freqs_cis_golden, persistent=True)
self._weights_fused = False
self._is_compiled = False
self.post_init()
# -- Weight management ---------------------------------------------------
def _ensure_device_buffers(self):
"""Recompute non-persistent buffers that HF meta-device loading may discard."""
if self._weights_fused:
return
device = self.tok_embeddings.weight.device
c = self.config
rope_dim = c.head_dim // 2
freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
if self.freqs_cis_golden.device != device:
self.freqs_cis_golden = self.freqs_cis_golden.to(device)
self._weights_fused = True
def compile_model(self):
if self._is_compiled:
return
torch._inductor.config.triton.cudagraphs = False
for layer in self.layers.values():
layer.compile(dynamic=True, mode="default")
self.coord_encoder = torch.compile(self.coord_encoder, dynamic=True, mode="default")
self.coord_decoder = torch.compile(self.coord_decoder, dynamic=True, mode="default")
self.size_encoder = torch.compile(self.size_encoder, dynamic=True, mode="default")
self.size_decoder = torch.compile(self.size_decoder, dynamic=True, mode="default")
if self.config.do_segmentation:
self.itok_upsampler.compile(mode="default", dynamic=True)
self._is_compiled = True
# -- Tokenizer -----------------------------------------------------------
def _get_tokenizer(self):
if not hasattr(self, "_tokenizer"):
import os
path = self.config._name_or_path
is_local = os.path.exists(path)
self._tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=is_local, trust_remote_code=True)
for token_name, token in self._tokenizer.special_tokens_map.items():
if isinstance(token, str):
setattr(self._tokenizer, token_name, token)
setattr(
self._tokenizer, token_name + "_id",
self._tokenizer.convert_tokens_to_ids(token),
)
return self._tokenizer
# -- Attention mask ------------------------------------------------------
def get_attention_mask(self, input_batch: T, max_len: int | None = None):
return create_batch_attention_mask(
input_batch,
pad_token_id=self._pad_token_id,
eos_token_id=self.config.eos_id,
soi_token_id=self.config.image_cls_token_id,
eoi_token_id=self.config.img_end_id,
max_len=max_len,
)
def get_upsampler_attn_mask(self, H, W, h, w, device):
return create_attention_mask(
get_upsampler_attn_mask_mod(H, W, h, w, device=device),
B=None, H=None, Q_LEN=H * W, KV_LEN=h * w,
)
# -- Embedding helpers ---------------------------------------------------
def _scatter_img_tokens_with_projector(self, h_BSD, pixel_patches_NLC, pixel_masks_NTHW, tokens_BS):
B, S, D = h_BSD.shape
pixel_patch_mask = E.reduce(
pixel_masks_NTHW,
"n (t pt) (h ph) (w pw) -> (n t h w)",
reduction="any",
pt=self.config.temporal_patch_size,
ph=self.config.spatial_patch_size,
pw=self.config.spatial_patch_size,
)
pixel_patches_flat = E.rearrange(pixel_patches_NLC, "n p c -> (n p) c")
valid_patches = pixel_patches_flat[pixel_patch_mask]
valid_feats = self.img_projector(valid_patches)
img_mask_h_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D)
assert valid_feats.numel() == img_mask_h_BSD.sum()
return torch.masked_scatter(h_BSD, img_mask_h_BSD, valid_feats)
def _encode_coords(self, h_BSD: T, tokens_BS: T, all_xy: T):
coord_tokens_mask = tokens_BS == self.config.coord_token_id
if all_xy.numel() == 0:
return h_BSD
coord_tokens = self.coord_encoder(all_xy.reshape(-1, 2))
if coord_tokens.shape[0] == h_BSD.shape[0]:
h_BSD = torch.where(
coord_tokens_mask.unsqueeze(-1),
coord_tokens.view(h_BSD.shape[0], -1, h_BSD.shape[-1]),
h_BSD,
)
else:
h_BSD = h_BSD.masked_scatter_(coord_tokens_mask.unsqueeze(-1), coord_tokens)
return h_BSD
def _encode_sizes(self, h_BSD, tokens_BS, all_hw: T):
size_tokens_mask = tokens_BS == self.config.size_token_id
if all_hw.numel() == 0:
return h_BSD
size_tokens = self.size_encoder(all_hw.reshape(-1, 2))
if size_tokens.shape[0] == h_BSD.shape[0]:
h_BSD = torch.where(
size_tokens_mask.unsqueeze(-1),
size_tokens.view(h_BSD.shape[0], -1, h_BSD.shape[-1]),
h_BSD,
)
else:
h_BSD = h_BSD.masked_scatter_(size_tokens_mask.unsqueeze(-1), size_tokens)
return h_BSD
def decode_coords(self, h_BSD, labels):
B, S, D = h_BSD.shape
coord_masks = labels == self.config.coord_token_id
coord_tokens = torch.masked_select(h_BSD, coord_masks.unsqueeze(-1))
coord_logits = self.coord_decoder(coord_tokens.reshape(-1, D))
return E.rearrange(coord_logits, "b (two dim) -> b two dim", two=2)
def decode_sizes(self, h_BSD, labels):
B, S, D = h_BSD.shape
size_masks = labels == self.config.size_token_id
size_tokens = torch.masked_select(h_BSD, size_masks.unsqueeze(-1))
size_logits = self.size_decoder(size_tokens.reshape(-1, D))
return E.rearrange(size_logits, "b (two dim) -> b two dim", two=2)
def process_sizes(self, logits):
num_bins = logits.shape[-1]
pred = torch.argmax(logits, dim=-1).float() / (num_bins - 1)
min_size = torch.log2(torch.tensor(1 / num_bins))
max_size = 0.0
pred = pred * (max_size - min_size) + min_size
return torch.pow(2.0, pred)
# -- Segmentation -------------------------------------------------------
def gather_img_tokens(self, h_BSD: T, tokens_BS: T, itok_masks_NTHW: T):
B, S, D = h_BSD.shape
itok_masks_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D)
itok_flatten = torch.masked_select(h_BSD, itok_masks_BSD)
itok_masks_NTHWD = E.repeat(itok_masks_NTHW, "n t h w -> n t h w d", d=D)
itok_NTHWD = torch.zeros_like(itok_masks_NTHWD, dtype=h_BSD.dtype, device=h_BSD.device)
itok_NTHWD = itok_NTHWD.masked_scatter_(itok_masks_NTHWD, itok_flatten)
return itok_NTHWD
def upsample_img_features(self, h_BSD: T, tokens_BS: T, pixel_values_NTHWC: T, pixel_mask_NTHW: T):
device = h_BSD.device
c = self.config
itok_masks_NTHW = E.reduce(
pixel_mask_NTHW,
"n (t pt) (h ph) (w pw) -> n t h w",
reduction="any",
pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size,
)
N, _, h, w = itok_masks_NTHW.shape
_, _, H, W = pixel_mask_NTHW.shape
images = E.rearrange(pixel_values_NTHWC, "n 1 h w c -> n c h w")
lr_img_features = self.gather_img_tokens(h_BSD, tokens_BS, itok_masks_NTHW)
lr_img_features = E.rearrange(lr_img_features, "n 1 h w d -> n d h w")
lr_img_features = self.conv_segm(lr_img_features)
upsampler_attn_mask = self.get_upsampler_attn_mask(H, W, h, w, device=device)
hr_parts = []
for i in range(N):
hr_i = self.itok_upsampler(
images=images[i:i + 1], features=lr_img_features[i:i + 1], attn_mask=upsampler_attn_mask,
)
hr_parts.append(hr_i)
return torch.cat(hr_parts, dim=0) if N > 1 else hr_parts[0]
@staticmethod
def _mask_to_coco_rle(binary_masks: torch.Tensor) -> list[dict]:
C, H, W = binary_masks.shape
has_any = E.reduce(binary_masks, "c h w -> c", reduction="any")
binary_col = E.rearrange(binary_masks, "c h w -> c (w h)")
diffs = binary_col[:, 1:] != binary_col[:, :-1]
nz = torch.nonzero(diffs, as_tuple=False)
first_vals = binary_col[:, 0]
nz_cpu = nz.cpu().numpy()
has_any_cpu = has_any.cpu().numpy()
first_vals_cpu = first_vals.cpu().numpy()
del diffs, nz, binary_col, first_vals, has_any
N_px = H * W
if nz_cpu.shape[0] > 0:
mask_ids = nz_cpu[:, 0]
change_cols = nz_cpu[:, 1]
uniq, grp_starts = np.unique(mask_ids, return_index=True)
grp_ends = np.append(grp_starts[1:], len(mask_ids))
mask_to_grp = {int(m): (int(gs), int(ge)) for m, gs, ge in zip(uniq, grp_starts, grp_ends)}
else:
change_cols = np.array([], dtype=np.intp)
mask_to_grp = {}
results = []
for i in range(C):
if not has_any_cpu[i]:
continue
if i in mask_to_grp:
gs, ge = mask_to_grp[i]
cidx = change_cols[gs:ge]
else:
cidx = np.array([], dtype=np.intp)
num_runs = len(cidx) + 1
starts = np.empty(num_runs, dtype=np.intp)
starts[0] = 0
if len(cidx) > 0:
starts[1:] = cidx + 1
counts = np.empty(num_runs, dtype=np.uint32)
if num_runs > 1:
counts[:-1] = np.diff(starts)
counts[-1] = N_px - starts[-1]
if first_vals_cpu[i]:
counts = np.concatenate([[0], counts])
rle = {"counts": counts.tolist(), "size": [H, W]}
rle = mask_utils.frPyObjects(rle, H, W)
rle["counts"] = rle["counts"].decode("utf-8")
results.append(rle)
return results
# -- Core forward --------------------------------------------------------
def forward(
self,
tokens: T,
attention_mask: BlockMask,
kv_cache,
rope_pos_t: T | None = None,
rope_pos_hw: T | None = None,
pixel_values: T | None = None,
pixel_mask: T | None = None,
coord_xy: T | None = None,
size_hw: T | None = None,
):
B, S = tokens.size()
c = self.config
block_mask = attention_mask
T_pos = kv_cache.get_pos()
is_prefill = S != 1
if is_prefill:
assert rope_pos_t is not None and rope_pos_hw is not None
pos_t = rope_pos_t[:, T_pos:T_pos + S].long()
kv_cache.pos_t = pos_t[:, -1:]
freqs_cis = self.freqs_cis[pos_t]
rope_pos_hw = rope_pos_hw[:, T_pos:T_pos + S]
freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(self.freqs_cis_golden, rope_pos_hw)
block_mask.seq_lengths = (S, S)
else:
pos_t = kv_cache.increment_and_get_pos_t()
freqs_cis = self.freqs_cis[pos_t]
freqs_cis_golden = None
block_idx = T_pos // block_mask.BLOCK_SIZE[0]
block_mask = block_mask[:, :, block_idx]
block_mask.seq_lengths = (S, T_pos + S)
block_mask.mask_mod = offset_mask_mod(attention_mask.mask_mod, offset=T_pos)
h_BSD = self.tok_embeddings(tokens)
coord_xy = coord_xy if coord_xy is not None else h_BSD.new_empty(0)
size_hw = size_hw if size_hw is not None else h_BSD.new_empty(0)
h_BSD = self._encode_coords(h_BSD, tokens, coord_xy)
h_BSD = self._encode_sizes(h_BSD, tokens, size_hw)
if pixel_values is not None:
assert pixel_mask is not None
pixel_values = pixel_values.to(self.dtype)
pixel_mask = pixel_mask.to(self.dtype)
pixel_patches_NLC = E.rearrange(
pixel_values,
"n (t pt) (h ph) (w pw) c -> n (t h w) (pt ph pw c)",
pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size,
)
h_BSD = self._scatter_img_tokens_with_projector(h_BSD, pixel_patches_NLC, pixel_mask, tokens)
for layer in self.layers.values():
h_BSD = layer(
h_BSD, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_golden,
pos_hw=rope_pos_hw, attention_masks=block_mask, kv_cache=kv_cache,
)
h_BSD = self.norm(h_BSD)
logits_BSV = self.output(h_BSD)
return logits_BSV, h_BSD
# -- Main API: generate --------------------------------------------------
@torch.inference_mode()
def generate(
self,
images,
queries,
max_new_tokens: int = 2048,
temperature: float = 0.0,
top_k: int | None = None,
min_dimension: int = 256,
max_dimension: int = 1024,
compile: bool = True,
seed: int | None = 42,
segm_threshold: float = 0.5,
) -> list[list[dict]]:
"""
Segment objects in images matching the given queries.
Args:
images: Single PIL Image (or path/URL) or list of them.
queries: Single query string or list of query strings (one per image).
max_new_tokens: Maximum generation steps.
temperature: Sampling temperature (0.0 = greedy).
top_k: Top-k sampling (None = disabled).
min_dimension: Min image side after resize.
max_dimension: Max image side after resize.
compile: Whether to torch.compile on first call.
seed: Random seed for reproducibility (None = non-deterministic).
segm_threshold: Sigmoid threshold for binary mask.
Returns:
List (per image) of lists (per detection) of dicts::
{
"xy": {"x": float, "y": float},
"hw": {"h": float, "w": float},
"mask_rle": {"counts": str, "size": [H, W]},
}
"""
self._ensure_device_buffers()
if compile:
self.compile_model()
# Normalize inputs
if isinstance(images, (str, Path, Image.Image)):
images = [images]
if isinstance(queries, str):
queries = [queries]
assert len(images) == len(queries), "Must provide one query per image"
device = self.device
tokenizer = self._get_tokenizer()
self._pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
stop_token_ids = [self.config.eos_id, tokenizer.convert_tokens_to_ids("<|end_of_query|>")]
# Store original image sizes for mask resizing
pil_images = [load_image(img).convert("RGB") for img in images]
original_sizes = [(img.height, img.width) for img in pil_images]
# Build prompts
image_prompt_pairs = [
(img, f"<|image|>Segment these expressions in the image:<|start_of_query|>{q}<|REF_SEG|>")
for img, q in zip(pil_images, queries)
]
# Preprocess
batch_inputs = process_batch(
tokenizer, self.config, image_prompt_pairs,
max_length=4096, min_dimension=min_dimension, max_dimension=max_dimension,
)
batch_inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch_inputs.items()}
tokens = batch_inputs["tokens"]
B, L = tokens.size()
block_size = 128
S = (L + max_new_tokens + block_size - 1) // block_size * block_size
assert S <= self.config.max_seq_len
rng = torch.Generator(device).manual_seed(seed) if seed is not None else None
kv_cache = KVCache(
max_batch_size=B, max_seq_length=S, n_heads=self.config.n_heads,
head_dim=self.config.head_dim, num_layers=self.config.n_layers,
)
padded_tokens = torch.full((B, S), self._pad_token_id, dtype=tokens.dtype, device=device)
padded_tokens[:, :L] = tokens
attention_mask = self.get_attention_mask(padded_tokens, max_len=S)
all_xy, all_hw = self._extract_coords([[]])
coord_xy = all_xy.to(device=device, dtype=self.dtype)
size_hw_t = all_hw.to(device=device, dtype=self.dtype)
# Prefill
logits_BSV, h_BSD = self.forward(
tokens=tokens, rope_pos_t=batch_inputs["pos_t"], rope_pos_hw=batch_inputs["pos_hw"],
attention_mask=attention_mask, kv_cache=kv_cache,
pixel_values=batch_inputs["pixel_values"], pixel_mask=batch_inputs["pixel_mask"],
coord_xy=coord_xy, size_hw=size_hw_t,
)
hr_img_features = self.upsample_img_features(
h_BSD, tokens, batch_inputs["pixel_values"], batch_inputs["pixel_mask"],
)
aux_output_B = [[] for _ in range(B)]
stop_ids = torch.tensor(stop_token_ids).to(device)
should_stop_B = torch.full((B,), False, dtype=torch.bool, device=device)
# Decode loop
while not torch.all(should_stop_B) and (pos := kv_cache.get_pos()) < S:
tokens_B1 = sample_next_token(logits_BSV[:, -1], rng, temperature, top_k)
if torch.any(should_stop_B):
tokens_B1 = tokens_B1.clone()
tokens_B1[should_stop_B, :] = self._pad_token_id
padded_tokens[:, pos] = tokens_B1[:, -1]
# Decode coords (with deduplication to avoid repeating the same location)
coord_logits = self.decode_coords(h_BSD[:, -1:], tokens_B1)
sample_w_coord = torch.where(tokens_B1 == self.config.coord_token_id)[0]
num_bins = coord_logits.size(-1)
coord_repeat_threshold = 0.01 # coords within 1% of image size are considered duplicates
max_coord_attempts = 100
xy_b2 = torch.zeros(B, 2, device=device, dtype=self.dtype)
for i, b in enumerate(sample_w_coord.tolist()):
logits_b = coord_logits[i].clone() # (2, num_bins)
existing_coords = [
item for item in aux_output_B[b]
if isinstance(item, dict) and "x" in item and "y" in item
]
pred_x, pred_y = 0.0, 0.0
for _ in range(max_coord_attempts):
pred_bins = torch.argmax(logits_b, dim=-1) # (2,)
pred_x = pred_bins[0].item() / (num_bins - 1)
pred_y = pred_bins[1].item() / (num_bins - 1)
is_repeat = any(
abs(ec["x"] - pred_x) < coord_repeat_threshold
and abs(ec["y"] - pred_y) < coord_repeat_threshold
for ec in existing_coords
)
if not is_repeat:
break
logits_b[0, pred_bins[0]] = float("-inf")
logits_b[1, pred_bins[1]] = float("-inf")
xy_b2[b, 0] = pred_x
xy_b2[b, 1] = pred_y
aux_output_B[b].append({"x": pred_x, "y": pred_y})
# Decode sizes
size_logits = self.decode_sizes(h_BSD[:, -1:], tokens_B1)
hw_b2 = self.process_sizes(size_logits)
size_preds = [{"h": hw[0].item(), "w": hw[1].item()} for hw in hw_b2]
sample_w_size = torch.where(tokens_B1 == self.config.size_token_id)[0]
for i, b in enumerate(sample_w_size.tolist()):
aux_output_B[b].append(size_preds[i])
# Decode segmentation
sample_w_segm = torch.where(tokens_B1 == self.config.seg_token_id)[0]
segm_tokens = h_BSD[sample_w_segm, -1, :]
segm_tokens = self.proj_segm(segm_tokens)
segm_masks = torch.einsum("kdhw,kd->khw", hr_img_features[sample_w_segm], segm_tokens)
for i, b in enumerate(sample_w_segm):
aux_output_B[b].append(segm_masks[i])
# Next step
logits_BSV, h_BSD = self.forward(
tokens=tokens_B1, attention_mask=attention_mask,
coord_xy=xy_b2.to(self.dtype), size_hw=hw_b2.to(self.dtype), kv_cache=kv_cache,
)
hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
should_stop_B = should_stop_B.logical_or(hit_stop_B)
# Post-process: convert aux outputs to structured results with RLE masks
pixel_mask_batch = batch_inputs["pixel_mask"][:, 0] # (B, H, W)
results = []
for b in range(B):
dets = self._postprocess_aux(
aux_output_B[b], pixel_mask_batch[b], original_sizes[b], segm_threshold,
)
results.append(dets)
return results
# -- Post-processing helpers ---------------------------------------------
def _extract_coords(self, coords_BO: list[list]):
all_xy, all_hw = [], []
for coords_O in coords_BO:
if not coords_O:
continue
for coords in coords_O:
for k, v in coords.items():
if k.startswith(("x", "y")):
all_xy.append(v)
elif k.startswith(("h", "w")):
all_hw.append(v)
return torch.tensor(all_xy), torch.tensor(all_hw)
@staticmethod
def _mask_nms(
binary_masks: list[torch.Tensor],
iou_threshold: float = 0.6,
nms_max_side: int = 256,
) -> list[int]:
"""
Fast vectorised mask NMS on binary (H, W) tensors.
Returns the list of kept indices ordered by descending mask score.
The IoU matrix is computed via a single batched matmul; suppression
uses one GPU boolean op per kept mask — no .item() in the inner loop.
"""
N = len(binary_masks)
if N <= 1:
return list(range(N))
device = binary_masks[0].device
base_h, base_w = binary_masks[0].shape
scale = min(1.0, nms_max_side / max(base_h, base_w))
th = max(1, int(round(base_h * scale)))
tw = max(1, int(round(base_w * scale)))
resized = []
for m in binary_masks:
m = m.float()
if m.shape != (th, tw):
m = F.interpolate(
m[None, None], size=(th, tw), mode="bilinear", align_corners=False
).squeeze()
resized.append(m)
binary = torch.stack(resized) # (N, th, tw)
flat = binary.view(N, -1) # (N, th*tw)
areas = flat.sum(dim=1) # (N,)
scores = areas # larger mask = higher priority
intersection = flat @ flat.T # (N, N)
union = areas[:, None] + areas[None, :] - intersection
iou = intersection / union.clamp(min=1)
order = scores.argsort(descending=True)
suppressed = torch.zeros(N, dtype=torch.bool, device=device)
keep = []
for idx in order.tolist():
if suppressed[idx]:
continue
keep.append(idx)
suppressed |= iou[idx] > iou_threshold
return keep
def _postprocess_aux(
self,
aux_list: list,
pixel_mask_hw: T,
orig_hw: tuple[int, int],
threshold: float,
nms_iou_threshold: float = 0.6,
) -> list[dict]:
"""Convert raw aux outputs into structured detections with RLE masks."""
orig_h, orig_w = orig_hw
# Find active image region from pixel mask
nonzero = torch.nonzero(pixel_mask_hw, as_tuple=False)
if len(nonzero) > 0:
min_h, min_w = nonzero.min(dim=0)[0]
max_h, max_w = nonzero.max(dim=0)[0]
act_h = (max_h - min_h + 1).item()
act_w = (max_w - min_w + 1).item()
else:
min_h = min_w = 0
act_h = act_w = None
# Group into triplets: coord, size, mask — build binary masks first
candidates = []
step = 3 # coord, size, mask
for i in range(0, len(aux_list), step):
if i + 2 >= len(aux_list):
break
xy = aux_list[i]
hw = aux_list[i + 1]
mask_logits = aux_list[i + 2]
if not isinstance(mask_logits, torch.Tensor):
continue
# Crop to active region
if act_h is not None and act_w is not None:
mask_logits = mask_logits[min_h:min_h + act_h, min_w:min_w + act_w]
# Resize to original image size
mask_logits = mask_logits.unsqueeze(0).unsqueeze(0).float()
mask_logits = F.interpolate(mask_logits, size=(orig_h, orig_w), mode="bilinear", align_corners=False)
mask_logits = mask_logits.squeeze(0).squeeze(0)
# Threshold
binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
candidates.append({"xy": xy, "hw": hw, "binary_mask": binary_mask})
if not candidates:
return []
# NMS on binary masks before RLE encoding
keep_indices = self._mask_nms(
[c["binary_mask"] for c in candidates],
iou_threshold=nms_iou_threshold,
)
candidates = [candidates[i] for i in keep_indices]
# Encode survivors as COCO RLE
detections = []
for c in candidates:
rle_list = self._mask_to_coco_rle(c["binary_mask"].unsqueeze(0))
mask_rle = rle_list[0] if rle_list else {"counts": "", "size": [orig_h, orig_w]}
detections.append({"xy": c["xy"], "hw": c["hw"], "mask_rle": mask_rle})
return detections