diff --git "a/modular_isaac.py" "b/modular_isaac.py" --- "a/modular_isaac.py" +++ "b/modular_isaac.py" @@ -1,174 +1,399 @@ -from __future__ import annotations +# coding=utf-8 +# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. -from collections import defaultdict -from typing import Any, TypedDict +from __future__ import annotations +import copy import math -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import PIL.Image - - -from transformers import ( - AutoTokenizer, - BatchFeature, - Cache, - Qwen3Config, +from collections.abc import Callable, Sequence +from enum import IntEnum +from typing import Any, Optional, Union + +from transformers.cache_utils import DynamicCache +from transformers.configuration_utils import PretrainedConfig, layer_type_validation +from transformers.feature_extraction_utils import BatchFeature +from transformers.generation.utils import GenerationMixin +from transformers.image_processing_utils_fast import ( + BaseImageProcessorFast, + ImagesKwargs, + SizeDict, + group_images_by_shape, + reorder_images, +) +from transformers.image_utils import ( + PILImageResampling, +) +from transformers.masking_utils import ( + ALL_MASK_ATTENTION_FUNCTIONS, + create_masks_for_generate, + packed_sequence_mask_function, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, +) +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config +from transformers.models.qwen3.modeling_qwen3 import ( Qwen3ForCausalLM, + Qwen3Model, Qwen3PreTrainedModel, ) -from transformers.cache_utils import SlidingWindowCache, StaticCache -from transformers.generation.utils import GenerationMixin -from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from transformers.models.qwen3.modeling_qwen3 import Qwen3DecoderLayer, Qwen3Model -from transformers.models.qwen2.tokenization_qwen2 import Qwen2Tokenizer -from transformers.processing_utils import ProcessorMixin -from transformers.tokenization_utils import TensorType -from transformers.modeling_attn_mask_utils import AttentionMaskConverter -import re - -from transformers.models.siglip2.modeling_siglip2 import ( - Siglip2MLP, +from transformers.processing_utils import ProcessorMixin, Unpack +from transformers.utils import TensorType, auto_docstring +from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN +from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD +from transformers.utils.generic import ( + OutputRecorder, + TransformersKwargs, + can_return_tuple, + check_model_inputs, ) -from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig -from perceptron.tensorstream import ( - Event, - Stream, - TensorStream, - TextType, - VisionType, - create_stream, - group_streams, +from transformers.utils.import_utils import ( + is_torch_available, + is_torchdynamo_compiling, + is_torchvision_available, + is_vision_available, ) -from perceptron.tensorstream.ops import ( - compute_mrope_pos_tensor, - modality_mask, - reconstruct_tensor_stream_from_compact_dict, - slice as ts_slice, - tensor_stream_token_view, +from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling +from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig +from transformers.models.siglip2.modeling_siglip2 import ( + Siglip2Attention, + Siglip2Encoder, + Siglip2EncoderLayer, + Siglip2VisionEmbeddings, ) -class PixelShuffleSiglip2VisionConfig(Siglip2VisionConfig): +if is_torch_available(): + import torch + import torch.nn as nn + import torch.nn.functional as F +if is_vision_available(): + from PIL.Image import Image +else: + Image = None +if is_torchvision_available(): + from transformers.models.pix2struct.image_processing_pix2struct_fast import ( + torch_extract_patches, + ) + + +class ModalityType(IntEnum): + """ + Modality identifiers for events. + + Members: + image: Vision tokens (e.g., patches). + text: Textual tokens. + """ + + image = 0 + text = 1 + + +class IsaacVisionConfig(Siglip2VisionConfig): """Vision configuration for Isaac with Pixel Shuffle support. Extends Siglip2VisionConfig with additional fields for pixel shuffle. + + Args: + pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): + Spatial factor applied before pixel shuffle reduces the resolution. + num_patches (`int`, *optional*, defaults to 256): + Maximum number of learnable positional embeddings to initialize. """ - model_type = "pixel_shuffle_siglip2" + model_type = "isaac_vision" base_config_key = "vision_config" def __init__( self, - pixel_shuffle_scale_factor: int = 1, - num_patches: int = 256, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + num_patches=256, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + pixel_shuffle_scale_factor=1, **kwargs, ): super().__init__(**kwargs) + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.num_patches = num_patches + # Add our custom fields self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor - self.num_patches = num_patches + # Ensure a sensible default attention backend + if getattr(self, "_attn_implementation", None) is None: + self._attn_implementation = "sdpa" + + +class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): + patch_size: Optional[int] + max_num_patches: Optional[int] + min_num_patches: Optional[int] + pixel_shuffle_scale: Optional[int] + + +@auto_docstring +class IsaacImageProcessorFast(BaseImageProcessorFast): + MAX_PIXELS = 60_000_000 # 60‑megapixel ceiling ≈ 8200 × 7300 px + + resample = PILImageResampling.BILINEAR + model_input_names = ["patches", "token_grids"] + valid_kwargs = IsaacImageProcessorFastKwargs + unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] + + do_resize = True + do_center_crop = False + patch_size: Optional[int] = 16 + max_num_patches: Optional[int] = 256 + min_num_patches: Optional[int] = None + pixel_shuffle_scale: Optional[int] = 1 + do_pad = False + do_rescale = True + do_normalize = True + image_mean = list(VISION_MEAN) + image_std = list(VISION_STD) + do_convert_rgb = True + disable_grouping = False -def create_cumulative_seq_lengths(seq_sizes: torch.Tensor, device: torch.device) -> tuple[torch.Tensor, int]: - """Create cumulative sequence lengths for variable-length attention.""" - cu_seqlens = torch.zeros(len(seq_sizes) + 1, dtype=torch.int32, device=device) - cu_seqlens[1:] = seq_sizes.cumsum(0) - max_seqlen = int(seq_sizes.max().item()) if len(seq_sizes) > 0 else 0 - return cu_seqlens, max_seqlen - - -def _max_from_cu(cu: torch.Tensor | None, fallback: int) -> int: - """Helper to compute max sequence length from cumulative sequence lengths.""" - if cu is None or len(cu) < 2: - return fallback - return int((cu[1:] - cu[:-1]).max().item()) - - -def flash_attention_document_mask_forward( - q_lhd: torch.Tensor, # (L, H, D) - k_lhd: torch.Tensor, # (L, H, D) - v_lhd: torch.Tensor, # (L, H, D) - attention_mask: torch.Tensor | None = None, # unused for FA path - dropout: float = 0.0, - scaling: float | None = None, - cum_seq_q: torch.Tensor | None = None, - cum_seq_k: torch.Tensor | None = None, - max_seqlen: int | None = None, - is_causal: bool = False, - **kwargs, -) -> tuple[torch.Tensor, None]: - """FlashAttention that consumes (L, H, D) directly to avoid layout churn.""" - L, H, D = q_lhd.shape - - # Compute max block length once (honor caller when provided) - if max_seqlen is not None: - max_q = max_k = int(max_seqlen) - else: - max_q = _max_from_cu(cum_seq_q, L) - max_k = _max_from_cu(cum_seq_k, L) - - # Ensure contiguity only if needed - if not q_lhd.is_contiguous(): - q_lhd = q_lhd.contiguous() - if not k_lhd.is_contiguous(): - k_lhd = k_lhd.contiguous() - if not v_lhd.is_contiguous(): - v_lhd = v_lhd.contiguous() - - out_lhd, *_ = torch.ops.aten._flash_attention_forward( - query=q_lhd, # (L, H, D) - key=k_lhd, # (L, H, D) - value=v_lhd, # (L, H, D) - cum_seq_q=cum_seq_q, - cum_seq_k=cum_seq_k, - max_q=max_q, - max_k=max_k, - dropout_p=dropout, - is_causal=is_causal, - return_debug_mask=False, - scale=scaling, - window_size_left=-1, - window_size_right=-1, - alibi_slopes=None, - ) - return out_lhd, None # (L, H, D) + def __init__( + self, + **kwargs: Unpack[IsaacImageProcessorFastKwargs], + ) -> None: + super().__init__(**kwargs) + def _validate_preprocess_kwargs(self, **kwargs): + # Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for. + kwargs.pop("do_resize", None) + kwargs.pop("size", None) + kwargs.pop("do_center_crop", None) + kwargs.pop("crop_size", None) + kwargs.pop("disable_grouping", None) + return super()._validate_preprocess_kwargs(**kwargs) -def sdpa_document_mask_forward( - q_lhd: torch.Tensor, # (L, H, D) - k_lhd: torch.Tensor, # (L, H, D) - v_lhd: torch.Tensor, # (L, H, D) - dropout: float, - scaling: float | None, - cu_seqlens: torch.Tensor | None, -) -> torch.Tensor: - """SDPA with block-diagonal masking for variable-length sequences.""" - L, H, D = q_lhd.shape + def resize( + self, + image: torch.Tensor, + size: SizeDict, + **kwargs, + ) -> torch.Tensor: + resize_kwargs: dict[str, Any] = {"align_corners": False} + resize_mode = "bilinear" + + return F.interpolate( + image, + size=(size.height, size.width), + mode=resize_mode, + **resize_kwargs, + ) - # Transpose to (1, H, L, D) format for SDPA - Q = q_lhd.permute(1, 0, 2).unsqueeze(0) - K = k_lhd.permute(1, 0, 2).unsqueeze(0) - V = v_lhd.permute(1, 0, 2).unsqueeze(0) + def _preprocess( + self, + images: list[torch.Tensor], + do_resize: bool, + interpolation: Optional[Any], + do_rescale: Optional[bool], + rescale_factor: Optional[float], + do_normalize: Optional[bool], + image_mean: Optional[Union[float, Sequence[float]]], + image_std: Optional[Union[float, Sequence[float]]], + disable_grouping: Optional[bool] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + *, + patch_size: Optional[int] = None, + max_num_patches: Optional[int] = None, + min_num_patches: Optional[int] = None, + pixel_shuffle_scale: Optional[int] = None, + **kwargs, + ) -> BatchFeature: + grouped_images, grouped_images_index = group_images_by_shape( + images, disable_grouping=disable_grouping + ) - # Build block-diagonal mask for variable-length sequences - attn_mask = None - if cu_seqlens is not None: - seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() - seg_ids = torch.repeat_interleave(torch.arange(len(seq_sizes), device=q_lhd.device), seq_sizes) - block_mask = seg_ids[:, None] != seg_ids[None, :] # Cross-document attention blocked - attn_mask = torch.where(block_mask, -torch.inf, 0.0).to(q_lhd.dtype).view(1, 1, L, L) + grouped_outputs = {} - Y = F.scaled_dot_product_attention(Q, K, V, attn_mask=attn_mask, dropout_p=dropout, scale=scaling) - return Y.squeeze(0).permute(1, 0, 2) # Back to (L, H, D) + for shape, stacked_images in grouped_images.items(): + batch_size, channels, original_height, original_width = stacked_images.shape + if bool(self.do_convert_rgb) and channels == 1: + stacked_images = stacked_images.repeat(1, 3, 1, 1) -class Siglip2VariableSequenceEmbeddings(nn.Module): - def __init__(self, config: PixelShuffleSiglip2VisionConfig): - super().__init__() + target_height, target_width = get_image_size_for_max_num_patches( + original_height, + original_width, + patch_size, + max_num_patches, + min_num_patches=min_num_patches, + pixel_shuffle_scale=pixel_shuffle_scale, + ) + if do_resize: + image_batch = self.resize( + stacked_images, + SizeDict(height=target_height, width=target_width), + interpolation=interpolation, + ) + else: + if (original_height % patch_size) or (original_width % patch_size): + raise ValueError( + f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." + ) + image_batch, target_height, target_width = ( + stacked_images, + original_height, + original_width, + ) + + image_batch = self.rescale_and_normalize( + image_batch, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + ) + + patches = torch_extract_patches(image_batch, patch_size, patch_size) + _, height_tokens, width_tokens, _ = patches.shape + + token_grid = ( + torch.tensor([height_tokens, width_tokens], device=patches.device) + .long() + .expand(batch_size, 2) + ) + + real_dim = ( + torch.tensor( + [1, height_tokens, width_tokens], + dtype=torch.long, + device=patches.device, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + + if (height_tokens % pixel_shuffle_scale) or ( + width_tokens % pixel_shuffle_scale + ): + raise ValueError( + f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." + ) + virtual_height = height_tokens // pixel_shuffle_scale + virtual_width = width_tokens // pixel_shuffle_scale + + virtual_dim = ( + torch.tensor( + [1, virtual_height, virtual_width], + dtype=torch.long, + device=patches.device, + ) + .unsqueeze(0) + .repeat(batch_size, 1) + ) + grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) + + def _reorder_grouped_item( # reorder an item of tuple payloads using the same grouped_images_index + grouped: dict[ + tuple[int, ...], + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + ], + grouped_index: dict[tuple[int, ...], list[int]], + item_idx: int, + ) -> list[torch.Tensor]: + return reorder_images( + {k: v[item_idx] for k, v in grouped.items()}, grouped_index + ) + + keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") + tensors: dict[str, torch.Tensor] = {} + + for i, key in enumerate(keys): + slices = _reorder_grouped_item(grouped_outputs, grouped_images_index, i) + tensors[key] = torch.stack(slices, dim=0) + + return BatchFeature(data=tensors, tensor_type=return_tensors) + + +def create_document_attention_mask( + config: PretrainedConfig, + input_embeds: torch.Tensor, + cu_seqlens: Optional[torch.Tensor], +) -> Optional[Union[torch.Tensor, Any]]: + """ + Materialize a backend-specific block-diagonal attention mask from packed cu_seqlens. + + Returns None if cu_seqlens is missing/degenerate. + """ + if cu_seqlens is None or cu_seqlens.numel() < 2: + return None # Degenerate input: nothing to mask + + seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() + if seq_sizes.numel() == 0 or int(seq_sizes.sum()) == 0: + return None # All-empty segments produce no attention blocks + + seg_ids = torch.repeat_interleave( + torch.arange(seq_sizes.numel(), device=cu_seqlens.device), + seq_sizes, + ) + mask_function = packed_sequence_mask_function(seg_ids.view(1, -1)) + + seq_len = input_embeds.shape[1] + cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) + + mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] + return mask_interface( + batch_size=input_embeds.shape[0], + cache_position=cache_position, + kv_length=seq_len, + kv_offset=0, + mask_function=mask_function, + attention_mask=None, + allow_is_causal_skip=False, + allow_is_bidirectional_skip=False, + dtype=input_embeds.dtype, + config=config, + use_vmap=False, + ) + + +class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): + """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. + + Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image + `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional + embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing). + """ + + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) self.config = config self.embed_dim = config.hidden_size self.patch_size = config.patch_size @@ -182,199 +407,213 @@ class Siglip2VariableSequenceEmbeddings(nn.Module): self.position_embedding_size = int(self.num_patches**0.5) self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) - def positional_embeddings( - self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + @check_model_inputs + def forward( + self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor ) -> torch.Tensor: - # Prepare positional embeddings grid: (1, embed_dim, h, w) - positional_embeddings = ( - self.position_embedding.weight.reshape(self.position_embedding_size, self.position_embedding_size, -1) - .permute(2, 0, 1) - .unsqueeze(0) + # Rebatch packed variable-resolution patches to resize per-image position embeddings + # and track lengths for varlen attention metadata. + packed_pixel_values, seq_lengths = self._pack_to_batch( + seq_patches, spatial_shapes ) + if packed_pixel_values is None: + return seq_patches.new_zeros((0, self.embed_dim)) - _seq_patches, _seq_sizes, spatial_shapes = packed_seq_patches - pos_embeds_list = [] - mode = "bilinear" - align_corners = False - antialias = True - for spatial_shape in spatial_shapes: - height, width = spatial_shape - # Guard to ensure height and width are positive for torch.compile - if height > 0 and width > 0: - resized_pos_embed = F.interpolate( - positional_embeddings, - size=(height, width), - mode=mode, - align_corners=align_corners, - antialias=antialias, - ) - # Reshape from (1, embed_dim, height, width) to (height*width, embed_dim) - resized_pos_embed = resized_pos_embed.reshape(self.embed_dim, height * width).transpose(0, 1) - else: - # Fallback - should never happen in practice - resized_pos_embed = positional_embeddings.reshape( - self.embed_dim, self.position_embedding_size * self.position_embedding_size - ).transpose(0, 1)[: height * width] - pos_embeds_list.append(resized_pos_embed) + target_dtype = self.patch_embedding.weight.dtype + patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) + + positional_embeddings = self.position_embedding.weight.reshape( + self.position_embedding_size, + self.position_embedding_size, + -1, + ) + resized_positional_embeddings = self.resize_positional_embeddings( + positional_embeddings, + spatial_shapes, + max_length=packed_pixel_values.shape[1], + ) - # Concatenate all positional embeddings along the sequence dimension - pos_embeds = torch.cat(pos_embeds_list, dim=0) - return pos_embeds + embeddings = patch_embeds + resized_positional_embeddings + return self._unpack_from_batch(embeddings, seq_lengths) - def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor, torch.Tensor]): - seq_patches, _seq_sizes, _spatial_shapes = packed_seq_patches + def _pack_to_batch( + self, + seq_patches: torch.Tensor, + spatial_shapes: torch.Tensor, + ) -> tuple[Optional[torch.Tensor], torch.Tensor]: + """Rebatch a packed patch sequence using per-image grids to align embeddings. - # Apply patch embeddings - target_dtype = self.patch_embedding.weight.dtype - patch_embeds = self.patch_embedding(seq_patches.to(dtype=target_dtype)) - pos_embeds = self.positional_embeddings(packed_seq_patches) + Args: + seq_patches: Packed patches of shape (total_patches, patch_dim). + spatial_shapes: Per-image patch grids of shape (num_images, 2) as (H_tokens, W_tokens). - # Add positional embeddings to patch embeddings - embeddings = patch_embeds + pos_embeds - return embeddings + Returns: + (packed_pixel_values, seq_lengths) where: + - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 + - seq_lengths: (batch,) lengths for each image + """ + seq_lengths = spatial_shapes.long().prod(dim=-1) # (B,) + batch_size = int(seq_lengths.numel()) + if batch_size == 0: + return None, seq_lengths + + # Split the packed sequence into per-image chunks, then pad to a batch + lengths_list = seq_lengths.tolist() + chunks = seq_patches.split(lengths_list, dim=0) + packed_pixel_values = nn.utils.rnn.pad_sequence( + chunks, batch_first=True + ) # zero-padded by default + return packed_pixel_values, seq_lengths + + def _unpack_from_batch( + self, embeddings: torch.Tensor, seq_lengths: torch.Tensor + ) -> torch.Tensor: + """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" + lengths = seq_lengths.to(device=embeddings.device).tolist() + chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] + return torch.cat(chunks, dim=0) -class Siglip2VariableLengthAttention(nn.Module): +class IsaacVisionAttention(Siglip2Attention): """Custom attention that supports variable-length sequences with flash attention.""" - def __init__(self, config): - super().__init__() - self.config = config - self.embed_dim = config.hidden_size - self.num_heads = config.num_attention_heads - self.head_dim = self.embed_dim // self.num_heads - if self.head_dim * self.num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {self.num_heads})." - ) - self.scale = self.head_dim**-0.5 - self.dropout = config.attention_dropout - - self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) - self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) - - def forward(self, hidden_states, cu_seqlens=None, max_seqlen=None): - # Expect packed sequences with batch_size == 1 - batch_size, L, _ = hidden_states.shape - if batch_size != 1: - raise ValueError("packed variable-length attention expects batch_size=1") - x = hidden_states[0] # (L, E) - - H = self.num_heads - D = self.head_dim - p_drop = self.dropout if self.training else 0.0 - - # Project and reshape to (L, H, D) - q = self.q_proj(x).view(L, H, D) - k = self.k_proj(x).view(L, H, D) - v = self.v_proj(x).view(L, H, D) - - attn_impl = getattr(self.config, "_attn_implementation", "flash_attention_3") - - if attn_impl in ("flash_attention_2", "flash_attention_3"): - y_lhd, _ = flash_attention_document_mask_forward( - q, - k, - v, - attention_mask=None, - dropout=p_drop, - scaling=self.scale, - cum_seq_q=cu_seqlens, - cum_seq_k=cu_seqlens, - max_seqlen=max_seqlen, - is_causal=False, + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + **kwargs, + ): + kwargs.pop("output_hidden_states", None) + kwargs.pop("return_dict", None) + + batch_size, seq_length, embed_dim = hidden_states.shape + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view( + batch_size, seq_length, self.num_heads, self.head_dim + ).transpose(1, 2) + keys = keys.view( + batch_size, seq_length, self.num_heads, self.head_dim + ).transpose(1, 2) + values = values.view( + batch_size, seq_length, self.num_heads, self.head_dim + ).transpose(1, 2) + + attn_impl = self.config._attn_implementation + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] + if attn_impl != "sdpa": + attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] + + attention_kwargs: dict[str, Any] = { + "is_causal": False, + "scaling": self.scale, + } + + supports_varlen = cu_seqlens is not None and attn_impl in { + "flash_attention_2", + "flash_attention_3", + "flex_attention", + "paged|flash_attention_2", + "paged|flash_attention_3", + } + if supports_varlen: + if max_seqlen is not None: + max_q = max_k = int(max_seqlen) + elif cu_seqlens.numel() >= 2: + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + max_q = max_k = lengths.max() if lengths.numel() > 0 else seq_length + else: + max_q = max_k = seq_length + + attention_kwargs.update( + { + "cu_seq_lens_q": cu_seqlens, + "cu_seq_lens_k": cu_seqlens, + "max_length_q": max_q, + "max_length_k": max_k, + } ) - else: - y_lhd = sdpa_document_mask_forward(q, k, v, dropout=p_drop, scaling=self.scale, cu_seqlens=cu_seqlens) - # Merge heads and project - y = self.out_proj(y_lhd.reshape(L, self.embed_dim)) - return y.unsqueeze(0), None # (1, L, E) + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + **attention_kwargs, + ) + attn_output = attn_output.reshape( + batch_size, seq_length, embed_dim + ).contiguous() + attn_output = self.out_proj(attn_output) + return attn_output, attn_weights -class IsaacSiglip2EncoderLayer(nn.Module): - """Siglip2 encoder layer with variable-length attention.""" - def __init__(self, config: PixelShuffleSiglip2VisionConfig): - super().__init__() - self.embed_dim = config.hidden_size - self.self_attn = Siglip2VariableLengthAttention(config) +class IsaacVisionEncoderLayer(Siglip2EncoderLayer): + """Isaac vision encoder layer with variable-length attention.""" - self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) - self.mlp = Siglip2MLP(config) # Use HF's Siglip2MLP - self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.self_attn = IsaacVisionAttention(config) def forward( self, hidden_states: torch.Tensor, - cu_seqlens: torch.Tensor = None, - max_seqlen: int = None, - ) -> tuple[torch.FloatTensor]: + attention_mask: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + output_attentions: bool = False, + **kwargs: Unpack[TransformersKwargs], + ): + r""" + cu_seqlens (`torch.Tensor`, *optional*): + Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive + entries gives each document's token count and enables block-diagonal attention masking for packed batches. + max_seqlen (`int`, *optional*): + Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary + buffers for packed variable-length attention. + """ + # Run attention directly so variable-length metadata reaches FlashAttention. residual = hidden_states - hidden_states = self.layer_norm1(hidden_states) - - hidden_states, attn_weights = self.self_attn( - hidden_states=hidden_states, + attn_output, _ = self.self_attn( + hidden_states, + attention_mask=attention_mask, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, + **kwargs, ) - - hidden_states = residual + hidden_states + hidden_states = residual + attn_output residual = hidden_states hidden_states = self.layer_norm2(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - return (hidden_states,) + return hidden_states -class IsaacEncoder(nn.Module): +class IsaacVisionEncoder(Siglip2Encoder): """Encoder using Isaac encoder layers with variable-length attention support.""" - def __init__(self, config: PixelShuffleSiglip2VisionConfig): - super().__init__() - self.config = config - self.layers = nn.ModuleList([IsaacSiglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) - - def forward( - self, - inputs_embeds, - cu_seqlens: torch.Tensor | None = None, - max_seqlen: int | None = None, - output_hidden_states: bool = False, - ): - all_hidden_states = () if output_hidden_states else None - - hidden_states = inputs_embeds - - for encoder_layer in self.layers: - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_states, - cu_seqlens, - max_seqlen, - ) - - hidden_states = layer_outputs[0] - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - return hidden_states, all_hidden_states, None + def __init__(self, config: IsaacVisionConfig): + super().__init__(config) + self.layers = nn.ModuleList( + [IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] + ) def create_pixel_shuffle_index_map( seq_sizes: torch.Tensor, token_grids: torch.Tensor, scale_factor: int = 1, - device: torch.device | None = None, + device: Optional[torch.device] = None, ) -> torch.Tensor: """ Build a gather-index map that tells us, for every *output* token after @@ -394,44 +633,32 @@ def create_pixel_shuffle_index_map( packed sequence for the j-th sub-patch that forms the i-th output token. """ - if device is None: - device = seq_sizes.device - - r = int(scale_factor) - if r < 2: - raise ValueError("`scale_factor` must be ≥ 2") - - # Safety: all spatial dims must be divisible by r - # Cannot run under torch compile fullgraph mode hence - if not torch.compiler.is_compiling(): - if not ((token_grids[:, 0] % r == 0).all() and (token_grids[:, 1] % r == 0).all()): + if not is_torchdynamo_compiling(): + if (token_grids % scale_factor).any(): raise AssertionError( - f"Every (H,W) in `token_grids` must be divisible by scale_factor={r}, got {token_grids.tolist()}" + f"Every (H,W) in token_grids must be divisible by scale_factor={scale_factor}, got {token_grids.tolist()}" ) gather_chunks: list[torch.Tensor] = [] tok_offset = 0 + for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist()): + # Flat indices for this image's packed segment + grid = ( + torch.arange(seq_len, device=device, dtype=torch.int64).view(h, w) + + tok_offset + ) - for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist(), strict=False): - # Build the (H, W) grid of flat indices for this image - grid = torch.arange(seq_len, device=device, dtype=torch.int64) + tok_offset - grid = grid.view(h, w) # (H, W) - - # -------- identical ordering to your fixed-res routine -------- - # Step 1: split width into blocks of r - grid = grid.view(h, w // r, r) # (H, W/r, r) - # Step 2: now split height into blocks of r - grid = grid.view(h // r, r, w // r, r) # (H/r, r, W/r, r) - # Step 3: final permutation to (H/r, W/r, r, r) - grid = grid.permute(0, 2, 1, 3).contiguous() # (H/r, W/r, r, r) - # Step 4: each (r, r) block forms one output token - gather_chunks.append(grid.reshape(-1, r * r)) # (H*W / r², r²) + # Block into (H/s, W/s) groups; each group contributes s*s indices + grid = ( + grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) + .permute(0, 2, 1, 3) + .contiguous() + ) + gather_chunks.append(grid.view(-1, scale_factor * scale_factor)) tok_offset += seq_len - # Concatenate over all images in the packed batch - gather_idx = torch.cat(gather_chunks, dim=0) # (Σ_i HᵢWᵢ/r², r²) - return gather_idx + return torch.cat(gather_chunks, dim=0) def pixel_shuffle_varlen( @@ -460,47 +687,67 @@ def pixel_shuffle_varlen( Raises: ValueError: If more than one batch item is provided. """ - keep_batch_dim = x.dim() == 3 - if keep_batch_dim: + return_with_batch_dim = x.dim() == 3 + if return_with_batch_dim: if x.size(0) != 1: - raise AssertionError("Packed sequence is expected to have batch_size == 1") - x_ = x.squeeze(0) # (seq, embed) + raise ValueError( + f"Packed vision sequences expect a singleton batch dimension; received batch_size={x.size(0)}." + ) + embeddings = x.squeeze(0) # (seq, embed) else: - x_ = x # (seq, embed) + embeddings = x # (seq, embed) - embed_dim = x_.size(-1) - r = int(scale_factor) + embed_dim = embeddings.size(-1) + scale_factor = int(scale_factor) # Calculate seq_sizes from token_grids seq_sizes = torch.prod(token_grids, dim=-1) - # Build index map and gather in one go + # Build a single gather index so pixel shuffle works on the packed stream + # without unpacking per-image grids. gather_idx = create_pixel_shuffle_index_map( seq_sizes=seq_sizes, token_grids=token_grids, - scale_factor=r, - device=x_.device, - ) # (new_seq, r²) + scale_factor=scale_factor, + device=embeddings.device, + ) # (new_seq, scale_factor**2) - # Gather → (new_seq, r², embed_dim) - gathered = x_[gather_idx] # fancy indexing keeps gradient + # Gather → (new_seq, scale_factor**2, embed_dim) + gathered = embeddings[gather_idx] # fancy indexing keeps gradient - # Merge the r² group dimension into channels to finish the shuffle - out = gathered.reshape(gathered.size(0), embed_dim * r * r) + # Merge the scale_factor**2 group dimension into channels to finish the shuffle + out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) # Restore batch dimension if needed - if keep_batch_dim: + if return_with_batch_dim: out = out.unsqueeze(0) return out -class Siglip2SequenceVisionTransformer(nn.Module): - def __init__(self, config: PixelShuffleSiglip2VisionConfig): +class IsaacVisionTransformer(nn.Module): + """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. + + Args: + config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. + + Inputs: + packed_seq_patches (Tuple[Tensor, Tensor]): ``(patches, token_grids)`` where ``patches`` is a packed + patch sequence and ``token_grids`` holds per-image (H_tokens, W_tokens). + + Returns: + torch.Tensor: Vision embeddings after encoder + pixel shuffle, shaped ``(seq_len, hidden_size * s^2)``. + """ + + _supports_sdpa = True + + def __init__(self, config: IsaacVisionConfig): super().__init__() self.config = config - self.embeddings = Siglip2VariableSequenceEmbeddings(config) - self.encoder = IsaacEncoder(config) - self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.embeddings = IsaacVisionEmbeddings(config) + self.encoder = IsaacVisionEncoder(config) + self.post_layernorm = nn.LayerNorm( + config.hidden_size, eps=config.layer_norm_eps + ) self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): @@ -508,30 +755,35 @@ class Siglip2SequenceVisionTransformer(nn.Module): seq_sizes = torch.prod(token_grids, dim=-1) # Get embeddings from packed sequence - hidden_states = self.embeddings((seq_patches, seq_sizes, token_grids)) + hidden_states = self.embeddings(seq_patches, token_grids) - # Add a pseudo batch dimension for the encoder + # Add a pseudo batch dimension so we can reuse the batch-first encoder stack + # while still driving per-image cu_seqlens through the varlen attention path. hidden_states = hidden_states.unsqueeze(0) # Generate cumulative sequence lengths for variable-length attention - cu_seqlens, max_seqlen = create_cumulative_seq_lengths(seq_sizes, hidden_states.device) + cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) + + attention_mask = create_document_attention_mask( + self.config, hidden_states, cu_seqlens + ) # Pass through encoder with variable-length attention parameters - hidden_states, _, _ = self.encoder( + encoder_outputs = self.encoder( inputs_embeds=hidden_states, + attention_mask=attention_mask, cu_seqlens=cu_seqlens, - max_seqlen=max_seqlen, ) + hidden_states = encoder_outputs.last_hidden_state # Apply final layer normalization hidden_states = self.post_layernorm(hidden_states) - if self.pixel_shuffle_scale_factor > 1: - hidden_states = pixel_shuffle_varlen( - x=hidden_states, - token_grids=token_grids, - scale_factor=self.pixel_shuffle_scale_factor, - ) + hidden_states = pixel_shuffle_varlen( + x=hidden_states, + token_grids=token_grids, + scale_factor=self.pixel_shuffle_scale_factor, + ) # Remove the pseudo batch dimension we added earlier hidden_states = hidden_states.squeeze(0) @@ -539,44 +791,56 @@ class Siglip2SequenceVisionTransformer(nn.Module): return hidden_states -# ============================================================================ -# Configuration -# ============================================================================ +class IsaacMultiModalProjector(nn.Module): + """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" + + def __init__(self, config: IsaacConfig): + super().__init__() + self.vision_hidden_size = config.vision_config.hidden_size * ( + config.vision_config.pixel_shuffle_scale_factor**2 + ) + self.backbone_hidden_size = config.hidden_size + self.linear_1 = nn.Linear( + self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False + ) + self.silu = nn.SiLU() + self.linear_2 = nn.Linear( + 4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False + ) -MAX_PIXELS = 60_000_000 # 60‑megapixel ceiling ≈ 8200 × 7300 px + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.silu(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states -# Vision preprocessing constants -VISION_MEAN = (0.5, 0.5, 0.5) -VISION_STD = (0.5, 0.5, 0.5) -VISION_SCALE = 1 / 255 +class IsaacVisionEmbedding(nn.Module): + _supports_sdpa = True -def _make_writeable(arr: np.ndarray) -> np.ndarray: - """Return *arr* itself if it is already writeable, otherwise try to flip the - write flag in-place and finally fall back to `arr.copy()`. - This guarantees the buffer handed to `torch.from_numpy()` is always - writeable, silencing the PyTorch warning about undefined behaviour. - """ - if arr.flags.writeable: - return arr + def __init__(self, config: IsaacConfig): + super().__init__() + vision_cfg = config.vision_config - # First, try the cheap path — in‑place flag toggle (works for mmap'd arrays - # and some shared memory buffers): - try: - arr.setflags(write=True) - return arr # success: no data copy - except ValueError: - # Buffer is inherently read‑only (e.g. backed by PyAV / PIL): make copy - return arr.copy() + self.vision_tower = IsaacVisionTransformer(vision_cfg) + self.multimodal_projector = IsaacMultiModalProjector(config) + def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + hidden_states = self.vision_tower(vision_tokens) + return self.multimodal_projector(hidden_states) -def extract_image_pil(image: PIL.Image.Image) -> torch.Tensor | None: - if image.width * image.height > MAX_PIXELS: - raise ValueError(f"Image (w={image.width}, h={image.height}) > MAX=`{MAX_PIXELS}`") - img = image if image.mode == "RGB" else image.convert("RGB") - arr = np.asarray(img) - arr = _make_writeable(arr) - return torch.from_numpy(arr) + +def get_scaled_image_size( + scale: float, + original_size: int, + patch_size: int, + pixel_shuffle_scale: int, +) -> int: + scaled_size = scale * original_size + divisor = patch_size * pixel_shuffle_scale + scaled_size = math.ceil(scaled_size / divisor) * divisor + scaled_size = max(divisor, scaled_size) + return int(scaled_size) def get_image_size_for_max_num_patches( @@ -584,7 +848,7 @@ def get_image_size_for_max_num_patches( image_width: int, patch_size: int, max_num_patches: int, - min_num_patches: int | None = None, + min_num_patches: Optional[int] = None, eps: float = 1e-5, pixel_shuffle_scale: int = 1, ) -> tuple[int, int]: @@ -611,13 +875,6 @@ def get_image_size_for_max_num_patches( and respect both the maximum and optional minimum patch-count constraints. """ - def get_scaled_image_size(scale, original_size, patch_size, pixel_shuffle_scale): - scaled_size = scale * original_size - divisor = patch_size * pixel_shuffle_scale - scaled_size = math.ceil(scaled_size / divisor) * divisor - scaled_size = max(divisor, scaled_size) - return int(scaled_size) - # Ensure divisibility divisor = patch_size * pixel_shuffle_scale adjusted_height = math.ceil(image_height / divisor) * divisor @@ -628,20 +885,29 @@ def get_image_size_for_max_num_patches( num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) if min_num_patches is not None and num_patches < min_num_patches: - # Scale up + # Scale up via binary search to satisfy the minimum patch budget while + # preserving divisibility by patch_size * pixel_shuffle_scale. scale_min, scale_max = 1.0, 100.0 while (scale_max - scale_min) >= eps: scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + target_height = get_scaled_image_size( + scale, image_height, patch_size, pixel_shuffle_scale + ) + target_width = get_scaled_image_size( + scale, image_width, patch_size, pixel_shuffle_scale + ) num_patches = (target_height / patch_size) * (target_width / patch_size) if num_patches >= min_num_patches: scale_max = scale else: scale_min = scale scale = scale_max - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + target_height = get_scaled_image_size( + scale, image_height, patch_size, pixel_shuffle_scale + ) + target_width = get_scaled_image_size( + scale, image_width, patch_size, pixel_shuffle_scale + ) return target_height, target_width elif num_patches <= max_num_patches: return adjusted_height, adjusted_width @@ -650,1022 +916,994 @@ def get_image_size_for_max_num_patches( scale_min, scale_max = eps / 10, 1.0 while (scale_max - scale_min) >= eps: scale = (scale_min + scale_max) / 2 - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + target_height = get_scaled_image_size( + scale, image_height, patch_size, pixel_shuffle_scale + ) + target_width = get_scaled_image_size( + scale, image_width, patch_size, pixel_shuffle_scale + ) num_patches = (target_height / patch_size) * (target_width / patch_size) if num_patches <= max_num_patches: scale_min = scale else: scale_max = scale scale = scale_min - target_height = get_scaled_image_size(scale, image_height, patch_size, pixel_shuffle_scale) - target_width = get_scaled_image_size(scale, image_width, patch_size, pixel_shuffle_scale) + target_height = get_scaled_image_size( + scale, image_height, patch_size, pixel_shuffle_scale + ) + target_width = get_scaled_image_size( + scale, image_width, patch_size, pixel_shuffle_scale + ) return target_height, target_width -_MEAN_TENSOR = torch.tensor(VISION_MEAN, dtype=torch.float32).view(1, 1, 1, -1) -_STD_TENSOR = torch.tensor(VISION_STD, dtype=torch.float32).view(1, 1, 1, -1) +class IsaacConfig(PretrainedConfig): + """Configuration class for Isaac multimodal model. - -def prepare_image_tensor( - image: torch.Tensor, - scale: float = VISION_SCALE, -) -> torch.Tensor: - r"""Standardize RGB images prior to patch extraction via rescaling and whitening. - - Args: - image (`torch.Tensor`): - Tensor with shape `(..., height, width, 3)` containing RGB values. The tensor is converted to floating - point if needed. - scale (`float`, *optional*, defaults to `VISION_SCALE`): - Scalar multiplier applied before normalization. - Returns: - `torch.Tensor`: Normalized tensor with the same shape as the input and dtype `torch.float32`. - """ - if not torch.is_floating_point(image): - image = image.float() - rescaled = image * scale - - # Use precomputed tensors and move to the correct device if needed - mean_tensor = _MEAN_TENSOR.to(image.device) - std_tensor = _STD_TENSOR.to(image.device) - - normalized = (rescaled - mean_tensor) / std_tensor - return normalized - - -def patchify_vision(image: torch.Tensor, patch_size: int) -> torch.Tensor: - r"""Convert normalized images into flattened ViT-style patches. - - Args: - image (`torch.Tensor`): - Tensor of shape `(num_images, height, width, channels)`. - patch_size (`int`): - Edge length of the square patches - - Returns: - `torch.Tensor`: - Patch tensor where each position stores the flattened pixels belonging to that patch. - - Raises: - ValueError: If `height` or `width` is not divisible by `patch_size`. - """ - num_images, height, width, channels = image.shape - if height % patch_size or width % patch_size: - raise ValueError(f"Dimensions of images {image.shape} are not divisible by patch_size={patch_size}.") - patches = image.reshape(num_images, height // patch_size, patch_size, width // patch_size, patch_size, channels) - patches = patches.permute(0, 1, 3, 2, 4, 5) - patches = patches.reshape(num_images, height // patch_size, width // patch_size, channels * patch_size * patch_size) - return patches - - -def process_vision_for_patches( - images: torch.Tensor, - patch_size: int, - max_num_patches: int, - min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, -) -> tuple[torch.Tensor, list[int]]: - r"""Resize, normalize, and patchify RGB images for the vision encoder. - - Args: - images (`torch.Tensor`): - Either `(height, width, channels)` for a single image or `(num_images, height, width, channels)` for a - batch. Channels are expected to be RGB. - patch_size (`int`): - Edge length of square patches; implictly controls resize grid granularity. - max_num_patches (`int`): - Maximum number of patches allowed after resizing. - min_num_patches (`int`, *optional*): - Minimum number of patches. If provided, the routine upsamples images as needed to satisfy the lower bound. - pixel_shuffle_scale (`int`, *optional*, defaults to 1): - pixel shuffle scale factor; influences the target grid that the function produces. - - Returns: - `tuple[torch.Tensor, list[int]]`: A pair `(patches, dims_virtual)` where `patches` has shape - `(num_images, target_h / patch_size, target_w / patch_size, channels * patch_size**2)` and `dims_virtual` - encodes effective `(images, height, width)` dimensions after optional pixel shuffling. + This configuration corresponds to checkpoints such as + [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). """ - # Add batch dim if single image - if images.dim() == 3: - images = images.unsqueeze(0) - - # Permute to channel first for resize - images = images.permute(0, 3, 1, 2) - - # Get target dimensions - _, _, orig_height, orig_width = images.shape - target_height, target_width = get_image_size_for_max_num_patches( - orig_height, - orig_width, - patch_size, - max_num_patches, - min_num_patches=min_num_patches, - pixel_shuffle_scale=pixel_shuffle_scale, - ) - - # Resize - images = F.interpolate( - images, - size=(target_height, target_width), - mode="bilinear", - align_corners=False, - ) - - # Back to channel last - images = images.permute(0, 2, 3, 1) - - # Normalize - images = prepare_image_tensor(images) - - # Patchify - patches = patchify_vision(images, patch_size=patch_size) - - # Calculate dimensions for the patches - n_images, h_patches, w_patches, _ = patches.shape - dims_virtual = ( - [1, h_patches, w_patches] - if pixel_shuffle_scale == 1 - else [1, h_patches // pixel_shuffle_scale, w_patches // pixel_shuffle_scale] - ) - - return patches, dims_virtual - - -def precompute_inv_freq(theta: float, dim: int) -> torch.Tensor: - """ - Returns shape (dim//2,). - """ - inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - return inv_freq # type: ignore[return-value] - - -def precompute_cos_sin_3d( - position_ids: torch.Tensor, # shape (3, B, T) - inv_freq: torch.Tensor, # shape (dim//2,) - mrope_half_section: list[int], # sum to dim//2 -) -> tuple[torch.Tensor, torch.Tensor]: - r"""Generate 3D rotary embeddings for multi-axis positions. - - Args: - position_ids (`torch.Tensor`): - Tensor of shape `(3, batch_size, seq_len)` containing positional indices for the x/y/t axes. - inv_freq (`torch.Tensor`): - Precomputed inverse frequency vector used to derive rotary phases. - mrope_half_section (`list[int]`): - Sizes the axis-specific frequency blocks. - - Returns: - `tuple[torch.Tensor, torch.Tensor]`: Cosine and sine tensors, each of shape `(batch_size, seq_len, dim)`, ready - to be passed into rotary attention layers. - """ - B = position_ids.shape[1] - T = position_ids.shape[2] - dim_half = inv_freq.shape[0] - device = position_ids.device - - # Initialize with full dimension (not half) to match LLaMA - cos_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) - sin_3d = torch.zeros((B, T, dim_half * 2), dtype=torch.float32, device=device) - - offset = 0 - for d in range(3): - block_size = mrope_half_section[d] - freq_slice = inv_freq[offset : offset + block_size] # shape => (block_size,) - # shape => (B, T, block_size) - phase = position_ids[d].unsqueeze(-1).float() * freq_slice - - cos_part = phase.cos() - sin_part = phase.sin() - - # Duplicate values for both halves of the dimension - cos_3d[:, :, offset : offset + block_size] = cos_part - cos_3d[:, :, dim_half + offset : dim_half + offset + block_size] = cos_part - sin_3d[:, :, offset : offset + block_size] = sin_part - sin_3d[:, :, dim_half + offset : dim_half + offset + block_size] = sin_part - - offset += block_size - - return cos_3d, sin_3d - - -class RopeScaling(TypedDict, total=False): - rope_type: str - factor: float - mrope_section: list[int] - mrope_interleaved: bool - low_freq_factor: float - high_freq_factor: float - original_max_position_embeddings: int - - -class IsaacConfig(Qwen3Config): - """Configuration class for Isaac multimodal model.""" model_type = "isaac" - sub_configs = {"vision_config": PixelShuffleSiglip2VisionConfig} + sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} + image_processor_type = "IsaacImageProcessor" def __init__( self, - vision_config=None, - vision_patch_size: int = 16, - vision_max_num_patches: int = 256, - vision_min_num_patches: int | None = None, - pixel_shuffle_scale: int = 1, + vision_config: Optional[IsaacVisionConfig] = None, + text_config: Optional[Union[Qwen3Config, dict]] = None, + vision_rescale_factor: float = 1 / 255, max_sequence_length: int = 16384, vision_token: str = "", - vision_attn_implementation: str | None = None, **kwargs, ): + attn_implementation = kwargs.get("attn_implementation") + + if isinstance(text_config, dict): + self.text_config = self.sub_configs["text_config"](**text_config) + elif isinstance(text_config, Qwen3Config): + self.text_config = text_config + elif text_config is None: + self.text_config = self.sub_configs["text_config"]() + + # Seed RoPE parameters before base init so the shared mixin can standardize/validate them. + self.rope_parameters = getattr(self.text_config, "rope_parameters", None) + self.layer_types = getattr(self.text_config, "layer_types", None) + super().__init__(**kwargs) - # Handle vision config - either dict or PixelShuffleSiglip2VisionConfig instance + # Keep rope parameters aligned between the composite and text sub-configs. + self.text_config.rope_parameters = self.rope_parameters + + # Mirror frequently accessed Qwen3 attributes at the composite config level + self.vocab_size = self.text_config.vocab_size + self.hidden_size = self.text_config.hidden_size + self.num_hidden_layers = self.text_config.num_hidden_layers + self.num_attention_heads = self.text_config.num_attention_heads + self.head_dim = self.text_config.head_dim + self.hidden_act = self.text_config.hidden_act + self.use_cache = self.text_config.use_cache + self.rope_theta = self.rope_parameters["rope_theta"] + + self.layer_types = getattr(self.text_config, "layer_types", None) + layer_type_validation(self.layer_types, self.num_hidden_layers) + + # Handle vision config - either dict or IsaacVisionConfig instance if isinstance(vision_config, dict): self.vision_config = self.sub_configs["vision_config"](**vision_config) + elif isinstance(vision_config, IsaacVisionConfig): + self.vision_config = vision_config elif vision_config is None: self.vision_config = self.sub_configs["vision_config"]() - else: - self.vision_config = vision_config - # EventStreamProcessor parameters (for backward compatibility) - self.video_patch_size = vision_patch_size - self.vision_max_num_patches = vision_max_num_patches - self.vision_min_num_patches = vision_min_num_patches - self.pixel_shuffle_scale = pixel_shuffle_scale + # Propagate user-requested attention backend to the vision sub-config when provided. + if attn_implementation is not None: + if isinstance(attn_implementation, dict): + vision_attn = attn_implementation.get( + "vision_config", attn_implementation.get("", None) + ) + else: + vision_attn = attn_implementation + if vision_attn is not None: + self.vision_config._attn_implementation = vision_attn + + if getattr(self, "_attn_implementation", None) is None: + self._attn_implementation = "sdpa" + # Vision normalization parameters + self.vision_rescale_factor = float(vision_rescale_factor) # Processing parameters self.max_sequence_length = max_sequence_length self.vision_token = vision_token - self.vision_attn_implementation = vision_attn_implementation + def to_dict(self): + output = super().to_dict() + # Ensure nested configs round-trip through dict serialization + if hasattr(self, "text_config") and self.text_config is not None: + output["text_config"] = self.text_config.to_dict() + if hasattr(self, "vision_config") and self.vision_config is not None: + output["vision_config"] = self.vision_config.to_dict() + return output -# ============================================================================ -# Processor Components -# ============================================================================ - -def create_text_event(tokenizer: AutoTokenizer, text: str, time: float = 0.0) -> Event: - r"""Wrap a text into an `Event` compatible with the multimodal TensorStream. +class IsaacProcessor(ProcessorMixin): + """Processor that pairs the Isaac image processor with the Qwen2 tokenizer. Args: - tokenizer (`AutoTokenizer`): - Tokenizer used to convert text into model vocabulary ids. - text (`str`): - Plain-text fragment to encode. - time (`float`, *optional*, defaults to 0.0): - Timeline coordinate associated with the event. Both start and end times use the same value because text - segments are instantaneous in the scheduler. + image_processor: Vision preprocessor (fast) used for patch extraction. + tokenizer: Qwen2 tokenizer instance. + vision_token (str, optional): Placeholder token marking image locations. Defaults to "". + max_sequence_length (int, optional): Maximum combined text+vision tokens kept. Defaults to 16384. + rescale_factor (float, optional): Image rescale factor; defaults to 1/255. + config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. Returns: - `Event`: Event carrying a `(num_tokens, 1)` tensor of token ids with matching - metadata so that downstream processors can compute modality-specific embeddings. + BatchFeature: Contains ``input_ids`` and ``packed_inputs`` (patch tensors, grids, offsets, lengths, modality, positions). """ - tokens = tokenizer.encode(text, add_special_tokens=False, return_tensors="pt").squeeze(0) - - # Calculate dimensions for the event - num_tokens = len(tokens) - dims_virtual = [num_tokens, 1] # [sequence_length, 1] - dims_real = dims_virtual.copy() - - # Ensure tokens has the right shape for tensor_stream_token_view - # It expects a 2D tensor where sum(dim=-1) gives the token IDs - if tokens.dim() == 1: - tokens = tokens.unsqueeze(-1) - - return Event( - data=tokens, - type=TextType.text, - time=(time, time), - dims_virtual=dims_virtual, - dims_real=dims_real, - idx_range=(0, num_tokens), - ) - - -# ============================================================================ -# Processor -# ============================================================================ - -class IsaacProcessor(ProcessorMixin): - attributes = ["tokenizer"] - tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + attributes = ["image_processor", "tokenizer"] + image_processor_class = ("IsaacImageProcessorFast",) + tokenizer_class = ("Qwen2Tokenizer",) + pad_token_id = 151643 def __init__( self, - tokenizer: Qwen2Tokenizer, - config: IsaacConfig | dict, - ): - super().__init__(tokenizer) - self.tokenizer = tokenizer - + image_processor, + tokenizer, + *, + vision_token: str = "", + max_sequence_length: int = 16384, + rescale_factor: Optional[float] = None, + config: Optional[Union[IsaacConfig, dict]] = None, + ) -> None: if isinstance(config, dict): config = IsaacConfig(**config) - self.config = config - # Use vision token from config - self.vision_token = config.vision_token + if config is not None: + vision_token = config.vision_token + max_sequence_length = config.max_sequence_length + rescale_factor = config.vision_rescale_factor - # Processing parameters - self.max_sequence_length = config.max_sequence_length + resolved_rescale_factor = ( + float(rescale_factor) if rescale_factor is not None else float(1 / 255) + ) + if config is not None: + config.vision_rescale_factor = resolved_rescale_factor - # Vision processing parameters - self.patch_size = config.video_patch_size - self.max_num_patches = config.vision_max_num_patches - self.min_num_patches = config.vision_min_num_patches - self.pixel_shuffle_scale = config.pixel_shuffle_scale + self.image_processor = image_processor + super().__init__(image_processor, tokenizer) - def apply_chat_template( - self, - messages: list[dict[str, Any]], - tokenize: bool = False, - add_generation_prompt: bool = False, - **kwargs, - ) -> Any: - return self.tokenizer.apply_chat_template( - messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt, **kwargs + text_pad_token_id = getattr(self.tokenizer, "pad_token_id", None) + image_pad_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") + + self.text_pad_token_id = int(text_pad_token_id) + self.image_pad_token_id = int(image_pad_token_id) + self.pad_token_id = self.text_pad_token_id + + self.current_processor = self.image_processor + self.config = config + self.chat_template = getattr(self.tokenizer, "chat_template", None) + self.vision_token = vision_token + self.max_sequence_length = max_sequence_length + + def _pack_batch( + self, texts: list[str], images_list: Optional[list[Optional[list[Image]]]] + ) -> dict[str, Optional[torch.Tensor]]: + if images_list is None: + pairs = ((t, None) for t in texts) + else: + pairs = zip(texts, images_list, strict=True) + + per_sample: list[dict[str, Optional[torch.Tensor]]] = [] + for txt, imgs in pairs: + if imgs is not None and isinstance(imgs, Image): + imgs = [imgs] + per_sample.append(self._pack_single(txt, imgs)) + + lengths = [int(p["input_ids"].shape[1]) for p in per_sample] + max_len = max(lengths, default=0) + batch = len(per_sample) + + # Use first device with data as anchor + base_device = torch.device("cpu") + for p in per_sample: + if p["input_ids"].numel() > 0: + base_device = p["input_ids"].device + break + + pad_id = self.text_pad_token_id + padded_input_ids = torch.full( + (batch, max_len), pad_id, device=base_device, dtype=torch.long + ) + padded_modality = torch.full( + (batch, max_len), + ModalityType.text.value, + device=base_device, + dtype=torch.long, + ) + padded_position_ids = torch.zeros( + (batch, max_len, 3), device=base_device, dtype=torch.long ) - def build_event_stream_simple( - self, - text: str, - images: list[PIL.Image.Image] | None = None, - ) -> Stream: - events = [] - # Process text and images - # Find all occurrences of vision token - - pattern = re.escape(self.vision_token) - parts = re.split(f"({pattern})", text) # Keep the delimiter in the result - - image_idx = 0 - for current_time, part in enumerate(parts): - if part == self.vision_token: - # Replace vision token with image event - if image_idx < len(images): - # Create vision event from PIL image - image_tensor = extract_image_pil(images[image_idx]) - if image_tensor is not None: - # Create a vision event with the image tensor - vision_event = Event( - data=image_tensor.unsqueeze(0), # HWC format from extract_image_pil - type=VisionType.image, # I-frame - time=(current_time, current_time), - ) - events.append(vision_event) - image_idx += 1 - elif part: # Non-empty text part - # tokens = self.text_processor.tokenize(part, add_special_tokens=False) - text_event = create_text_event(self.tokenizer, part, time=current_time) - events.append(text_event) - - # Process vision events if any - if any(event.type == VisionType.image for event in events): - # Separate text and vision events for processing - text_events = [event for event in events if event.type == TextType.text] - vision_events = [event for event in events if event.type == VisionType.image] - - # Process vision events using functional approach - processed_vision_events = [] - for vision_event in vision_events: - # Process the vision data - patches, dims_virtual = process_vision_for_patches( - vision_event.data.squeeze(0), # Remove the extra dimension - patch_size=self.patch_size, - max_num_patches=self.max_num_patches, - min_num_patches=self.min_num_patches, - pixel_shuffle_scale=self.pixel_shuffle_scale, + for i, (sample, l) in enumerate(zip(per_sample, lengths)): + if l: + padded_input_ids[i, -l:] = sample["input_ids"][0] + padded_modality[i, -l:] = sample["modality_tensor"][0] + padded_position_ids[i, -l:] = sample["position_ids"][0] + + # Vision-side aggregation + v_samples = [ + (b, s) for b, s in enumerate(per_sample) if s["vision_patches"] is not None + ] + if v_samples: + vision_patches_list = [s["vision_patches"] for _, s in v_samples] + vision_grids_list = [s["vision_token_grids"] for _, s in v_samples] + vision_offsets_list = [s["vision_token_offsets"] for _, s in v_samples] + vision_lengths_list = [s["vision_token_lengths"] for _, s in v_samples] + vision_batch_indices = [ + torch.full_like(s["vision_token_offsets"], b) for b, s in v_samples + ] + + vision_patches = torch.cat(vision_patches_list, dim=0) + vision_token_grids = torch.cat(vision_grids_list, dim=0) + vision_token_offsets = torch.cat(vision_offsets_list, dim=0) + vision_token_lengths = torch.cat(vision_lengths_list, dim=0) + vision_token_batch_indices = torch.cat(vision_batch_indices, dim=0) + else: + vision_patches = vision_token_grids = vision_token_offsets = ( + vision_token_lengths + ) = vision_token_batch_indices = None + + return { + "input_ids": padded_input_ids, + "vision_patches": vision_patches, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "vision_token_batch_indices": vision_token_batch_indices, + "modality_tensor": padded_modality, + "position_ids": padded_position_ids, + } + + def _pack_single( + self, text: str, images: Optional[list[Image]] + ) -> dict[str, Optional[torch.Tensor]]: + segments = text.split( + self.vision_token + ) # Parse by vision_token; interleave text segments and image segments. + num_images = len(segments) - 1 + items: list[dict[str, Any]] = [] + total = 0 + num_provided_images = len(images) if images is not None else 0 + if not num_images == num_provided_images: + raise ValueError( + f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text} " + ) + + for index, segment in enumerate(segments): + if segment: + tok = ( + self.tokenizer.encode( + segment, add_special_tokens=False, return_tensors="pt" + ) + .squeeze(0) + .to(torch.long) ) + segment_length = int(tok.numel()) + items.append( + {"type": "text", "segment_length": segment_length, "tok": tok} + ) + total += segment_length - # Update event with processed data - vision_event.data = patches.unsqueeze(1) # Add back frame dimension - vision_event.dims_virtual = dims_virtual - vision_event.dims_real = ( - dims_virtual - if self.pixel_shuffle_scale == 1 - else [ - dims_virtual[0], - dims_virtual[1] * self.pixel_shuffle_scale, - dims_virtual[2] * self.pixel_shuffle_scale, - ] + if index < num_images: + feat = self.image_processor( + images=images[index], return_tensors=TensorType.PYTORCH ) - vision_event.idx_range = (0, math.prod(dims_virtual)) + patches = feat["patches"][0].reshape(-1, feat["patches"].shape[-1]) - # Flatten the patches - vision_event.data = vision_event.data.reshape(-1, vision_event.data.shape[-1]) - processed_vision_events.append(vision_event) + virtual_pixel_size = ( + feat["virtual_pixel_size"][0].to(torch.long).tolist() + ) + real_pixel_size = feat["real_pixel_size"][0].to(torch.long).tolist() + dims = tuple( + (virtual_pixel_size + [1, 1, 1])[:3] + ) # (T,H,W) in virtual space + segment_length = int(dims[0] * dims[1] * dims[2]) + + items.append( + { + "type": "image", + "segment_length": segment_length, + "dims": dims, + "patches": patches, + "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), + } + ) + total += segment_length + + # Tail crop window. + start = max(0, total - self.max_sequence_length) + end = total + + image_pad_value = self.image_pad_token_id + base_device: Optional[torch.device] = None + position_ids, modality, input_ids = [], [], [] + vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], [] + + global_offset = 0 + position_offset = 0 + + for item in items: + segment_length = int(item["segment_length"]) + current_window_start = max(start, global_offset) + current_window_end = min(end, global_offset + segment_length) + has_overlap = current_window_end > current_window_start + + if has_overlap and base_device is None: + base_device = ( + item["patches"].device + if item["type"] == "image" + else item["tok"].device + ) - events = text_events + processed_vision_events + if has_overlap: + segment_local_start = int(current_window_start - global_offset) + segment_local_end = int(current_window_end - global_offset) + segment_local_indices = torch.arange( + segment_local_start, + segment_local_end, + device=base_device, + dtype=torch.long, + ) + segment_kept_length = segment_local_end - segment_local_start - # Create stream without scheduling (events already in order) - return create_stream(events, priority=[TextType.text, VisionType.image], schedule=True) + if item["type"] == "text": + slice_index = segment_local_indices + position_offset + zero_axis_pad = torch.zeros_like(slice_index) + position_ids.append( + torch.stack((slice_index, zero_axis_pad, zero_axis_pad), -1) + ) + modality.append( + torch.full( + (segment_kept_length,), + ModalityType.text.value, + device=base_device, + dtype=torch.long, + ) + ) + input_ids.append( + item["tok"].to(base_device)[ + segment_local_start:segment_local_end + ] + ) + position_offset += segment_length + else: + num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] + hw = grid_height_tokens * grid_width_tokens + slice_index = (segment_local_indices // hw) + position_offset + rem = segment_local_indices % hw + row_index = rem // grid_width_tokens + col_index = rem % grid_width_tokens + position_ids.append( + torch.stack((slice_index, row_index, col_index), -1) + ) + modality.append( + torch.full( + (segment_kept_length,), + ModalityType.image.value, + device=base_device, + dtype=torch.long, + ) + ) + input_ids.append( + torch.full( + (segment_kept_length,), + image_pad_value, + device=base_device, + dtype=torch.long, + ) + ) - def __call__( - self, - text: str | list[str], - images: PIL.Image.Image | list[PIL.Image.Image] | None = None, - return_tensors: str | TensorType | None = TensorType.PYTORCH, - **kwargs, - ) -> BatchFeature: - """ - Process text and images into TensorStream format. - Args: - text: Input text or list of texts with vision tokens - images: PIL image or list of images (optional) - return_tensors: Format for output tensors + vpatches.append( + item["patches"].to(base_device) + ) # full patches; slice later via offsets/lengths + # Record per-image slice boundaries so we can drop cropped virtual tokens + # after pixel shuffle without re-packing the entire vision stream. + grids.append(item["grid"]) + vision_token_offsets.append(segment_local_start) + vision_token_lengths.append(segment_kept_length) - Returns: - BatchFeature with input_ids and tensor_stream - """ - # Normalize inputs to lists - if isinstance(text, str): - texts = [text] - else: - texts = text + position_offset += int(num_pos_slices) - if images is not None: - if isinstance(images, PIL.Image.Image): - images_list = [images] else: - images_list = images - else: - images_list = None - - if len(texts) != 1: - raise ValueError("IsaacProcessor currently supports batch_size=1") - if images_list is not None: - # Count vision tokens in text to validate image count - vision_token_count = texts[0].count(self.vision_token) - if vision_token_count != len(images_list): - raise ValueError( - f"Number of {self.vision_token} tokens in text ({vision_token_count}) " - f"must match number of images ({len(images_list)})" + position_offset += ( + segment_length if item["type"] == "text" else int(item["dims"][0]) ) - # Build event stream - stream = self.build_event_stream_simple( - text=texts[0], - images=images_list, - ) + global_offset += segment_length - # Create TensorStream - tensor_stream = TensorStream([stream]) + if base_device is None: + base_device = torch.device("cpu") - # Slice to max length if needed - _, T = tensor_stream.shape - if T > self.max_sequence_length: - tensor_stream = ts_slice(tensor_stream, start=T - self.max_sequence_length, end=T) + modality_tensor = ( + torch.cat(modality, 0).unsqueeze(0) + if modality + else torch.zeros((1, 0), device=base_device, dtype=torch.long) + ) + position_ids = ( + torch.cat(position_ids, 0).unsqueeze(0) + if position_ids + else torch.zeros((1, 0, 3), device=base_device, dtype=torch.long) + ) + input_ids = ( + torch.cat(input_ids, 0).unsqueeze(0) + if input_ids + else torch.zeros((1, 0), device=base_device, dtype=torch.long) + ) - # Get token view - tokens = tensor_stream_token_view(tensor_stream) - if return_tensors in (TensorType.PYTORCH, "pt"): - input_ids = torch.as_tensor(tokens, dtype=torch.long) + if vpatches: + vision_patches = torch.cat(vpatches, 0) + vision_token_grids = torch.tensor( + grids, device=base_device, dtype=torch.long + ) + vision_token_offsets = torch.tensor( + vision_token_offsets, device=base_device, dtype=torch.long + ) + vision_token_lengths = torch.tensor( + vision_token_lengths, device=base_device, dtype=torch.long + ) else: - input_ids = tokens + vision_patches = vision_token_grids = vision_token_offsets = ( + vision_token_lengths + ) = None - data = { + return { "input_ids": input_ids, - "tensor_stream": tensor_stream, + "vision_patches": vision_patches, + "vision_token_grids": vision_token_grids, + "vision_token_offsets": vision_token_offsets, + "vision_token_lengths": vision_token_lengths, + "modality_tensor": modality_tensor, + "position_ids": position_ids, } - return BatchFeature(data=data) + def __call__( + self, + text: Union[str, list[str]], + images: Optional[Union[Image, list[Image]]] = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + **kwargs, + ) -> BatchFeature: + texts = [text] if isinstance(text, str) else text + images_list: Optional[list[Optional[list[Image]]]] = None + if images is not None: + if isinstance(images, list) and len(images) == len(texts): + if not images: + images_list = [] + elif isinstance(images[0], list): + images_list = images # already per-sample + else: + images_list = [ + [img] for img in images + ] # list of images, one per sample + else: + images_list = [] + for t in texts: + n_tok = t.count(self.vision_token) + if n_tok == 0: + images_list.append(None) + else: + if isinstance(images, list): + images_list.append(images) + else: + images_list.append([images]) + + packed = self._pack_batch(texts, images_list) + input_ids = packed.pop("input_ids") + return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) + + +class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): + def __init__(self, config: IsaacConfig, device=None): + rope_source_cfg = ( + config.get_text_config() if hasattr(config, "get_text_config") else config + ) + rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} + config_for_rope = copy.copy(rope_source_cfg) + config_for_rope.rope_scaling = rope_scaling + + init_device = ( + device + if device is not None and getattr(device, "type", None) != "meta" + else None + ) + super().__init__(config_for_rope, device=init_device) + rotary_half_dim = self.inv_freq.shape[0] + self.mrope_section = self._resolve_mrope_section( + rope_scaling.get("mrope_section"), rotary_half_dim + ) + self.hidden_size = ( + getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size + ) -# ============================================================================ -# Model -# ============================================================================ + @staticmethod + def _resolve_mrope_section( + section: Optional[list[int]], rotary_half_dim: int + ) -> list[int]: + if section is None: + weights = (2, 1, 1) + base = [rotary_half_dim * w // sum(weights) for w in weights] + base[0] += rotary_half_dim - sum(base) + return base + + section = [int(v) for v in section] + return section + + def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: + split_sections = tuple(self.mrope_section * 2) + chunks = tensor.split(split_sections, dim=-1) + return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) + def forward( + self, + position_ids: torch.Tensor, + modality_tensor: torch.Tensor, + hidden_states: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if hidden_states is None: + batch, seq_len, _ = position_ids.shape + hidden_states = torch.zeros( + batch, + seq_len, + self.hidden_size, + dtype=torch.float32, + device=position_ids.device, + ) -def compute_position_ids_input_ids(input_ids: torch.Tensor) -> torch.Tensor: - r"""Create 3D positional indices for token input. + with torch.no_grad(): + pos = position_ids.clone() + not_spatial = modality_tensor != ModalityType.image.value + data_1d = pos[not_spatial][..., 0].unsqueeze( + -1 + ) # Collapse non-vision modalities to 1D positions + pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) + pos_axes = pos.permute(2, 0, 1).contiguous() + + cos_axes, sin_axes = super().forward(hidden_states, pos_axes) + cos_axes, sin_axes = ( + cos_axes.to(hidden_states.dtype), + sin_axes.to(hidden_states.dtype), + ) + cos_combined, sin_combined = ( + self._combine_axes(cos_axes), + self._combine_axes(sin_axes), + ) - Args: - input_ids (`torch.Tensor`): - Tensor of shape `(batch_size, seq_len)` containing token ids. + return cos_combined, sin_combined - Returns: - `torch.Tensor`: Positional indices with shape `(batch_size, seq_len, 3)` where each channel duplicates the - 1D position so it can be consumed by the 3-axis MRoPE rotary embedding. - """ - batch_size, seq_length = input_ids.shape - position_ids = torch.arange(seq_length, device=input_ids.device) - position_ids = position_ids.view(1, -1).expand(batch_size, -1) - position_ids = position_ids.unsqueeze(2).expand(-1, -1, 3) # Add 3D for MRoPE - return position_ids +@auto_docstring +class IsaacModel(Qwen3PreTrainedModel): + supports_gradient_checkpointing = True + _can_compile_fullgraph = False + _supports_flex_attn = False + _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} + all_tied_weights_keys: dict[str, str] = {} -class IsaacRotaryEmbedding(nn.Module): - def __init__(self, config: IsaacConfig, device=None): - super().__init__() + def __init__(self, config: IsaacConfig): + Qwen3PreTrainedModel.__init__(self, config) - # Extract dimensions from config - self.hidden_size = config.hidden_size - self.num_attention_heads = config.num_attention_heads - self.head_dim = config.head_dim + text_cfg_source = config.text_config + text_cfg = copy.deepcopy(text_cfg_source) + self.text_model = Qwen3Model._from_config(text_cfg) + self.text_model.config = ( + config # Ensure downstream callers observe the composed config + ) - # Get rope_scaling config - use direct access when available - rope_scaling = getattr(config, "rope_scaling", None) or {} + self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) - # Read RopeScaling parameters - self.rope_type = rope_scaling.get("rope_type", "default") + self.vision_embedding = IsaacVisionEmbedding(config) + self.vision_embedding._supports_sdpa = True + self.max_sequence_length = config.max_sequence_length + self.vision_rescale_factor = config.vision_rescale_factor + self.vision_token = config.vision_token + self.rope_deltas = None - self.mrope_section = [ - self.head_dim // 4, # 2x more for temporal dim - self.head_dim // 8, - self.head_dim // 8, - ] + self.post_init() - rope_base = getattr(config, "rope_theta", 10000.0) - inv_freq = precompute_inv_freq(rope_base, self.head_dim) - self.register_buffer("inv_freq", inv_freq, persistent=False) + # Respect config-specified gradient checkpointing + if getattr(config, "gradient_checkpointing", False): + self.gradient_checkpointing_enable() - def forward(self, position_ids: torch.Tensor, modality_tensor: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: - with torch.no_grad(): - # Ensure non-spatial tokens have 1D rotation equivalence - not_spatial = ~(modality_tensor == VisionType.image.value) - # shape is [N, 1] - data_1d = position_ids[not_spatial][..., 0].unsqueeze(-1) - # now broadcast it from [N, 1] -> [N, D] so it matches pos[not_spatial] exactly - data_1d = data_1d.expand(-1, position_ids.shape[-1]) # expand along the last dim - position_ids = position_ids.clone() # Clone to avoid warning about in-place operations on expanded tensors - position_ids[not_spatial] = data_1d - position_ids = position_ids.permute(2, 0, 1) # pos dim first -> (3, B, L) - cos, sin = precompute_cos_sin_3d(position_ids, self.inv_freq, self.mrope_section) + def get_input_embeddings(self) -> nn.Module: + return self.text_model.get_input_embeddings() - return cos, sin + def set_input_embeddings(self, value: nn.Module) -> None: + self.text_model.set_input_embeddings(value) + vocab_size = getattr(value, "num_embeddings", None) + if vocab_size is not None: + self.config.vocab_size = vocab_size + if hasattr(self.config, "text_config"): + self.config.text_config.vocab_size = vocab_size + self.text_model.config.vocab_size = vocab_size + @property + def embed_tokens(self) -> nn.Module: + return self.text_model.embed_tokens -class IsaacModel(Qwen3Model): - def __init__(self, config: IsaacConfig): - super().__init__(config) - text_cfg = getattr(config, "get_text_config", lambda: config)() - self.layers = torch.nn.ModuleList( - [Qwen3DecoderLayer(text_cfg, layer_idx) for layer_idx in range(config.num_hidden_layers)] + @embed_tokens.setter + def embed_tokens(self, value: nn.Module) -> None: + self.text_model.embed_tokens = value + + @property + def vision_model(self) -> nn.Module: + return self.vision_embedding.vision_tower + + def embed_packed_inputs( + self, input_ids: torch.Tensor, packed_inputs: dict[str, Optional[torch.Tensor]] + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Expects input_ids for text tokens and packed_inputs containing: + - modality_tensor: (batch, seq_len) modality ids aligned to the sequence + - position_ids: (batch, seq_len, 3) MRoPE coordinates (optional) + - vision_patches: concatenated vision tokens shaped (total_tokens, embed_dim) or None + - vision_token_grids: (num_images, 2) token grid sizes or None + - vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional) + - vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional) + - vision_token_batch_indices: (num_images,) batch row for each image (optional; defaults to zeros) + """ + modality = packed_inputs["modality_tensor"].to( + device=input_ids.device, dtype=torch.long ) - self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) + embeds = self.text_model.embed_tokens(input_ids) - vision_cfg = config.vision_config - # Use vision_attn_implementation if specified, otherwise fall back to general attn_implementation - vision_cfg._attn_implementation = ( - config.vision_attn_implementation - if config.vision_attn_implementation is not None - else config._attn_implementation + vision_patches = packed_inputs.get("vision_patches") + if vision_patches is None: + return embeds, modality + + token_grids = packed_inputs["vision_token_grids"].to( + device=vision_patches.device, dtype=torch.long ) - if vision_cfg is None: - raise ValueError("IsaacConfig should always have vision_config") - - hidden_dim = vision_cfg.hidden_size * (vision_cfg.pixel_shuffle_scale_factor**2) - self.vision_embedding = nn.Sequential( - Siglip2SequenceVisionTransformer(vision_cfg), - nn.Linear( - hidden_dim, - 4 * hidden_dim, - bias=False, - ), - nn.SiLU(), - nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), + vision = self.vision_embedding( + (vision_patches, token_grids) + ) # (total_tokens, hidden) + + # per-image token counts AFTER pixel-shuffle + vision_reduction_factor = int( + self.config.vision_config.pixel_shuffle_scale_factor ) + sizes = ( + token_grids.prod(-1) + .div( + vision_reduction_factor * vision_reduction_factor, rounding_mode="floor" + ) + .tolist() + ) + offsets = packed_inputs.get("vision_token_offsets") + lengths = packed_inputs.get("vision_token_lengths") + batch_indices = packed_inputs.get("vision_token_batch_indices") + + chunks = vision.split(sizes, dim=0) + picked: list[torch.Tensor] = [] + picked_batch: list[int] = [] + for chunk, size, offset, length, batch_index in zip( + chunks, + sizes, + offsets.tolist(), + lengths.tolist(), + (batch_indices.tolist() if batch_indices is not None else [0] * len(sizes)), + ): + if size <= 0: + continue + offset = max(0, min(int(offset), size)) + length = max(0, min(int(length), size - offset)) + if length: + picked.append(chunk[offset : offset + length]) + picked_batch.append(int(batch_index)) + if picked: + vision_chunks = picked + vision_batch_idx = picked_batch + else: + vision_chunks = vision_batch_idx = [] - # Dispatch table for TensorStream balanced embedding (text + vision) - self.embed_fns = { - TextType: self.embed_text_tokens, - VisionType: self.embed_vision, - } + vision = ( + torch.cat(vision_chunks, 0) + if vision_chunks + else vision.new_zeros((0, vision.size(-1))) + ) + embeds = embeds.clone() + num_batches = modality.shape[0] + image_positions = [ + (modality[b] == ModalityType.image.value) + .nonzero(as_tuple=False) + .squeeze(-1) + for b in range(num_batches) + ] + cursors = [0 for _ in range(num_batches)] + + for chunk, batch_index in zip(vision_chunks, vision_batch_idx): + if chunk.numel() == 0: + continue + positions = image_positions[batch_index] + start = cursors[batch_index] + end = start + chunk.shape[0] + embeds[batch_index, positions[start:end]] = chunk.to( + device=embeds.device, dtype=embeds.dtype + ) + cursors[batch_index] = end - def embed_text_tokens(self, token_ids: torch.Tensor) -> torch.Tensor: - """Embed text tokens, squeezing singleton dimensions.""" - # Text events are shaped as (..., 1); squeeze the singleton index dim - h = self.embed_tokens(token_ids) - if h.dim() >= 2 and h.size(-2) == 1: - h = h[..., 0, :] - return h + return embeds, modality + + def get_rope_index( + self, + *, + position_ids: Optional[torch.Tensor] = None, + attention_mask: torch.Tensor, + inputs_embeds: torch.Tensor, + cache_position: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Build 3D position ids and per-batch RoPE deltas.""" - def embed_vision(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: - """Embed vision tokens using the vision encoder.""" - # vision tokens is (seq_patches, token_grids) - return self.vision_embedding(vision_tokens) + device = inputs_embeds.device + batch_size, seq_len = inputs_embeds.shape[:2] - def embed_stream(self, tensor_stream: TensorStream) -> torch.Tensor: - """ - Embed each modality stream independently, preserving the original TensorStream - structure. - """ - flat_stream = tensor_stream.flat_stream() - per_modality_stream = group_streams(flat_stream, group_fn=lambda ev: ev.type, schedule=False) - per_modality_compact_stream = {k: v.compact() for k, v in per_modality_stream.items()} - - # Collect per-event grids for vision tokens (H, W like dims sans time) - token_grids = defaultdict(list) - for stream in tensor_stream.streams: - for event in stream: - token_grids[event.type].append(event.dims(virtual=False)) - - embedded_compact = {} - for stream_type, modality_payload_tensor in per_modality_compact_stream.items(): - if stream_type.modality == VisionType: - # Build a (N_events, 2) grid tensor with spatial dims only - grids = token_grids.get(stream_type, []) - if len(grids) == 0: - input_tensor = modality_payload_tensor - else: - token_grids_tensor = torch.tensor(grids, dtype=torch.long, device=tensor_stream.device)[:, 1:] - input_tensor = (modality_payload_tensor, token_grids_tensor) - embedded_compact[stream_type] = self.embed_fns[stream_type.modality](input_tensor) - else: - embedded_compact[stream_type] = self.embed_fns[stream_type.modality](modality_payload_tensor) + if position_ids is None: + cp = cache_position.to(device=device, dtype=torch.long) + if cp.ndim == 1: + cp = cp.view(1, -1).expand(batch_size or 1, -1) + + base_delta = torch.as_tensor( + 0 if self.rope_deltas is None else self.rope_deltas, + device=device, + dtype=torch.long, + ).reshape(-1, 1) + base_delta = torch.broadcast_to(base_delta, (batch_size, 1)) + + mask_delta = attention_mask.to(device=device, dtype=torch.long).sum( + 1, keepdim=True + ) - attention_mask.size(1) + rope_position = cp + base_delta + mask_delta + pos_3d = rope_position.unsqueeze(-1).expand(-1, -1, 3) + return pos_3d, base_delta + + position_ids = position_ids.to(device=device) + if position_ids.ndim == 2: + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) + + if position_ids.shape[1] != seq_len: + start_positions = position_ids[:, :1, 0] + position_ids = ( + torch.arange(seq_len, device=position_ids.device).view(1, -1) + + start_positions + ) + position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) - # Reconstruct a TensorStream with embedded payloads and compact - embedded_ts = reconstruct_tensor_stream_from_compact_dict(tensor_stream, embedded_compact) - h = embedded_ts.compact() # (B, T, D) - return h + attn = attention_mask.to(device=device, dtype=torch.long) + m_per_batch = position_ids.amax(dim=(1, 2)) + seq_lens = attn.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=device) + rope_deltas = ( + (m_per_batch + 1 - seq_lens).to(dtype=position_ids.dtype).unsqueeze(1) + ) + return position_ids, rope_deltas + @auto_docstring + @check_model_inputs def forward( self, - input_ids: torch.LongTensor | None = None, - tensor_stream: TensorStream | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - modality_tensor: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - use_cache: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs, + input_ids: Optional[torch.LongTensor] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | BaseModelOutputWithPast: """ Forward pass with MRoPE position embeddings. Computes position embeddings once and passes them through all layers. + + Args: + packed_inputs (`dict`, *optional*): + Plain tensor payloads. When provided, requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). + modality_tensor (`torch.LongTensor`, *optional*): + Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing + values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. """ - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - # Get inputs - if tensor_stream is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both tensor_stream and inputs_embeds") - elif tensor_stream is not None: - # Embed TensorStream directly - inputs_embeds = self.embed_stream(tensor_stream) - # Create modality tensor if not provided - if modality_tensor is None: - modality_tensor = modality_mask(tensor_stream) - elif input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + + output_attentions = kwargs.pop("output_attentions", None) + + modality_tensor: Optional[torch.Tensor] = None + + if packed_inputs is not None: + inputs_embeds, modality_tensor = self.embed_packed_inputs( + input_ids, packed_inputs + ) elif input_ids is not None: - inputs_embeds = self.embed_tokens(input_ids) - # Create text modality tensor if not provided - if modality_tensor is None: - batch_size, seq_length = input_ids.shape - modality_tensor = torch.full( - (batch_size, seq_length), TextType.text.value, device=input_ids.device, dtype=torch.long - ) - elif inputs_embeds is None: - raise ValueError("You have to specify either tensor_stream, input_ids or inputs_embeds") + inputs_embeds = self.text_model.embed_tokens(input_ids) - # Create default position_ids if not provided - if position_ids is None: - if tensor_stream is not None: - position_ids = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) - else: - position_ids = compute_position_ids_input_ids(input_ids) + device = inputs_embeds.device + batch_size, seq_len = inputs_embeds.shape[:2] - # Compute MRoPE position embeddings if we have custom rotary_emb - cos, sin = self.rotary_emb(position_ids, modality_tensor) - cos = cos.to(inputs_embeds.dtype) - sin = sin.to(inputs_embeds.dtype) + if use_cache and past_key_values is None: + past_key_values = DynamicCache(config=self.config.get_text_config()) - # Prepare attention mask - if attention_mask is not None: - attention_mask = self._update_causal_mask( - attention_mask, inputs_embeds, cache_position, past_key_values, False + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + seq_len, device=device + ) + + if attention_mask is None: + attention_mask = torch.ones( + inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long ) - # Initialize hidden states + if ( + position_ids is None + and packed_inputs is not None + and packed_inputs.get("position_ids") is not None + ): + position_ids = packed_inputs.get("position_ids").to(device=device) + + position_ids, rope_deltas = self.get_rope_index( + position_ids=position_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + ) + self.rope_deltas = rope_deltas + + if modality_tensor is None: + modality_tensor = torch.full( + (batch_size, seq_len), + ModalityType.text.value, + device=device, + dtype=torch.long, + ) + + cos, sin = self.rotary_emb( + position_ids, modality_tensor, hidden_states=inputs_embeds + ) + + decoder_position_ids = ( + position_ids[..., 0] if position_ids.ndim == 3 else position_ids + ) + + if not isinstance(attention_mask, dict): + attention_mask = create_masks_for_generate( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + position_ids=decoder_position_ids, + ) + + is_mask_dict = isinstance(attention_mask, dict) hidden_states = inputs_embeds + all_attentions = [] if output_attentions else None - for decoder_layer in self.layers: - layer_outputs = decoder_layer( + for layer in self.text_model.layers: + layer_mask = ( + attention_mask[layer.attention_type] if is_mask_dict else attention_mask + ) + layer_outputs = layer( hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_values, + attention_mask=layer_mask, + position_ids=decoder_position_ids, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=(cos, sin), + output_attentions=output_attentions, **kwargs, ) - hidden_states = layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs + layer_outputs_is_tuple = isinstance(layer_outputs, tuple) + hidden_states = ( + layer_outputs[0] if layer_outputs_is_tuple else layer_outputs + ) + if output_attentions and layer_outputs_is_tuple: + all_attentions.append(layer_outputs[1]) - # Final layer norm - hidden_states = self.norm(hidden_states) + hidden_states = self.text_model.norm(hidden_states) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=past_key_values, + hidden_states=(hidden_states,), + attentions=tuple(all_attentions) if output_attentions else None, ) - def _update_causal_mask( - self, - attention_mask: torch.Tensor, - input_tensor: torch.Tensor, - cache_position: torch.Tensor, - past_key_values: Cache, - output_attentions: bool = False, - ): - if self.config._attn_implementation == "flash_attention_2": - if attention_mask is not None and past_key_values is not None: - is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Qwen3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in - # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail - # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - using_static_cache = isinstance(past_key_values, StaticCache) - using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) - - # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if ( - self.config._attn_implementation == "sdpa" - and not (using_static_cache or using_sliding_window_cache) - and not output_attentions - ): - if AttentionMaskConverter._ignore_causal_mask_sdpa( - attention_mask, - inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, - sliding_window=self.config.sliding_window, - is_training=self.training, - ): - return None - - dtype, device = input_tensor.dtype, input_tensor.device - min_dtype = torch.finfo(dtype).min - sequence_length = input_tensor.shape[1] - # SlidingWindowCache or StaticCache - if using_sliding_window_cache or using_static_cache: - target_length = past_key_values.get_max_cache_shape() - # DynamicCache or no cache - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else past_seen_tokens + sequence_length + 1 - ) - - # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). - causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( - attention_mask, - sequence_length=sequence_length, - target_length=target_length, - dtype=dtype, - device=device, - cache_position=cache_position, - batch_size=input_tensor.shape[0], - config=self.config, - past_key_values=past_key_values, - ) - - if ( - self.config._attn_implementation == "sdpa" - and attention_mask is not None - and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not output_attentions - ): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) - - return causal_mask - - @staticmethod - def _prepare_4d_causal_attention_mask_with_cache_position( - attention_mask: torch.Tensor, - sequence_length: int, - target_length: int, - dtype: torch.dtype, - device: torch.device, - cache_position: torch.Tensor, - batch_size: int, - config: Qwen3Config, - past_key_values: Cache, - ): - """ - Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape - `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. - - Args: - attention_mask (`torch.Tensor`): - A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. - sequence_length (`int`): - The sequence length being processed. - target_length (`int`): - The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. - dtype (`torch.dtype`): - The dtype to use for the 4D attention mask. - device (`torch.device`): - The device to place the 4D attention mask on. - cache_position (`torch.Tensor`): - Indices depicting the position of the input sequence tokens in the sequence. - batch_size (`torch.Tensor`): - Batch size. - config (`Qwen3Config`): - The model's configuration class - past_key_values (`Cache`): - The cache class that is being used currently to generate - """ - if attention_mask is not None and attention_mask.dim() == 4: - # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. - causal_mask = attention_mask - else: - min_dtype = torch.finfo(dtype).min - causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) - if config.sliding_window is not None: - # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also - # the check is needed to verify is current checkpoint was trained with sliding window or not - if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( - cache_position.reshape(-1, 1) - config.sliding_window - ) - diagonal_attend_mask.bitwise_or_(sliding_attend_mask) - causal_mask *= diagonal_attend_mask - causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) - if attention_mask is not None: - causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit - if attention_mask.shape[-1] > target_length: - attention_mask = attention_mask[:, :target_length] - mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - return causal_mask - +@auto_docstring class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): - """Isaac multimodal model for conditional generation.""" - config_class = IsaacConfig + _can_compile_fullgraph = False + _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} + all_tied_weights_keys: dict[str, str] = { + "lm_head.weight": "model.text_model.embed_tokens.weight" + } def __init__(self, config: IsaacConfig): - Qwen3PreTrainedModel.__init__(self, config) - self.model = IsaacModel(config) # Use our custom model + super().__init__(config) + self.model = IsaacModel(config) self.vocab_size = config.vocab_size self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - # Tracks rotary position offsets computed during a full forward pass so decode steps can reuse them. - self.rope_deltas = None - - self.config = config - - def get_rope_index( - self, - input_ids: torch.Tensor | None, - tensor_stream: TensorStream | None, - attention_mask: torch.Tensor | None, - ) -> tuple[torch.Tensor, torch.Tensor]: - """Compute MRoPE position ids from a TensorStream (or 1D fallback). - - Returns (position_ids, rope_deltas). position_ids is (B,L,3) for MRoPE. - rope_deltas is (B,1) used to advance positions in decode. - """ - # tensor_stream present: compute 3D coords - if tensor_stream is None and input_ids is None: - raise ValueError("`tensor_stream` or `input_ids` must be provided to compute rope indices") - - if tensor_stream is not None: - pos_3d = compute_mrope_pos_tensor(tensor_stream) # (B,L,3) - else: - pos_3d = compute_position_ids_input_ids(input_ids) - B, L, _ = pos_3d.shape - - # Max position per batch across the 3 planes and sequence dimension: (B,) - m_per_batch = pos_3d.amax(dim=(1, 2)) - - # Sequence lengths per batch: (B,) - if attention_mask is None: - seq_lens = torch.full_like(m_per_batch, L) - else: - seq_lens = attention_mask.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=m_per_batch.device) - - rope_deltas = (m_per_batch + 1 - seq_lens).to(dtype=pos_3d.dtype).unsqueeze(1) - return pos_3d, rope_deltas + @auto_docstring + @can_return_tuple + @check_model_inputs def forward( self, - input_ids: torch.LongTensor | None = None, - tensor_stream: TensorStream | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_hidden_states: bool | None = None, - return_dict: bool | None = None, - cache_position: torch.LongTensor | None = None, - **kwargs, + input_ids: Optional[torch.LongTensor] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[list[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], ) -> tuple | CausalLMOutputWithPast: - """ - Forward pass for conditional generation supporting both standard inputs and TensorStream. - Uses our embed_stream approach for multimodal inputs. - """ + """Run multimodal CausalLM forward, accepting packed vision/text inputs. - # Don't compute embeddings here - let the model handle it - if tensor_stream is not None: - input_ids = None - if input_ids is None and inputs_embeds is None and tensor_stream is None: - raise ValueError("Either input_ids, inputs_embeds, or tensor_stream must be provided.") - - # Build position ids (MRoPE) if needed and tensor_stream is available - # During decode we reuse `self.rope_deltas` computed on the initial forward pass; `rope_delta` captures how far - # cached rotary phases have progressed so we can advance `position_ids` without rebuilding the TensorStream. - if position_ids is None and tensor_stream is not None: - position_ids, self.rope_deltas = self.get_rope_index(input_ids, tensor_stream, attention_mask) - elif position_ids is None and input_ids is not None: - # For text inputs build position ids and modality tensor - position_ids = compute_position_ids_input_ids(input_ids) - if cache_position is not None and self.rope_deltas is not None: - # Combine the incremental decode step (`cache_position`) with cached offsets so hidden states continue - # rotating in lockstep across generation steps. - rope_delta = (cache_position[0] + self.rope_deltas).to(input_ids.device) - else: - rope_delta = 0 - if cache_position is not None and not isinstance(rope_delta, int): # otherwise `deltas` is an int `0` - batch_size = input_ids.shape[0] - rope_delta = rope_delta.repeat_interleave(batch_size // rope_delta.shape[0], dim=0) - position_ids = position_ids.add(rope_delta) - - if tensor_stream is not None: - modality_tensor = modality_mask(tensor_stream) - else: - batch_size, seq_len = input_ids.shape - modality_tensor = torch.empty(batch_size, seq_len, device=position_ids.device).fill_(TextType.text.value) + Args: + packed_inputs (`dict`, *optional*): + Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and + vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. + + Returns: + CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. + """ + output_attentions = kwargs.pop("output_attentions", None) outputs = self.model( input_ids=input_ids, - tensor_stream=tensor_stream, + packed_inputs=packed_inputs, attention_mask=attention_mask, position_ids=position_ids, - modality_tensor=modality_tensor, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + output_attentions=output_attentions, cache_position=cache_position, **kwargs, ) - hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = None if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size) + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.vocab_size + ) return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, - attentions=None, + attentions=outputs.attentions if output_attentions else None, ) def prepare_inputs_for_generation( self, input_ids: torch.LongTensor, - past_key_values: list[torch.FloatTensor] | None = None, - attention_mask: torch.Tensor | None = None, - inputs_embeds: torch.FloatTensor | None = None, - tensor_stream: TensorStream | None = None, - cache_position: torch.LongTensor | None = None, - position_ids: torch.LongTensor | None = None, - use_cache: bool = True, + past_key_values: Optional[list[torch.FloatTensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + packed_inputs: Optional[dict[str, torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, **kwargs, ) -> dict[str, Any]: - """ - Prepare inputs for generation, handling TensorStream inputs properly. - """ - # Call parent preparation model_inputs = super().prepare_inputs_for_generation( input_ids, past_key_values=past_key_values, @@ -1673,27 +1911,41 @@ class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): inputs_embeds=inputs_embeds, cache_position=cache_position, position_ids=position_ids, - use_cache=use_cache, **kwargs, ) + if packed_inputs is None: + return model_inputs - # Handle TensorStream for first forward pass only - if tensor_stream is not None and (cache_position is None or cache_position[0] == 0): - model_inputs["tensor_stream"] = tensor_stream - # Let forward rebuild position_ids using cached deltas during decode + past_len = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + first_step = past_len == 0 + model_inputs["packed_inputs"] = packed_inputs if first_step else None model_inputs["position_ids"] = None - # Drop tensor_stream after step 0 - if cache_position is not None and cache_position[0] != 0: - model_inputs["tensor_stream"] = None + return model_inputs - def can_generate(self) -> bool: + @classmethod + def can_generate(cls) -> bool: return True + def set_input_embeddings(self, value: nn.Module) -> None: + self.model.set_input_embeddings(value) + vocab_size = getattr(value, "num_embeddings", None) + self.config.vocab_size = vocab_size + self.model.config.vocab_size = vocab_size + self.model.text_model.config.vocab_size = vocab_size + if self.lm_head.weight.shape[0] != vocab_size: + self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) + self.lm_head.weight = self.model.text_model.embed_tokens.weight + __all__ = [ "IsaacConfig", "IsaacModel", + "IsaacPreTrainedModel", # noqa: F822 "IsaacForConditionalGeneration", + "IsaacImageProcessorFast", "IsaacProcessor", ] +