File size: 38,763 Bytes
4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 1f8827e 4f2517b 9abbbe0 4f2517b 9abbbe0 4f2517b 9abbbe0 4f2517b ad720cd 4f2517b ad720cd 4f2517b ad720cd 4f2517b ad720cd 4f2517b ad720cd 4f2517b ad720cd 4f2517b 9abbbe0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 | import math
from pathlib import Path
import einops as E
import numpy as np
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from PIL import Image
from pycocotools import mask as mask_utils
from torch import Tensor as T
from torch import nn
from torch.nn.attention.flex_attention import (
AuxRequest,
BlockMask,
)
from transformers import AutoTokenizer, PreTrainedModel
from .anyup import AnyUp, get_attention_mask_mod as get_upsampler_attn_mask_mod
from .attention import (
compiled_flex_attn_decode,
compiled_flex_attn_prefill,
create_attention_mask,
create_batch_attention_mask,
offset_mask_mod,
)
from .configuration_falcon_perception import FalconPerceptionConfig
from .processing_falcon_perception import load_image, process_batch
from .rope import (
apply_3d_rotary_emb,
apply_golden_freqs_cis_to_visual_pos,
precompute_freqs_cis,
)
# ---------------------------------------------------------------------------
# Sub-modules: Heads
# ---------------------------------------------------------------------------
class FourierEncoder(nn.Module):
def __init__(self, in_dim: int, feat_dim: int, out_dim: int):
super().__init__()
self.embed = nn.Linear(in_dim, feat_dim // 2, bias=False)
self.transform = nn.Linear(feat_dim, out_dim, bias=False)
def forward(self, x):
f = 2 * math.pi * self.embed(x)
f = torch.cat([f.cos(), f.sin()], dim=-1)
return self.transform(f)
class BboxDecoder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int) -> None:
super().__init__()
self.w1 = nn.Linear(in_dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, out_dim, bias=False)
def forward(self, x: T) -> T:
return self.w2(F.relu(self.w1(x)).square())
class SegmDecoder(nn.Module):
def __init__(self, in_dim: int, out_dim: int, num_layers: int) -> None:
super().__init__()
self.layers = nn.ModuleList([nn.Linear(in_dim, in_dim) for _ in range(num_layers - 1)])
self.pixel_layer = nn.Linear(in_dim, out_dim, bias=False)
def forward(self, x) -> torch.Tensor:
for layer in self.layers:
x = F.relu(layer(x)).square()
return self.pixel_layer(x)
# ---------------------------------------------------------------------------
# Sub-modules: Attention
# ---------------------------------------------------------------------------
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
B, S, H, D = x.shape
if n_rep == 1:
return x
return torch.unsqueeze(x, dim=3).expand(B, S, H, n_rep, D).reshape(B, S, H * n_rep, D)
class Attention(nn.Module):
def __init__(self, config: FalconPerceptionConfig, layer_id: int):
super().__init__()
self.layer_id = layer_id
self.n_kv_heads = config.n_kv_heads or config.n_heads
self.n_rep = config.n_heads // self.n_kv_heads
self.head_dim = config.head_dim or config.dim // config.n_heads
self.q_dim = config.n_heads * self.head_dim
self.kv_dim = self.n_kv_heads * self.head_dim
self.wqkv = nn.Linear(config.dim, self.q_dim + 2 * self.kv_dim, bias=False)
self.wo = nn.Linear(config.n_heads * self.head_dim, config.dim, bias=False)
self.sinks = nn.Parameter(torch.empty((config.n_heads,)))
def _pre_attention_qkv(self, x) -> tuple[T, T, T]:
qkv = self.wqkv(F.rms_norm(x, (x.size(-1),)))
xq, xk, xv = qkv.split([self.q_dim, self.kv_dim, self.kv_dim], dim=-1)
xq = E.rearrange(xq, "b s (h d) -> b s h d", d=self.head_dim)
xk = E.rearrange(xk, "b s (h d) -> b s h d", d=self.head_dim)
xv = E.rearrange(xv, "b s (h d) -> b s h d", d=self.head_dim)
xq = F.rms_norm(xq, (xq.size(-1),))
xk = F.rms_norm(xk, (xk.size(-1),))
xk = repeat_kv(xk, n_rep=self.n_rep)
xv = repeat_kv(xv, n_rep=self.n_rep)
return xq, xk, xv
def _post_attention(self, output: T, lse: T) -> T:
sinks_BHS = self.sinks.view(1, -1, 1)
sink_scale = torch.sigmoid(lse - sinks_BHS)
output = (output * sink_scale.unsqueeze(-1)).to(output.dtype)
output = output.permute(0, 2, 1, 3).contiguous().flatten(2)
return self.wo(output)
def compile_attention(self, *, dynamic: bool = True, mode: str = "default"):
self._pre_attention_qkv = torch.compile(self._pre_attention_qkv, dynamic=dynamic, mode=mode)
self._post_attention = torch.compile(self._post_attention, dynamic=dynamic, mode=mode)
def forward(
self, x: T, attention_masks: BlockMask, freqs_cis: T,
freqs_cis_2d: T | None = None, pos_hw: T | None = None,
kv_cache=None, input_pos=None, batch_idx=None,
flex_attn_kernel_options=None,
):
xq, xk, xv = self._pre_attention_qkv(x)
xq, xk = apply_3d_rotary_emb(xq, xk, freqs_cis, freqs_cis_2d, pos_hw)
xq = E.rearrange(xq, "b s h d -> b h s d")
xk = E.rearrange(xk, "b s h d -> b h s d")
xv = E.rearrange(xv, "b s h d -> b h s d")
xk, xv = kv_cache.insert_kv(self.layer_id, xk, xv, input_pos=input_pos, batch_idx=batch_idx)
flex_fn = compiled_flex_attn_decode if xq.shape[2] == 1 else compiled_flex_attn_prefill
output, aux_output = flex_fn(xq, xk, xv, block_mask=attention_masks, return_aux=AuxRequest(lse=True))
return self._post_attention(output, aux_output.lse)
# ---------------------------------------------------------------------------
# Sub-modules: FeedForward
# ---------------------------------------------------------------------------
@triton.jit
def _squared_relu_gate_kernel(
packed_ptr, out_ptr, n_rows, n_cols,
in_row_stride, in_col_stride, out_row_stride, out_col_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
n_elements = n_rows * n_cols
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
rows = offsets // n_cols
cols = offsets % n_cols
gate_idx = rows * in_row_stride + (2 * cols) * in_col_stride
up_idx = rows * in_row_stride + (2 * cols + 1) * in_col_stride
out_idx = rows * out_row_stride + cols * out_col_stride
gate = tl.load(packed_ptr + gate_idx, mask=mask)
up = tl.load(packed_ptr + up_idx, mask=mask)
gate = tl.where(gate > 0, gate, 0.0)
out = gate * gate * up
tl.store(out_ptr + out_idx, out, mask=mask)
def squared_relu_gate(packed: T, hidden_dim: int) -> T:
packed_2d = packed.flatten(0, -2)
n_rows = packed_2d.shape[0]
n_cols = hidden_dim
out_2d = torch.empty((n_rows, n_cols), device=packed.device, dtype=packed.dtype)
n = n_rows * n_cols
grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
_squared_relu_gate_kernel[grid](
packed_2d, out_2d, n_rows, n_cols,
packed_2d.stride(0), packed_2d.stride(1),
out_2d.stride(0), out_2d.stride(1),
BLOCK_SIZE=1024,
)
return out_2d.view(*packed.shape[:-1], hidden_dim)
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.hidden_dim = hidden_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.rms_norm(x, (x.size(-1),))
w13_out = self.w13(x)
return self.w2(squared_relu_gate(w13_out, self.hidden_dim))
# ---------------------------------------------------------------------------
# Sub-modules: TransformerBlock
# ---------------------------------------------------------------------------
class TransformerBlock(nn.Module):
def __init__(self, layer_id: int, config: FalconPerceptionConfig):
super().__init__()
self.attention = Attention(config, layer_id)
self.feed_forward = FeedForward(config.dim, config.ffn_dim)
def compile(self, *, dynamic: bool = True, mode: str = "default"):
self.feed_forward = torch.compile(self.feed_forward, dynamic=dynamic, mode=mode)
self.attention.compile_attention(dynamic=dynamic, mode=mode)
return self
def forward(
self, x: T, freqs_cis: T, freqs_cis_2d: T | None = None,
pos_hw: T | None = None, attention_masks=None, kv_cache=None,
input_pos=None, batch_idx=None, flex_attn_kernel_options=None,
):
B, S, D = x.shape
x = x + self.attention(
x, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_2d, pos_hw=pos_hw,
attention_masks=attention_masks, kv_cache=kv_cache,
input_pos=input_pos, batch_idx=batch_idx,
flex_attn_kernel_options=flex_attn_kernel_options,
)
out = x + self.feed_forward(x)
return out.reshape(B, S, D)
# ---------------------------------------------------------------------------
# KV Cache
# ---------------------------------------------------------------------------
class KVCache:
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, num_layers):
self.kv_shape = (num_layers, 2, max_batch_size, n_heads, max_seq_length, head_dim)
self.kv_cache = None
self.pos = 0
self.pos_t: T | None = None
def reset(self):
self.pos = 0
self.pos_t = None
def get_pos(self):
return self.pos
def set_pos_t(self, pos_t):
self.pos_t = pos_t
def increment_and_get_pos_t(self):
assert self.pos_t is not None
self.pos_t += 1
return self.pos_t
def insert_kv(self, layer_id: int, k: T, v: T, **kwargs):
del kwargs
assert self.pos_t is not None
if self.kv_cache is None:
self.kv_cache = torch.empty(self.kv_shape, dtype=k.dtype, device=k.device)
B, H, T_add, D = k.size()
t0, t1 = self.pos, self.pos + T_add
self.kv_cache[layer_id, 0, :, :, t0:t1] = k
self.kv_cache[layer_id, 1, :, :, t0:t1] = v
key_view = self.kv_cache[layer_id, 0, :, :, :t1]
value_view = self.kv_cache[layer_id, 1, :, :, :t1]
if layer_id == self.kv_cache.size(0) - 1:
self.pos = t1
return key_view, value_view
# ---------------------------------------------------------------------------
# Sampling
# ---------------------------------------------------------------------------
@torch.inference_mode()
def sample_next_token(logits, rng, temperature=0.0, top_k=None):
assert temperature >= 0.0
if temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
if top_k is not None:
k = min(top_k, logits.size(-1))
vals, idx = torch.topk(logits, k, dim=-1)
vals = vals / temperature
probs = F.softmax(vals, dim=-1)
choice = torch.multinomial(probs, num_samples=1, generator=rng)
return idx.gather(1, choice)
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=rng)
# ---------------------------------------------------------------------------
# Main Model
# ---------------------------------------------------------------------------
class FalconPerceptionForSegmentation(PreTrainedModel):
config_class = FalconPerceptionConfig
_no_split_modules = ["TransformerBlock"]
def __init__(self, config: FalconPerceptionConfig):
super().__init__(config)
img_in_dim = config.temporal_patch_size * config.spatial_patch_size ** 2 * config.channel_size
self.img_projector = nn.Linear(img_in_dim, config.dim, bias=False)
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
self.layers = nn.ModuleDict()
for layer_id in range(config.n_layers):
self.layers[str(layer_id)] = TransformerBlock(layer_id, config)
self.norm = nn.RMSNorm(config.dim, eps=config.norm_eps)
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
self.coord_encoder = FourierEncoder(2, config.coord_enc_dim, config.dim)
self.coord_decoder = BboxDecoder(config.dim, config.coord_dec_dim, config.coord_out_dim)
self.size_encoder = FourierEncoder(2, config.size_enc_dim, config.dim)
self.size_decoder = BboxDecoder(config.dim, config.size_dec_dim, config.size_out_dim)
if config.do_segmentation:
self.itok_upsampler = AnyUp()
self.proj_segm = SegmDecoder(config.dim, config.segm_out_dim, config.num_segm_layers)
self.conv_segm = nn.Conv2d(config.dim, config.segm_out_dim, kernel_size=3, padding=1)
rope_dim = config.head_dim // 2
freqs_cis = precompute_freqs_cis(rope_dim, config.max_seq_len, config.rope_theta)
freqs_cis_golden = torch.empty((config.n_heads, rope_dim // 2, 2), dtype=torch.float)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
self.register_buffer("freqs_cis_golden", freqs_cis_golden, persistent=True)
self._weights_fused = False
self._is_compiled = False
self.post_init()
# -- Weight management ---------------------------------------------------
def _ensure_device_buffers(self):
"""Recompute non-persistent buffers that HF meta-device loading may discard."""
if self._weights_fused:
return
device = self.tok_embeddings.weight.device
c = self.config
rope_dim = c.head_dim // 2
freqs_cis = precompute_freqs_cis(rope_dim, c.max_seq_len, c.rope_theta).to(device)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
if self.freqs_cis_golden.device != device:
self.freqs_cis_golden = self.freqs_cis_golden.to(device)
self._weights_fused = True
def compile_model(self):
if self._is_compiled:
return
torch._inductor.config.triton.cudagraphs = False
for layer in self.layers.values():
layer.compile(dynamic=True, mode="default")
self.coord_encoder = torch.compile(self.coord_encoder, dynamic=True, mode="default")
self.coord_decoder = torch.compile(self.coord_decoder, dynamic=True, mode="default")
self.size_encoder = torch.compile(self.size_encoder, dynamic=True, mode="default")
self.size_decoder = torch.compile(self.size_decoder, dynamic=True, mode="default")
if self.config.do_segmentation:
self.itok_upsampler.compile(mode="default", dynamic=True)
self._is_compiled = True
# -- Tokenizer -----------------------------------------------------------
def _get_tokenizer(self):
if not hasattr(self, "_tokenizer"):
import os
path = self.config._name_or_path
is_local = os.path.exists(path)
self._tokenizer = AutoTokenizer.from_pretrained(path, local_files_only=is_local, trust_remote_code=True)
for token_name, token in self._tokenizer.special_tokens_map.items():
if isinstance(token, str):
setattr(self._tokenizer, token_name, token)
setattr(
self._tokenizer, token_name + "_id",
self._tokenizer.convert_tokens_to_ids(token),
)
return self._tokenizer
# -- Attention mask ------------------------------------------------------
def get_attention_mask(self, input_batch: T, max_len: int | None = None):
return create_batch_attention_mask(
input_batch,
pad_token_id=self._pad_token_id,
eos_token_id=self.config.eos_id,
soi_token_id=self.config.image_cls_token_id,
eoi_token_id=self.config.img_end_id,
max_len=max_len,
)
def get_upsampler_attn_mask(self, H, W, h, w, device):
return create_attention_mask(
get_upsampler_attn_mask_mod(H, W, h, w, device=device),
B=None, H=None, Q_LEN=H * W, KV_LEN=h * w,
)
# -- Embedding helpers ---------------------------------------------------
def _scatter_img_tokens_with_projector(self, h_BSD, pixel_patches_NLC, pixel_masks_NTHW, tokens_BS):
B, S, D = h_BSD.shape
pixel_patch_mask = E.reduce(
pixel_masks_NTHW,
"n (t pt) (h ph) (w pw) -> (n t h w)",
reduction="any",
pt=self.config.temporal_patch_size,
ph=self.config.spatial_patch_size,
pw=self.config.spatial_patch_size,
)
pixel_patches_flat = E.rearrange(pixel_patches_NLC, "n p c -> (n p) c")
valid_patches = pixel_patches_flat[pixel_patch_mask]
valid_feats = self.img_projector(valid_patches)
img_mask_h_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D)
assert valid_feats.numel() == img_mask_h_BSD.sum()
return torch.masked_scatter(h_BSD, img_mask_h_BSD, valid_feats)
def _encode_coords(self, h_BSD: T, tokens_BS: T, all_xy: T):
coord_tokens_mask = tokens_BS == self.config.coord_token_id
if all_xy.numel() == 0:
return h_BSD
coord_tokens = self.coord_encoder(all_xy.reshape(-1, 2))
if coord_tokens.shape[0] == h_BSD.shape[0]:
h_BSD = torch.where(
coord_tokens_mask.unsqueeze(-1),
coord_tokens.view(h_BSD.shape[0], -1, h_BSD.shape[-1]),
h_BSD,
)
else:
h_BSD = h_BSD.masked_scatter_(coord_tokens_mask.unsqueeze(-1), coord_tokens)
return h_BSD
def _encode_sizes(self, h_BSD, tokens_BS, all_hw: T):
size_tokens_mask = tokens_BS == self.config.size_token_id
if all_hw.numel() == 0:
return h_BSD
size_tokens = self.size_encoder(all_hw.reshape(-1, 2))
if size_tokens.shape[0] == h_BSD.shape[0]:
h_BSD = torch.where(
size_tokens_mask.unsqueeze(-1),
size_tokens.view(h_BSD.shape[0], -1, h_BSD.shape[-1]),
h_BSD,
)
else:
h_BSD = h_BSD.masked_scatter_(size_tokens_mask.unsqueeze(-1), size_tokens)
return h_BSD
def decode_coords(self, h_BSD, labels):
B, S, D = h_BSD.shape
coord_masks = labels == self.config.coord_token_id
coord_tokens = torch.masked_select(h_BSD, coord_masks.unsqueeze(-1))
coord_logits = self.coord_decoder(coord_tokens.reshape(-1, D))
return E.rearrange(coord_logits, "b (two dim) -> b two dim", two=2)
def decode_sizes(self, h_BSD, labels):
B, S, D = h_BSD.shape
size_masks = labels == self.config.size_token_id
size_tokens = torch.masked_select(h_BSD, size_masks.unsqueeze(-1))
size_logits = self.size_decoder(size_tokens.reshape(-1, D))
return E.rearrange(size_logits, "b (two dim) -> b two dim", two=2)
def process_sizes(self, logits):
num_bins = logits.shape[-1]
pred = torch.argmax(logits, dim=-1).float() / (num_bins - 1)
min_size = torch.log2(torch.tensor(1 / num_bins))
max_size = 0.0
pred = pred * (max_size - min_size) + min_size
return torch.pow(2.0, pred)
# -- Segmentation -------------------------------------------------------
def gather_img_tokens(self, h_BSD: T, tokens_BS: T, itok_masks_NTHW: T):
B, S, D = h_BSD.shape
itok_masks_BSD = E.repeat(tokens_BS == self.config.img_id, "b s -> b s d", d=D)
itok_flatten = torch.masked_select(h_BSD, itok_masks_BSD)
itok_masks_NTHWD = E.repeat(itok_masks_NTHW, "n t h w -> n t h w d", d=D)
itok_NTHWD = torch.zeros_like(itok_masks_NTHWD, dtype=h_BSD.dtype, device=h_BSD.device)
itok_NTHWD = itok_NTHWD.masked_scatter_(itok_masks_NTHWD, itok_flatten)
return itok_NTHWD
def upsample_img_features(self, h_BSD: T, tokens_BS: T, pixel_values_NTHWC: T, pixel_mask_NTHW: T):
device = h_BSD.device
c = self.config
itok_masks_NTHW = E.reduce(
pixel_mask_NTHW,
"n (t pt) (h ph) (w pw) -> n t h w",
reduction="any",
pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size,
)
N, _, h, w = itok_masks_NTHW.shape
_, _, H, W = pixel_mask_NTHW.shape
images = E.rearrange(pixel_values_NTHWC, "n 1 h w c -> n c h w")
lr_img_features = self.gather_img_tokens(h_BSD, tokens_BS, itok_masks_NTHW)
lr_img_features = E.rearrange(lr_img_features, "n 1 h w d -> n d h w")
lr_img_features = self.conv_segm(lr_img_features)
upsampler_attn_mask = self.get_upsampler_attn_mask(H, W, h, w, device=device)
hr_parts = []
for i in range(N):
hr_i = self.itok_upsampler(
images=images[i:i + 1], features=lr_img_features[i:i + 1], attn_mask=upsampler_attn_mask,
)
hr_parts.append(hr_i)
return torch.cat(hr_parts, dim=0) if N > 1 else hr_parts[0]
@staticmethod
def _mask_to_coco_rle(binary_masks: torch.Tensor) -> list[dict]:
C, H, W = binary_masks.shape
has_any = E.reduce(binary_masks, "c h w -> c", reduction="any")
binary_col = E.rearrange(binary_masks, "c h w -> c (w h)")
diffs = binary_col[:, 1:] != binary_col[:, :-1]
nz = torch.nonzero(diffs, as_tuple=False)
first_vals = binary_col[:, 0]
nz_cpu = nz.cpu().numpy()
has_any_cpu = has_any.cpu().numpy()
first_vals_cpu = first_vals.cpu().numpy()
del diffs, nz, binary_col, first_vals, has_any
N_px = H * W
if nz_cpu.shape[0] > 0:
mask_ids = nz_cpu[:, 0]
change_cols = nz_cpu[:, 1]
uniq, grp_starts = np.unique(mask_ids, return_index=True)
grp_ends = np.append(grp_starts[1:], len(mask_ids))
mask_to_grp = {int(m): (int(gs), int(ge)) for m, gs, ge in zip(uniq, grp_starts, grp_ends)}
else:
change_cols = np.array([], dtype=np.intp)
mask_to_grp = {}
results = []
for i in range(C):
if not has_any_cpu[i]:
continue
if i in mask_to_grp:
gs, ge = mask_to_grp[i]
cidx = change_cols[gs:ge]
else:
cidx = np.array([], dtype=np.intp)
num_runs = len(cidx) + 1
starts = np.empty(num_runs, dtype=np.intp)
starts[0] = 0
if len(cidx) > 0:
starts[1:] = cidx + 1
counts = np.empty(num_runs, dtype=np.uint32)
if num_runs > 1:
counts[:-1] = np.diff(starts)
counts[-1] = N_px - starts[-1]
if first_vals_cpu[i]:
counts = np.concatenate([[0], counts])
rle = {"counts": counts.tolist(), "size": [H, W]}
rle = mask_utils.frPyObjects(rle, H, W)
rle["counts"] = rle["counts"].decode("utf-8")
results.append(rle)
return results
# -- Core forward --------------------------------------------------------
def forward(
self,
tokens: T,
attention_mask: BlockMask,
kv_cache,
rope_pos_t: T | None = None,
rope_pos_hw: T | None = None,
pixel_values: T | None = None,
pixel_mask: T | None = None,
coord_xy: T | None = None,
size_hw: T | None = None,
):
B, S = tokens.size()
c = self.config
block_mask = attention_mask
T_pos = kv_cache.get_pos()
is_prefill = S != 1
if is_prefill:
assert rope_pos_t is not None and rope_pos_hw is not None
pos_t = rope_pos_t[:, T_pos:T_pos + S].long()
kv_cache.pos_t = pos_t[:, -1:]
freqs_cis = self.freqs_cis[pos_t]
rope_pos_hw = rope_pos_hw[:, T_pos:T_pos + S]
freqs_cis_golden = apply_golden_freqs_cis_to_visual_pos(self.freqs_cis_golden, rope_pos_hw)
block_mask.seq_lengths = (S, S)
else:
pos_t = kv_cache.increment_and_get_pos_t()
freqs_cis = self.freqs_cis[pos_t]
freqs_cis_golden = None
block_idx = T_pos // block_mask.BLOCK_SIZE[0]
block_mask = block_mask[:, :, block_idx]
block_mask.seq_lengths = (S, T_pos + S)
block_mask.mask_mod = offset_mask_mod(attention_mask.mask_mod, offset=T_pos)
h_BSD = self.tok_embeddings(tokens)
coord_xy = coord_xy if coord_xy is not None else h_BSD.new_empty(0)
size_hw = size_hw if size_hw is not None else h_BSD.new_empty(0)
h_BSD = self._encode_coords(h_BSD, tokens, coord_xy)
h_BSD = self._encode_sizes(h_BSD, tokens, size_hw)
if pixel_values is not None:
assert pixel_mask is not None
pixel_values = pixel_values.to(self.dtype)
pixel_mask = pixel_mask.to(self.dtype)
pixel_patches_NLC = E.rearrange(
pixel_values,
"n (t pt) (h ph) (w pw) c -> n (t h w) (pt ph pw c)",
pt=c.temporal_patch_size, ph=c.spatial_patch_size, pw=c.spatial_patch_size,
)
h_BSD = self._scatter_img_tokens_with_projector(h_BSD, pixel_patches_NLC, pixel_mask, tokens)
for layer in self.layers.values():
h_BSD = layer(
h_BSD, freqs_cis=freqs_cis, freqs_cis_2d=freqs_cis_golden,
pos_hw=rope_pos_hw, attention_masks=block_mask, kv_cache=kv_cache,
)
h_BSD = self.norm(h_BSD)
logits_BSV = self.output(h_BSD)
return logits_BSV, h_BSD
# -- Main API: generate --------------------------------------------------
@torch.inference_mode()
def generate(
self,
images,
queries,
max_new_tokens: int = 2048,
temperature: float = 0.0,
top_k: int | None = None,
min_dimension: int = 256,
max_dimension: int = 1024,
compile: bool = True,
seed: int | None = 42,
segm_threshold: float = 0.5,
) -> list[list[dict]]:
"""
Segment objects in images matching the given queries.
Args:
images: Single PIL Image (or path/URL) or list of them.
queries: Single query string or list of query strings (one per image).
max_new_tokens: Maximum generation steps.
temperature: Sampling temperature (0.0 = greedy).
top_k: Top-k sampling (None = disabled).
min_dimension: Min image side after resize.
max_dimension: Max image side after resize.
compile: Whether to torch.compile on first call.
seed: Random seed for reproducibility (None = non-deterministic).
segm_threshold: Sigmoid threshold for binary mask.
Returns:
List (per image) of lists (per detection) of dicts::
{
"xy": {"x": float, "y": float},
"hw": {"h": float, "w": float},
"mask_rle": {"counts": str, "size": [H, W]},
}
"""
self._ensure_device_buffers()
if compile:
self.compile_model()
# Normalize inputs
if isinstance(images, (str, Path, Image.Image)):
images = [images]
if isinstance(queries, str):
queries = [queries]
assert len(images) == len(queries), "Must provide one query per image"
device = self.device
tokenizer = self._get_tokenizer()
self._pad_token_id = tokenizer.convert_tokens_to_ids("<|pad|>")
stop_token_ids = [self.config.eos_id, tokenizer.convert_tokens_to_ids("<|end_of_query|>")]
# Store original image sizes for mask resizing
pil_images = [load_image(img).convert("RGB") for img in images]
original_sizes = [(img.height, img.width) for img in pil_images]
# Build prompts
image_prompt_pairs = [
(img, f"<|image|>Segment these expressions in the image:<|start_of_query|>{q}<|REF_SEG|>")
for img, q in zip(pil_images, queries)
]
# Preprocess
batch_inputs = process_batch(
tokenizer, self.config, image_prompt_pairs,
max_length=4096, min_dimension=min_dimension, max_dimension=max_dimension,
)
batch_inputs = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch_inputs.items()}
tokens = batch_inputs["tokens"]
B, L = tokens.size()
block_size = 128
S = (L + max_new_tokens + block_size - 1) // block_size * block_size
assert S <= self.config.max_seq_len
rng = torch.Generator(device).manual_seed(seed) if seed is not None else None
kv_cache = KVCache(
max_batch_size=B, max_seq_length=S, n_heads=self.config.n_heads,
head_dim=self.config.head_dim, num_layers=self.config.n_layers,
)
padded_tokens = torch.full((B, S), self._pad_token_id, dtype=tokens.dtype, device=device)
padded_tokens[:, :L] = tokens
attention_mask = self.get_attention_mask(padded_tokens, max_len=S)
all_xy, all_hw = self._extract_coords([[]])
coord_xy = all_xy.to(device=device, dtype=self.dtype)
size_hw_t = all_hw.to(device=device, dtype=self.dtype)
# Prefill
logits_BSV, h_BSD = self.forward(
tokens=tokens, rope_pos_t=batch_inputs["pos_t"], rope_pos_hw=batch_inputs["pos_hw"],
attention_mask=attention_mask, kv_cache=kv_cache,
pixel_values=batch_inputs["pixel_values"], pixel_mask=batch_inputs["pixel_mask"],
coord_xy=coord_xy, size_hw=size_hw_t,
)
hr_img_features = self.upsample_img_features(
h_BSD, tokens, batch_inputs["pixel_values"], batch_inputs["pixel_mask"],
)
aux_output_B = [[] for _ in range(B)]
stop_ids = torch.tensor(stop_token_ids).to(device)
should_stop_B = torch.full((B,), False, dtype=torch.bool, device=device)
# Decode loop
while not torch.all(should_stop_B) and (pos := kv_cache.get_pos()) < S:
tokens_B1 = sample_next_token(logits_BSV[:, -1], rng, temperature, top_k)
if torch.any(should_stop_B):
tokens_B1 = tokens_B1.clone()
tokens_B1[should_stop_B, :] = self._pad_token_id
padded_tokens[:, pos] = tokens_B1[:, -1]
# Decode coords (with deduplication to avoid repeating the same location)
coord_logits = self.decode_coords(h_BSD[:, -1:], tokens_B1)
sample_w_coord = torch.where(tokens_B1 == self.config.coord_token_id)[0]
num_bins = coord_logits.size(-1)
coord_repeat_threshold = 0.01 # coords within 1% of image size are considered duplicates
max_coord_attempts = 100
xy_b2 = torch.zeros(B, 2, device=device, dtype=self.dtype)
for i, b in enumerate(sample_w_coord.tolist()):
logits_b = coord_logits[i].clone() # (2, num_bins)
existing_coords = [
item for item in aux_output_B[b]
if isinstance(item, dict) and "x" in item and "y" in item
]
pred_x, pred_y = 0.0, 0.0
for _ in range(max_coord_attempts):
pred_bins = torch.argmax(logits_b, dim=-1) # (2,)
pred_x = pred_bins[0].item() / (num_bins - 1)
pred_y = pred_bins[1].item() / (num_bins - 1)
is_repeat = any(
abs(ec["x"] - pred_x) < coord_repeat_threshold
and abs(ec["y"] - pred_y) < coord_repeat_threshold
for ec in existing_coords
)
if not is_repeat:
break
logits_b[0, pred_bins[0]] = float("-inf")
logits_b[1, pred_bins[1]] = float("-inf")
xy_b2[b, 0] = pred_x
xy_b2[b, 1] = pred_y
aux_output_B[b].append({"x": pred_x, "y": pred_y})
# Decode sizes
size_logits = self.decode_sizes(h_BSD[:, -1:], tokens_B1)
hw_b2 = self.process_sizes(size_logits)
size_preds = [{"h": hw[0].item(), "w": hw[1].item()} for hw in hw_b2]
sample_w_size = torch.where(tokens_B1 == self.config.size_token_id)[0]
for i, b in enumerate(sample_w_size.tolist()):
aux_output_B[b].append(size_preds[i])
# Decode segmentation
sample_w_segm = torch.where(tokens_B1 == self.config.seg_token_id)[0]
segm_tokens = h_BSD[sample_w_segm, -1, :]
segm_tokens = self.proj_segm(segm_tokens)
segm_masks = torch.einsum("kdhw,kd->khw", hr_img_features[sample_w_segm], segm_tokens)
for i, b in enumerate(sample_w_segm):
aux_output_B[b].append(segm_masks[i])
# Next step
logits_BSV, h_BSD = self.forward(
tokens=tokens_B1, attention_mask=attention_mask,
coord_xy=xy_b2.to(self.dtype), size_hw=hw_b2.to(self.dtype), kv_cache=kv_cache,
)
hit_stop_B = torch.isin(tokens_B1, stop_ids).any(dim=-1)
should_stop_B = should_stop_B.logical_or(hit_stop_B)
# Post-process: convert aux outputs to structured results with RLE masks
pixel_mask_batch = batch_inputs["pixel_mask"][:, 0] # (B, H, W)
results = []
for b in range(B):
dets = self._postprocess_aux(
aux_output_B[b], pixel_mask_batch[b], original_sizes[b], segm_threshold,
)
results.append(dets)
return results
# -- Post-processing helpers ---------------------------------------------
def _extract_coords(self, coords_BO: list[list]):
all_xy, all_hw = [], []
for coords_O in coords_BO:
if not coords_O:
continue
for coords in coords_O:
for k, v in coords.items():
if k.startswith(("x", "y")):
all_xy.append(v)
elif k.startswith(("h", "w")):
all_hw.append(v)
return torch.tensor(all_xy), torch.tensor(all_hw)
@staticmethod
def _mask_nms(
binary_masks: list[torch.Tensor],
iou_threshold: float = 0.6,
nms_max_side: int = 256,
) -> list[int]:
"""
Fast vectorised mask NMS on binary (H, W) tensors.
Returns the list of kept indices ordered by descending mask score.
The IoU matrix is computed via a single batched matmul; suppression
uses one GPU boolean op per kept mask — no .item() in the inner loop.
"""
N = len(binary_masks)
if N <= 1:
return list(range(N))
device = binary_masks[0].device
base_h, base_w = binary_masks[0].shape
scale = min(1.0, nms_max_side / max(base_h, base_w))
th = max(1, int(round(base_h * scale)))
tw = max(1, int(round(base_w * scale)))
resized = []
for m in binary_masks:
m = m.float()
if m.shape != (th, tw):
m = F.interpolate(
m[None, None], size=(th, tw), mode="bilinear", align_corners=False
).squeeze()
resized.append(m)
binary = torch.stack(resized) # (N, th, tw)
flat = binary.view(N, -1) # (N, th*tw)
areas = flat.sum(dim=1) # (N,)
scores = areas # larger mask = higher priority
intersection = flat @ flat.T # (N, N)
union = areas[:, None] + areas[None, :] - intersection
iou = intersection / union.clamp(min=1)
order = scores.argsort(descending=True)
suppressed = torch.zeros(N, dtype=torch.bool, device=device)
keep = []
for idx in order.tolist():
if suppressed[idx]:
continue
keep.append(idx)
suppressed |= iou[idx] > iou_threshold
return keep
def _postprocess_aux(
self,
aux_list: list,
pixel_mask_hw: T,
orig_hw: tuple[int, int],
threshold: float,
nms_iou_threshold: float = 0.6,
) -> list[dict]:
"""Convert raw aux outputs into structured detections with RLE masks."""
orig_h, orig_w = orig_hw
# Find active image region from pixel mask
nonzero = torch.nonzero(pixel_mask_hw, as_tuple=False)
if len(nonzero) > 0:
min_h, min_w = nonzero.min(dim=0)[0]
max_h, max_w = nonzero.max(dim=0)[0]
act_h = (max_h - min_h + 1).item()
act_w = (max_w - min_w + 1).item()
else:
min_h = min_w = 0
act_h = act_w = None
# Group into triplets: coord, size, mask — build binary masks first
candidates = []
step = 3 # coord, size, mask
for i in range(0, len(aux_list), step):
if i + 2 >= len(aux_list):
break
xy = aux_list[i]
hw = aux_list[i + 1]
mask_logits = aux_list[i + 2]
if not isinstance(mask_logits, torch.Tensor):
continue
# Crop to active region
if act_h is not None and act_w is not None:
mask_logits = mask_logits[min_h:min_h + act_h, min_w:min_w + act_w]
# Resize to original image size
mask_logits = mask_logits.unsqueeze(0).unsqueeze(0).float()
mask_logits = F.interpolate(mask_logits, size=(orig_h, orig_w), mode="bilinear", align_corners=False)
mask_logits = mask_logits.squeeze(0).squeeze(0)
# Threshold
binary_mask = (torch.sigmoid(mask_logits) > threshold).bool()
candidates.append({"xy": xy, "hw": hw, "binary_mask": binary_mask})
if not candidates:
return []
# NMS on binary masks before RLE encoding
keep_indices = self._mask_nms(
[c["binary_mask"] for c in candidates],
iou_threshold=nms_iou_threshold,
)
candidates = [candidates[i] for i in keep_indices]
# Encode survivors as COCO RLE
detections = []
for c in candidates:
rle_list = self._mask_to_coco_rle(c["binary_mask"].unsqueeze(0))
mask_rle = rle_list[0] if rle_list else {"counts": "", "size": [orig_h, orig_w]}
detections.append({"xy": c["xy"], "hw": c["hw"], "mask_rle": mask_rle})
return detections |