Isaac-0.1 / modular_isaac.py
philperceptron's picture
feat: batch processing/generation
ebbee7e
raw
history blame
75 kB
# coding=utf-8
# Copyright 2025 Perceptron, Inc and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import math
from collections.abc import Callable, Sequence
from enum import IntEnum
from typing import Any, Optional, Union
from transformers.cache_utils import DynamicCache
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
from transformers.feature_extraction_utils import BatchFeature
from transformers.generation.utils import GenerationMixin
from transformers.image_processing_utils_fast import (
BaseImageProcessorFast,
ImagesKwargs,
SizeDict,
group_images_by_shape,
reorder_images,
)
from transformers.image_utils import (
PILImageResampling,
)
from transformers.masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
create_masks_for_generate,
packed_sequence_mask_function,
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3ForCausalLM,
Qwen3Model,
Qwen3PreTrainedModel,
)
from transformers.processing_utils import ProcessorMixin, Unpack
from transformers.utils import TensorType, auto_docstring
from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN
from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD
from transformers.utils.generic import (
OutputRecorder,
TransformersKwargs,
can_return_tuple,
check_model_inputs,
)
from transformers.utils.import_utils import (
is_torch_available,
is_torchdynamo_compiling,
is_torchvision_available,
is_vision_available,
)
from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
from transformers.models.siglip2.modeling_siglip2 import (
Siglip2Attention,
Siglip2Encoder,
Siglip2EncoderLayer,
Siglip2VisionEmbeddings,
)
if is_torch_available():
import torch
import torch.nn as nn
import torch.nn.functional as F
if is_vision_available():
from PIL.Image import Image
else:
Image = None
if is_torchvision_available():
from transformers.models.pix2struct.image_processing_pix2struct_fast import (
torch_extract_patches,
)
class ModalityType(IntEnum):
"""
Modality identifiers for events.
Members:
image: Vision tokens (e.g., patches).
text: Textual tokens.
"""
image = 0
text = 1
class IsaacVisionConfig(Siglip2VisionConfig):
"""Vision configuration for Isaac with Pixel Shuffle support.
Extends Siglip2VisionConfig with additional fields for pixel shuffle.
Args:
pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1):
Spatial factor applied before pixel shuffle reduces the resolution.
num_patches (`int`, *optional*, defaults to 256):
Maximum number of learnable positional embeddings to initialize.
"""
model_type = "isaac_vision"
base_config_key = "vision_config"
def __init__(
self,
hidden_size=768,
intermediate_size=3072,
num_hidden_layers=12,
num_attention_heads=12,
num_channels=3,
num_patches=256,
patch_size=16,
hidden_act="gelu_pytorch_tanh",
layer_norm_eps=1e-6,
attention_dropout=0.0,
pixel_shuffle_scale_factor=1,
**kwargs,
):
super().__init__(**kwargs)
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_channels = num_channels
self.patch_size = patch_size
self.attention_dropout = attention_dropout
self.layer_norm_eps = layer_norm_eps
self.hidden_act = hidden_act
self.num_patches = num_patches
# Add our custom fields
self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor
# Ensure a sensible default attention backend
if getattr(self, "_attn_implementation", None) is None:
self._attn_implementation = "sdpa"
class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False):
patch_size: Optional[int]
max_num_patches: Optional[int]
min_num_patches: Optional[int]
pixel_shuffle_scale: Optional[int]
@auto_docstring
class IsaacImageProcessorFast(BaseImageProcessorFast):
MAX_PIXELS = 60_000_000 # 60‑megapixel ceiling ≈ 8200 × 7300 px
resample = PILImageResampling.BILINEAR
model_input_names = ["patches", "token_grids"]
valid_kwargs = IsaacImageProcessorFastKwargs
unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"]
do_resize = True
do_center_crop = False
patch_size: Optional[int] = 16
max_num_patches: Optional[int] = 256
min_num_patches: Optional[int] = None
pixel_shuffle_scale: Optional[int] = 1
do_pad = False
do_rescale = True
do_normalize = True
image_mean = list(VISION_MEAN)
image_std = list(VISION_STD)
do_convert_rgb = True
disable_grouping = False
def __init__(
self,
**kwargs: Unpack[IsaacImageProcessorFastKwargs],
) -> None:
super().__init__(**kwargs)
def _validate_preprocess_kwargs(self, **kwargs):
# Allow callers to omit resize-related placeholders that BaseImageProcessorFast checks for.
kwargs.pop("do_resize", None)
kwargs.pop("size", None)
kwargs.pop("do_center_crop", None)
kwargs.pop("crop_size", None)
kwargs.pop("disable_grouping", None)
return super()._validate_preprocess_kwargs(**kwargs)
def resize(
self,
image: torch.Tensor,
size: SizeDict,
**kwargs,
) -> torch.Tensor:
resize_kwargs: dict[str, Any] = {"align_corners": False}
resize_mode = "bilinear"
return F.interpolate(
image,
size=(size.height, size.width),
mode=resize_mode,
**resize_kwargs,
)
def _preprocess(
self,
images: list[torch.Tensor],
do_resize: bool,
interpolation: Optional[Any],
do_rescale: Optional[bool],
rescale_factor: Optional[float],
do_normalize: Optional[bool],
image_mean: Optional[Union[float, Sequence[float]]],
image_std: Optional[Union[float, Sequence[float]]],
disable_grouping: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
*,
patch_size: Optional[int] = None,
max_num_patches: Optional[int] = None,
min_num_patches: Optional[int] = None,
pixel_shuffle_scale: Optional[int] = None,
**kwargs,
) -> BatchFeature:
grouped_images, grouped_images_index = group_images_by_shape(
images, disable_grouping=disable_grouping
)
grouped_outputs = {}
for shape, stacked_images in grouped_images.items():
batch_size, channels, original_height, original_width = stacked_images.shape
if bool(self.do_convert_rgb) and channels == 1:
stacked_images = stacked_images.repeat(1, 3, 1, 1)
target_height, target_width = get_image_size_for_max_num_patches(
original_height,
original_width,
patch_size,
max_num_patches,
min_num_patches=min_num_patches,
pixel_shuffle_scale=pixel_shuffle_scale,
)
if do_resize:
image_batch = self.resize(
stacked_images,
SizeDict(height=target_height, width=target_width),
interpolation=interpolation,
)
else:
if (original_height % patch_size) or (original_width % patch_size):
raise ValueError(
f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution."
)
image_batch, target_height, target_width = (
stacked_images,
original_height,
original_width,
)
image_batch = self.rescale_and_normalize(
image_batch,
do_rescale=do_rescale,
rescale_factor=rescale_factor,
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
)
patches = torch_extract_patches(image_batch, patch_size, patch_size)
_, height_tokens, width_tokens, _ = patches.shape
token_grid = (
torch.tensor([height_tokens, width_tokens], device=patches.device)
.long()
.expand(batch_size, 2)
)
real_dim = (
torch.tensor(
[1, height_tokens, width_tokens],
dtype=torch.long,
device=patches.device,
)
.unsqueeze(0)
.repeat(batch_size, 1)
)
if (height_tokens % pixel_shuffle_scale) or (
width_tokens % pixel_shuffle_scale
):
raise ValueError(
f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle."
)
virtual_height = height_tokens // pixel_shuffle_scale
virtual_width = width_tokens // pixel_shuffle_scale
virtual_dim = (
torch.tensor(
[1, virtual_height, virtual_width],
dtype=torch.long,
device=patches.device,
)
.unsqueeze(0)
.repeat(batch_size, 1)
)
grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim)
def _reorder_grouped_item( # reorder an item of tuple payloads using the same grouped_images_index
grouped: dict[
tuple[int, ...],
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
],
grouped_index: dict[tuple[int, ...], list[int]],
item_idx: int,
) -> list[torch.Tensor]:
return reorder_images(
{k: v[item_idx] for k, v in grouped.items()}, grouped_index
)
keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size")
tensors: dict[str, torch.Tensor] = {}
for i, key in enumerate(keys):
slices = _reorder_grouped_item(grouped_outputs, grouped_images_index, i)
tensors[key] = torch.stack(slices, dim=0)
return BatchFeature(data=tensors, tensor_type=return_tensors)
def create_document_attention_mask(
config: PretrainedConfig,
input_embeds: torch.Tensor,
cu_seqlens: Optional[torch.Tensor],
) -> Optional[Union[torch.Tensor, Any]]:
"""
Materialize a backend-specific block-diagonal attention mask from packed cu_seqlens.
Returns None if cu_seqlens is missing/degenerate.
"""
if cu_seqlens is None or cu_seqlens.numel() < 2:
return None # Degenerate input: nothing to mask
seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long()
if seq_sizes.numel() == 0 or int(seq_sizes.sum()) == 0:
return None # All-empty segments produce no attention blocks
seg_ids = torch.repeat_interleave(
torch.arange(seq_sizes.numel(), device=cu_seqlens.device),
seq_sizes,
)
mask_function = packed_sequence_mask_function(seg_ids.view(1, -1))
seq_len = input_embeds.shape[1]
cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long)
mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation]
return mask_interface(
batch_size=input_embeds.shape[0],
cache_position=cache_position,
kv_length=seq_len,
kv_offset=0,
mask_function=mask_function,
attention_mask=None,
allow_is_causal_skip=False,
allow_is_bidirectional_skip=False,
dtype=input_embeds.dtype,
config=config,
use_vmap=False,
)
class IsaacVisionEmbeddings(Siglip2VisionEmbeddings):
"""Adapter around SigLIP2 vision embeddings that consumes packed patch sequences.
Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image
`token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional
embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing).
"""
def __init__(self, config: IsaacVisionConfig):
super().__init__(config)
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
self.patch_embedding = nn.Linear(
in_features=config.num_channels * self.patch_size * self.patch_size,
out_features=self.embed_dim,
)
self.num_patches = config.num_patches
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
@check_model_inputs
def forward(
self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor
) -> torch.Tensor:
# Rebatch packed variable-resolution patches to resize per-image position embeddings
# and track lengths for varlen attention metadata.
packed_pixel_values, seq_lengths = self._pack_to_batch(
seq_patches, spatial_shapes
)
if packed_pixel_values is None:
return seq_patches.new_zeros((0, self.embed_dim))
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype))
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size,
self.position_embedding_size,
-1,
)
resized_positional_embeddings = self.resize_positional_embeddings(
positional_embeddings,
spatial_shapes,
max_length=packed_pixel_values.shape[1],
)
embeddings = patch_embeds + resized_positional_embeddings
return self._unpack_from_batch(embeddings, seq_lengths)
def _pack_to_batch(
self,
seq_patches: torch.Tensor,
spatial_shapes: torch.Tensor,
) -> tuple[Optional[torch.Tensor], torch.Tensor]:
"""Rebatch a packed patch sequence using per-image grids to align embeddings.
Args:
seq_patches: Packed patches of shape (total_patches, patch_dim).
spatial_shapes: Per-image patch grids of shape (num_images, 2) as (H_tokens, W_tokens).
Returns:
(packed_pixel_values, seq_lengths) where:
- packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0
- seq_lengths: (batch,) lengths for each image
"""
seq_lengths = spatial_shapes.long().prod(dim=-1) # (B,)
batch_size = int(seq_lengths.numel())
if batch_size == 0:
return None, seq_lengths
# Split the packed sequence into per-image chunks, then pad to a batch
lengths_list = seq_lengths.tolist()
chunks = seq_patches.split(lengths_list, dim=0)
packed_pixel_values = nn.utils.rnn.pad_sequence(
chunks, batch_first=True
) # zero-padded by default
return packed_pixel_values, seq_lengths
def _unpack_from_batch(
self, embeddings: torch.Tensor, seq_lengths: torch.Tensor
) -> torch.Tensor:
"""Flatten a padded batch back to packed sequence order using `seq_lengths`."""
lengths = seq_lengths.to(device=embeddings.device).tolist()
chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0]
return torch.cat(chunks, dim=0)
class IsaacVisionAttention(Siglip2Attention):
"""Custom attention that supports variable-length sequences with flash attention."""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
**kwargs,
):
kwargs.pop("output_hidden_states", None)
kwargs.pop("return_dict", None)
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
keys = keys.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
values = values.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
attn_impl = self.config._attn_implementation
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"]
if attn_impl != "sdpa":
attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl]
attention_kwargs: dict[str, Any] = {
"is_causal": False,
"scaling": self.scale,
}
supports_varlen = cu_seqlens is not None and attn_impl in {
"flash_attention_2",
"flash_attention_3",
"flex_attention",
"paged|flash_attention_2",
"paged|flash_attention_3",
}
if supports_varlen:
if max_seqlen is not None:
max_q = max_k = int(max_seqlen)
elif cu_seqlens.numel() >= 2:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
max_q = max_k = lengths.max() if lengths.numel() > 0 else seq_length
else:
max_q = max_k = seq_length
attention_kwargs.update(
{
"cu_seq_lens_q": cu_seqlens,
"cu_seq_lens_k": cu_seqlens,
"max_length_q": max_q,
"max_length_k": max_k,
}
)
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
**attention_kwargs,
)
attn_output = attn_output.reshape(
batch_size, seq_length, embed_dim
).contiguous()
attn_output = self.out_proj(attn_output)
return attn_output, attn_weights
class IsaacVisionEncoderLayer(Siglip2EncoderLayer):
"""Isaac vision encoder layer with variable-length attention."""
def __init__(self, config: IsaacVisionConfig):
super().__init__(config)
self.self_attn = IsaacVisionAttention(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
output_attentions: bool = False,
**kwargs: Unpack[TransformersKwargs],
):
r"""
cu_seqlens (`torch.Tensor`, *optional*):
Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive
entries gives each document's token count and enables block-diagonal attention masking for packed batches.
max_seqlen (`int`, *optional*):
Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary
buffers for packed variable-length attention.
"""
# Run attention directly so variable-length metadata reaches FlashAttention.
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
attn_output, _ = self.self_attn(
hidden_states,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
**kwargs,
)
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class IsaacVisionEncoder(Siglip2Encoder):
"""Encoder using Isaac encoder layers with variable-length attention support."""
def __init__(self, config: IsaacVisionConfig):
super().__init__(config)
self.layers = nn.ModuleList(
[IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)]
)
def create_pixel_shuffle_index_map(
seq_sizes: torch.Tensor,
token_grids: torch.Tensor,
scale_factor: int = 1,
device: Optional[torch.device] = None,
) -> torch.Tensor:
"""
Build a gather-index map that tells us, for every *output* token after
pixel-shuffle, which `scale_factor**2` *input* tokens are being merged.
Args
----
seq_sizes : (num_images,) - #patches in each image (row-major order)
token_grids : (num_images,2) - (height, width) for every image
scale_factor : spatial down-scale factor (≥2)
device : (optional) overrides `seq_sizes.device`
Returns
-------
gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor.
gather_idx[i, j] is the *flat* index into the *original*
packed sequence for the j-th sub-patch that forms the
i-th output token.
"""
if not is_torchdynamo_compiling():
if (token_grids % scale_factor).any():
raise AssertionError(
f"Every (H,W) in token_grids must be divisible by scale_factor={scale_factor}, got {token_grids.tolist()}"
)
gather_chunks: list[torch.Tensor] = []
tok_offset = 0
for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist()):
# Flat indices for this image's packed segment
grid = (
torch.arange(seq_len, device=device, dtype=torch.int64).view(h, w)
+ tok_offset
)
# Block into (H/s, W/s) groups; each group contributes s*s indices
grid = (
grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor)
.permute(0, 2, 1, 3)
.contiguous()
)
gather_chunks.append(grid.view(-1, scale_factor * scale_factor))
tok_offset += seq_len
return torch.cat(gather_chunks, dim=0)
def pixel_shuffle_varlen(
x: torch.Tensor,
token_grids: torch.Tensor,
scale_factor: int = 1,
) -> torch.Tensor:
r"""Apply pixel shuffle to a packed vision sequence without unpacking per image.
Args:
x (`torch.Tensor`):
Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes
produced by stacking image patches.
token_grids (`torch.Tensor`):
Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes
corresponding to each image segment inside `x`.
scale_factor (`int`, *optional*, defaults to 1):
Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a
single embedding channel-group.
Returns:
`torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention:
`(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)`
if the singleton batch dimension was present.
Raises:
ValueError: If more than one batch item is provided.
"""
return_with_batch_dim = x.dim() == 3
if return_with_batch_dim:
if x.size(0) != 1:
raise ValueError(
f"Packed vision sequences expect a singleton batch dimension; received batch_size={x.size(0)}."
)
embeddings = x.squeeze(0) # (seq, embed)
else:
embeddings = x # (seq, embed)
embed_dim = embeddings.size(-1)
scale_factor = int(scale_factor)
# Calculate seq_sizes from token_grids
seq_sizes = torch.prod(token_grids, dim=-1)
# Build a single gather index so pixel shuffle works on the packed stream
# without unpacking per-image grids.
gather_idx = create_pixel_shuffle_index_map(
seq_sizes=seq_sizes,
token_grids=token_grids,
scale_factor=scale_factor,
device=embeddings.device,
) # (new_seq, scale_factor**2)
# Gather → (new_seq, scale_factor**2, embed_dim)
gathered = embeddings[gather_idx] # fancy indexing keeps gradient
# Merge the scale_factor**2 group dimension into channels to finish the shuffle
out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor)
# Restore batch dimension if needed
if return_with_batch_dim:
out = out.unsqueeze(0)
return out
class IsaacVisionTransformer(nn.Module):
"""Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs.
Args:
config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters.
Inputs:
packed_seq_patches (Tuple[Tensor, Tensor]): ``(patches, token_grids)`` where ``patches`` is a packed
patch sequence and ``token_grids`` holds per-image (H_tokens, W_tokens).
Returns:
torch.Tensor: Vision embeddings after encoder + pixel shuffle, shaped ``(seq_len, hidden_size * s^2)``.
"""
_supports_sdpa = True
def __init__(self, config: IsaacVisionConfig):
super().__init__()
self.config = config
self.embeddings = IsaacVisionEmbeddings(config)
self.encoder = IsaacVisionEncoder(config)
self.post_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor
def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]):
seq_patches, token_grids = packed_seq_patches
seq_sizes = torch.prod(token_grids, dim=-1)
# Get embeddings from packed sequence
hidden_states = self.embeddings(seq_patches, token_grids)
# Add a pseudo batch dimension so we can reuse the batch-first encoder stack
# while still driving per-image cu_seqlens through the varlen attention path.
hidden_states = hidden_states.unsqueeze(0)
# Generate cumulative sequence lengths for variable-length attention
cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0))
attention_mask = create_document_attention_mask(
self.config, hidden_states, cu_seqlens
)
# Pass through encoder with variable-length attention parameters
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
cu_seqlens=cu_seqlens,
)
hidden_states = encoder_outputs.last_hidden_state
# Apply final layer normalization
hidden_states = self.post_layernorm(hidden_states)
hidden_states = pixel_shuffle_varlen(
x=hidden_states,
token_grids=token_grids,
scale_factor=self.pixel_shuffle_scale_factor,
)
# Remove the pseudo batch dimension we added earlier
hidden_states = hidden_states.squeeze(0)
# Return the full sequence of embeddings
return hidden_states
class IsaacMultiModalProjector(nn.Module):
"""Maps vision tower outputs to the text hidden size with a SiLU MLP."""
def __init__(self, config: IsaacConfig):
super().__init__()
self.vision_hidden_size = config.vision_config.hidden_size * (
config.vision_config.pixel_shuffle_scale_factor**2
)
self.backbone_hidden_size = config.hidden_size
self.linear_1 = nn.Linear(
self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False
)
self.silu = nn.SiLU()
self.linear_2 = nn.Linear(
4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False
)
def forward(self, image_features):
hidden_states = self.linear_1(image_features)
hidden_states = self.silu(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class IsaacVisionEmbedding(nn.Module):
_supports_sdpa = True
def __init__(self, config: IsaacConfig):
super().__init__()
vision_cfg = config.vision_config
self.vision_tower = IsaacVisionTransformer(vision_cfg)
self.multimodal_projector = IsaacMultiModalProjector(config)
def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
hidden_states = self.vision_tower(vision_tokens)
return self.multimodal_projector(hidden_states)
def get_scaled_image_size(
scale: float,
original_size: int,
patch_size: int,
pixel_shuffle_scale: int,
) -> int:
scaled_size = scale * original_size
divisor = patch_size * pixel_shuffle_scale
scaled_size = math.ceil(scaled_size / divisor) * divisor
scaled_size = max(divisor, scaled_size)
return int(scaled_size)
def get_image_size_for_max_num_patches(
image_height: int,
image_width: int,
patch_size: int,
max_num_patches: int,
min_num_patches: Optional[int] = None,
eps: float = 1e-5,
pixel_shuffle_scale: int = 1,
) -> tuple[int, int]:
r"""Compute a target resolution whose patch grid satisfies patching parametrization.
Args:
image_height (`int`):
Height in pixels of the source image prior to any resizing.
image_width (`int`):
Width in pixels of the source image prior to any resizing.
patch_size (`int`):
Size of the square patch used by the vision encoder.
max_num_patches (`int`):
Upper bound on `(height / patch_size) * (width / patch_size)` after resizing.
min_num_patches (`int`, *optional*):
Lower bound on the number of patches. When provided the image will be scaled up if necessary.
eps (`float`, *optional*, defaults to 1e-5):
Convergence tolerance for the internal binary search to determing the target dimensions.
pixel_shuffle_scale (`int`, *optional*, defaults to 1):
Additional stride multiplier applied when pixel shuffle later reduces spatial resolution.
Returns:
`tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale`
and respect both the maximum and optional minimum patch-count constraints.
"""
# Ensure divisibility
divisor = patch_size * pixel_shuffle_scale
adjusted_height = math.ceil(image_height / divisor) * divisor
adjusted_height = max(divisor, adjusted_height)
adjusted_width = math.ceil(image_width / divisor) * divisor
adjusted_width = max(divisor, adjusted_width)
num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size)
if min_num_patches is not None and num_patches < min_num_patches:
# Scale up via binary search to satisfy the minimum patch budget while
# preserving divisibility by patch_size * pixel_shuffle_scale.
scale_min, scale_max = 1.0, 100.0
while (scale_max - scale_min) >= eps:
scale = (scale_min + scale_max) / 2
target_height = get_scaled_image_size(
scale, image_height, patch_size, pixel_shuffle_scale
)
target_width = get_scaled_image_size(
scale, image_width, patch_size, pixel_shuffle_scale
)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches >= min_num_patches:
scale_max = scale
else:
scale_min = scale
scale = scale_max
target_height = get_scaled_image_size(
scale, image_height, patch_size, pixel_shuffle_scale
)
target_width = get_scaled_image_size(
scale, image_width, patch_size, pixel_shuffle_scale
)
return target_height, target_width
elif num_patches <= max_num_patches:
return adjusted_height, adjusted_width
else:
# Scale down
scale_min, scale_max = eps / 10, 1.0
while (scale_max - scale_min) >= eps:
scale = (scale_min + scale_max) / 2
target_height = get_scaled_image_size(
scale, image_height, patch_size, pixel_shuffle_scale
)
target_width = get_scaled_image_size(
scale, image_width, patch_size, pixel_shuffle_scale
)
num_patches = (target_height / patch_size) * (target_width / patch_size)
if num_patches <= max_num_patches:
scale_min = scale
else:
scale_max = scale
scale = scale_min
target_height = get_scaled_image_size(
scale, image_height, patch_size, pixel_shuffle_scale
)
target_width = get_scaled_image_size(
scale, image_width, patch_size, pixel_shuffle_scale
)
return target_height, target_width
class IsaacConfig(PretrainedConfig):
"""Configuration class for Isaac multimodal model.
This configuration corresponds to checkpoints such as
[Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base).
"""
model_type = "isaac"
sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config}
image_processor_type = "IsaacImageProcessor"
def __init__(
self,
vision_config: Optional[IsaacVisionConfig] = None,
text_config: Optional[Union[Qwen3Config, dict]] = None,
vision_rescale_factor: float = 1 / 255,
max_sequence_length: int = 16384,
vision_token: str = "<image>",
**kwargs,
):
attn_implementation = kwargs.get("attn_implementation")
if isinstance(text_config, dict):
self.text_config = self.sub_configs["text_config"](**text_config)
elif isinstance(text_config, Qwen3Config):
self.text_config = text_config
elif text_config is None:
self.text_config = self.sub_configs["text_config"]()
# Seed RoPE parameters before base init so the shared mixin can standardize/validate them.
self.rope_parameters = getattr(self.text_config, "rope_parameters", None)
self.layer_types = getattr(self.text_config, "layer_types", None)
super().__init__(**kwargs)
# Keep rope parameters aligned between the composite and text sub-configs.
self.text_config.rope_parameters = self.rope_parameters
# Mirror frequently accessed Qwen3 attributes at the composite config level
self.vocab_size = self.text_config.vocab_size
self.hidden_size = self.text_config.hidden_size
self.num_hidden_layers = self.text_config.num_hidden_layers
self.num_attention_heads = self.text_config.num_attention_heads
self.head_dim = self.text_config.head_dim
self.hidden_act = self.text_config.hidden_act
self.use_cache = self.text_config.use_cache
self.rope_theta = self.rope_parameters["rope_theta"]
self.layer_types = getattr(self.text_config, "layer_types", None)
layer_type_validation(self.layer_types, self.num_hidden_layers)
# Handle vision config - either dict or IsaacVisionConfig instance
if isinstance(vision_config, dict):
self.vision_config = self.sub_configs["vision_config"](**vision_config)
elif isinstance(vision_config, IsaacVisionConfig):
self.vision_config = vision_config
elif vision_config is None:
self.vision_config = self.sub_configs["vision_config"]()
# Propagate user-requested attention backend to the vision sub-config when provided.
if attn_implementation is not None:
if isinstance(attn_implementation, dict):
vision_attn = attn_implementation.get(
"vision_config", attn_implementation.get("", None)
)
else:
vision_attn = attn_implementation
if vision_attn is not None:
self.vision_config._attn_implementation = vision_attn
if getattr(self, "_attn_implementation", None) is None:
self._attn_implementation = "sdpa"
# Vision normalization parameters
self.vision_rescale_factor = float(vision_rescale_factor)
# Processing parameters
self.max_sequence_length = max_sequence_length
self.vision_token = vision_token
def to_dict(self):
output = super().to_dict()
# Ensure nested configs round-trip through dict serialization
if hasattr(self, "text_config") and self.text_config is not None:
output["text_config"] = self.text_config.to_dict()
if hasattr(self, "vision_config") and self.vision_config is not None:
output["vision_config"] = self.vision_config.to_dict()
return output
class IsaacProcessor(ProcessorMixin):
"""Processor that pairs the Isaac image processor with the Qwen2 tokenizer.
Args:
image_processor: Vision preprocessor (fast) used for patch extraction.
tokenizer: Qwen2 tokenizer instance.
vision_token (str, optional): Placeholder token marking image locations. Defaults to "<image>".
max_sequence_length (int, optional): Maximum combined text+vision tokens kept. Defaults to 16384.
rescale_factor (float, optional): Image rescale factor; defaults to 1/255.
config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config.
Returns:
BatchFeature: Contains ``input_ids`` and ``packed_inputs`` (patch tensors, grids, offsets, lengths, modality, positions).
"""
attributes = ["image_processor", "tokenizer"]
image_processor_class = ("IsaacImageProcessorFast",)
tokenizer_class = ("Qwen2Tokenizer",)
pad_token_id = 151643
def __init__(
self,
image_processor,
tokenizer,
*,
vision_token: str = "<image>",
max_sequence_length: int = 16384,
rescale_factor: Optional[float] = None,
config: Optional[Union[IsaacConfig, dict]] = None,
) -> None:
if isinstance(config, dict):
config = IsaacConfig(**config)
if config is not None:
vision_token = config.vision_token
max_sequence_length = config.max_sequence_length
rescale_factor = config.vision_rescale_factor
resolved_rescale_factor = (
float(rescale_factor) if rescale_factor is not None else float(1 / 255)
)
if config is not None:
config.vision_rescale_factor = resolved_rescale_factor
self.image_processor = image_processor
super().__init__(image_processor, tokenizer)
text_pad_token_id = getattr(self.tokenizer, "pad_token_id", None)
image_pad_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>")
self.text_pad_token_id = int(text_pad_token_id)
self.image_pad_token_id = int(image_pad_token_id)
self.pad_token_id = self.text_pad_token_id
self.current_processor = self.image_processor
self.config = config
self.chat_template = getattr(self.tokenizer, "chat_template", None)
self.vision_token = vision_token
self.max_sequence_length = max_sequence_length
def _pack_batch(
self, texts: list[str], images_list: Optional[list[Optional[list[Image]]]]
) -> dict[str, Optional[torch.Tensor]]:
if images_list is None:
pairs = ((t, None) for t in texts)
else:
pairs = zip(texts, images_list, strict=True)
per_sample: list[dict[str, Optional[torch.Tensor]]] = []
for txt, imgs in pairs:
if imgs is not None and isinstance(imgs, Image):
imgs = [imgs]
per_sample.append(self._pack_single(txt, imgs))
lengths = [int(p["input_ids"].shape[1]) for p in per_sample]
max_len = max(lengths, default=0)
batch = len(per_sample)
# Use first device with data as anchor
base_device = torch.device("cpu")
for p in per_sample:
if p["input_ids"].numel() > 0:
base_device = p["input_ids"].device
break
pad_id = self.text_pad_token_id
padded_input_ids = torch.full(
(batch, max_len), pad_id, device=base_device, dtype=torch.long
)
padded_modality = torch.full(
(batch, max_len),
ModalityType.text.value,
device=base_device,
dtype=torch.long,
)
padded_position_ids = torch.zeros(
(batch, max_len, 3), device=base_device, dtype=torch.long
)
for i, (sample, l) in enumerate(zip(per_sample, lengths)):
if l:
padded_input_ids[i, -l:] = sample["input_ids"][0]
padded_modality[i, -l:] = sample["modality_tensor"][0]
padded_position_ids[i, -l:] = sample["position_ids"][0]
# Vision-side aggregation
v_samples = [
(b, s) for b, s in enumerate(per_sample) if s["vision_patches"] is not None
]
if v_samples:
vision_patches_list = [s["vision_patches"] for _, s in v_samples]
vision_grids_list = [s["vision_token_grids"] for _, s in v_samples]
vision_offsets_list = [s["vision_token_offsets"] for _, s in v_samples]
vision_lengths_list = [s["vision_token_lengths"] for _, s in v_samples]
vision_batch_indices = [
torch.full_like(s["vision_token_offsets"], b) for b, s in v_samples
]
vision_patches = torch.cat(vision_patches_list, dim=0)
vision_token_grids = torch.cat(vision_grids_list, dim=0)
vision_token_offsets = torch.cat(vision_offsets_list, dim=0)
vision_token_lengths = torch.cat(vision_lengths_list, dim=0)
vision_token_batch_indices = torch.cat(vision_batch_indices, dim=0)
else:
vision_patches = vision_token_grids = vision_token_offsets = (
vision_token_lengths
) = vision_token_batch_indices = None
return {
"input_ids": padded_input_ids,
"vision_patches": vision_patches,
"vision_token_grids": vision_token_grids,
"vision_token_offsets": vision_token_offsets,
"vision_token_lengths": vision_token_lengths,
"vision_token_batch_indices": vision_token_batch_indices,
"modality_tensor": padded_modality,
"position_ids": padded_position_ids,
}
def _pack_single(
self, text: str, images: Optional[list[Image]]
) -> dict[str, Optional[torch.Tensor]]:
segments = text.split(
self.vision_token
) # Parse by vision_token; interleave text segments and image segments.
num_images = len(segments) - 1
items: list[dict[str, Any]] = []
total = 0
num_provided_images = len(images) if images is not None else 0
if not num_images == num_provided_images:
raise ValueError(
f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text} "
)
for index, segment in enumerate(segments):
if segment:
tok = (
self.tokenizer.encode(
segment, add_special_tokens=False, return_tensors="pt"
)
.squeeze(0)
.to(torch.long)
)
segment_length = int(tok.numel())
items.append(
{"type": "text", "segment_length": segment_length, "tok": tok}
)
total += segment_length
if index < num_images:
feat = self.image_processor(
images=images[index], return_tensors=TensorType.PYTORCH
)
patches = feat["patches"][0].reshape(-1, feat["patches"].shape[-1])
virtual_pixel_size = (
feat["virtual_pixel_size"][0].to(torch.long).tolist()
)
real_pixel_size = feat["real_pixel_size"][0].to(torch.long).tolist()
dims = tuple(
(virtual_pixel_size + [1, 1, 1])[:3]
) # (T,H,W) in virtual space
segment_length = int(dims[0] * dims[1] * dims[2])
items.append(
{
"type": "image",
"segment_length": segment_length,
"dims": dims,
"patches": patches,
"grid": (int(real_pixel_size[1]), int(real_pixel_size[2])),
}
)
total += segment_length
# Tail crop window.
start = max(0, total - self.max_sequence_length)
end = total
image_pad_value = self.image_pad_token_id
base_device: Optional[torch.device] = None
position_ids, modality, input_ids = [], [], []
vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], []
global_offset = 0
position_offset = 0
for item in items:
segment_length = int(item["segment_length"])
current_window_start = max(start, global_offset)
current_window_end = min(end, global_offset + segment_length)
has_overlap = current_window_end > current_window_start
if has_overlap and base_device is None:
base_device = (
item["patches"].device
if item["type"] == "image"
else item["tok"].device
)
if has_overlap:
segment_local_start = int(current_window_start - global_offset)
segment_local_end = int(current_window_end - global_offset)
segment_local_indices = torch.arange(
segment_local_start,
segment_local_end,
device=base_device,
dtype=torch.long,
)
segment_kept_length = segment_local_end - segment_local_start
if item["type"] == "text":
slice_index = segment_local_indices + position_offset
zero_axis_pad = torch.zeros_like(slice_index)
position_ids.append(
torch.stack((slice_index, zero_axis_pad, zero_axis_pad), -1)
)
modality.append(
torch.full(
(segment_kept_length,),
ModalityType.text.value,
device=base_device,
dtype=torch.long,
)
)
input_ids.append(
item["tok"].to(base_device)[
segment_local_start:segment_local_end
]
)
position_offset += segment_length
else:
num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"]
hw = grid_height_tokens * grid_width_tokens
slice_index = (segment_local_indices // hw) + position_offset
rem = segment_local_indices % hw
row_index = rem // grid_width_tokens
col_index = rem % grid_width_tokens
position_ids.append(
torch.stack((slice_index, row_index, col_index), -1)
)
modality.append(
torch.full(
(segment_kept_length,),
ModalityType.image.value,
device=base_device,
dtype=torch.long,
)
)
input_ids.append(
torch.full(
(segment_kept_length,),
image_pad_value,
device=base_device,
dtype=torch.long,
)
)
vpatches.append(
item["patches"].to(base_device)
) # full patches; slice later via offsets/lengths
# Record per-image slice boundaries so we can drop cropped virtual tokens
# after pixel shuffle without re-packing the entire vision stream.
grids.append(item["grid"])
vision_token_offsets.append(segment_local_start)
vision_token_lengths.append(segment_kept_length)
position_offset += int(num_pos_slices)
else:
position_offset += (
segment_length if item["type"] == "text" else int(item["dims"][0])
)
global_offset += segment_length
if base_device is None:
base_device = torch.device("cpu")
modality_tensor = (
torch.cat(modality, 0).unsqueeze(0)
if modality
else torch.zeros((1, 0), device=base_device, dtype=torch.long)
)
position_ids = (
torch.cat(position_ids, 0).unsqueeze(0)
if position_ids
else torch.zeros((1, 0, 3), device=base_device, dtype=torch.long)
)
input_ids = (
torch.cat(input_ids, 0).unsqueeze(0)
if input_ids
else torch.zeros((1, 0), device=base_device, dtype=torch.long)
)
if vpatches:
vision_patches = torch.cat(vpatches, 0)
vision_token_grids = torch.tensor(
grids, device=base_device, dtype=torch.long
)
vision_token_offsets = torch.tensor(
vision_token_offsets, device=base_device, dtype=torch.long
)
vision_token_lengths = torch.tensor(
vision_token_lengths, device=base_device, dtype=torch.long
)
else:
vision_patches = vision_token_grids = vision_token_offsets = (
vision_token_lengths
) = None
return {
"input_ids": input_ids,
"vision_patches": vision_patches,
"vision_token_grids": vision_token_grids,
"vision_token_offsets": vision_token_offsets,
"vision_token_lengths": vision_token_lengths,
"modality_tensor": modality_tensor,
"position_ids": position_ids,
}
def __call__(
self,
text: Union[str, list[str]],
images: Optional[Union[Image, list[Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
**kwargs,
) -> BatchFeature:
texts = [text] if isinstance(text, str) else text
images_list: Optional[list[Optional[list[Image]]]] = None
if images is not None:
if isinstance(images, list) and len(images) == len(texts):
if not images:
images_list = []
elif isinstance(images[0], list):
images_list = images # already per-sample
else:
images_list = [
[img] for img in images
] # list of images, one per sample
else:
images_list = []
for t in texts:
n_tok = t.count(self.vision_token)
if n_tok == 0:
images_list.append(None)
else:
if isinstance(images, list):
images_list.append(images)
else:
images_list.append([images])
packed = self._pack_batch(texts, images_list)
input_ids = packed.pop("input_ids")
return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed})
class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding):
def __init__(self, config: IsaacConfig, device=None):
rope_source_cfg = (
config.get_text_config() if hasattr(config, "get_text_config") else config
)
rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {}
config_for_rope = copy.copy(rope_source_cfg)
config_for_rope.rope_scaling = rope_scaling
init_device = (
device
if device is not None and getattr(device, "type", None) != "meta"
else None
)
super().__init__(config_for_rope, device=init_device)
rotary_half_dim = self.inv_freq.shape[0]
self.mrope_section = self._resolve_mrope_section(
rope_scaling.get("mrope_section"), rotary_half_dim
)
self.hidden_size = (
getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size
)
@staticmethod
def _resolve_mrope_section(
section: Optional[list[int]], rotary_half_dim: int
) -> list[int]:
if section is None:
weights = (2, 1, 1)
base = [rotary_half_dim * w // sum(weights) for w in weights]
base[0] += rotary_half_dim - sum(base)
return base
section = [int(v) for v in section]
return section
def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor:
split_sections = tuple(self.mrope_section * 2)
chunks = tensor.split(split_sections, dim=-1)
return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
def forward(
self,
position_ids: torch.Tensor,
modality_tensor: torch.Tensor,
hidden_states: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if hidden_states is None:
batch, seq_len, _ = position_ids.shape
hidden_states = torch.zeros(
batch,
seq_len,
self.hidden_size,
dtype=torch.float32,
device=position_ids.device,
)
with torch.no_grad():
pos = position_ids.clone()
not_spatial = modality_tensor != ModalityType.image.value
data_1d = pos[not_spatial][..., 0].unsqueeze(
-1
) # Collapse non-vision modalities to 1D positions
pos[not_spatial] = data_1d.expand(-1, pos.shape[-1])
pos_axes = pos.permute(2, 0, 1).contiguous()
cos_axes, sin_axes = super().forward(hidden_states, pos_axes)
cos_axes, sin_axes = (
cos_axes.to(hidden_states.dtype),
sin_axes.to(hidden_states.dtype),
)
cos_combined, sin_combined = (
self._combine_axes(cos_axes),
self._combine_axes(sin_axes),
)
return cos_combined, sin_combined
@auto_docstring
class IsaacModel(Qwen3PreTrainedModel):
supports_gradient_checkpointing = True
_can_compile_fullgraph = False
_supports_flex_attn = False
_can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)}
all_tied_weights_keys: dict[str, str] = {}
def __init__(self, config: IsaacConfig):
Qwen3PreTrainedModel.__init__(self, config)
text_cfg_source = config.text_config
text_cfg = copy.deepcopy(text_cfg_source)
self.text_model = Qwen3Model._from_config(text_cfg)
self.text_model.config = (
config # Ensure downstream callers observe the composed config
)
self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device)
self.vision_embedding = IsaacVisionEmbedding(config)
self.vision_embedding._supports_sdpa = True
self.max_sequence_length = config.max_sequence_length
self.vision_rescale_factor = config.vision_rescale_factor
self.vision_token = config.vision_token
self.rope_deltas = None
self.post_init()
# Respect config-specified gradient checkpointing
if getattr(config, "gradient_checkpointing", False):
self.gradient_checkpointing_enable()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.get_input_embeddings()
def set_input_embeddings(self, value: nn.Module) -> None:
self.text_model.set_input_embeddings(value)
vocab_size = getattr(value, "num_embeddings", None)
if vocab_size is not None:
self.config.vocab_size = vocab_size
if hasattr(self.config, "text_config"):
self.config.text_config.vocab_size = vocab_size
self.text_model.config.vocab_size = vocab_size
@property
def embed_tokens(self) -> nn.Module:
return self.text_model.embed_tokens
@embed_tokens.setter
def embed_tokens(self, value: nn.Module) -> None:
self.text_model.embed_tokens = value
@property
def vision_model(self) -> nn.Module:
return self.vision_embedding.vision_tower
def embed_packed_inputs(
self, input_ids: torch.Tensor, packed_inputs: dict[str, Optional[torch.Tensor]]
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Expects input_ids for text tokens and packed_inputs containing:
- modality_tensor: (batch, seq_len) modality ids aligned to the sequence
- position_ids: (batch, seq_len, 3) MRoPE coordinates (optional)
- vision_patches: concatenated vision tokens shaped (total_tokens, embed_dim) or None
- vision_token_grids: (num_images, 2) token grid sizes or None
- vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional)
- vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional)
- vision_token_batch_indices: (num_images,) batch row for each image (optional; defaults to zeros)
"""
modality = packed_inputs["modality_tensor"].to(
device=input_ids.device, dtype=torch.long
)
embeds = self.text_model.embed_tokens(input_ids)
vision_patches = packed_inputs.get("vision_patches")
if vision_patches is None:
return embeds, modality
token_grids = packed_inputs["vision_token_grids"].to(
device=vision_patches.device, dtype=torch.long
)
vision = self.vision_embedding(
(vision_patches, token_grids)
) # (total_tokens, hidden)
# per-image token counts AFTER pixel-shuffle
vision_reduction_factor = int(
self.config.vision_config.pixel_shuffle_scale_factor
)
sizes = (
token_grids.prod(-1)
.div(
vision_reduction_factor * vision_reduction_factor, rounding_mode="floor"
)
.tolist()
)
offsets = packed_inputs.get("vision_token_offsets")
lengths = packed_inputs.get("vision_token_lengths")
batch_indices = packed_inputs.get("vision_token_batch_indices")
chunks = vision.split(sizes, dim=0)
picked: list[torch.Tensor] = []
picked_batch: list[int] = []
for chunk, size, offset, length, batch_index in zip(
chunks,
sizes,
offsets.tolist(),
lengths.tolist(),
(batch_indices.tolist() if batch_indices is not None else [0] * len(sizes)),
):
if size <= 0:
continue
offset = max(0, min(int(offset), size))
length = max(0, min(int(length), size - offset))
if length:
picked.append(chunk[offset : offset + length])
picked_batch.append(int(batch_index))
if picked:
vision_chunks = picked
vision_batch_idx = picked_batch
else:
vision_chunks = vision_batch_idx = []
vision = (
torch.cat(vision_chunks, 0)
if vision_chunks
else vision.new_zeros((0, vision.size(-1)))
)
embeds = embeds.clone()
num_batches = modality.shape[0]
image_positions = [
(modality[b] == ModalityType.image.value)
.nonzero(as_tuple=False)
.squeeze(-1)
for b in range(num_batches)
]
cursors = [0 for _ in range(num_batches)]
for chunk, batch_index in zip(vision_chunks, vision_batch_idx):
if chunk.numel() == 0:
continue
positions = image_positions[batch_index]
start = cursors[batch_index]
end = start + chunk.shape[0]
embeds[batch_index, positions[start:end]] = chunk.to(
device=embeds.device, dtype=embeds.dtype
)
cursors[batch_index] = end
return embeds, modality
def get_rope_index(
self,
*,
position_ids: Optional[torch.Tensor] = None,
attention_mask: torch.Tensor,
inputs_embeds: torch.Tensor,
cache_position: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Build 3D position ids and per-batch RoPE deltas."""
device = inputs_embeds.device
batch_size, seq_len = inputs_embeds.shape[:2]
if position_ids is None:
cp = cache_position.to(device=device, dtype=torch.long)
if cp.ndim == 1:
cp = cp.view(1, -1).expand(batch_size or 1, -1)
base_delta = torch.as_tensor(
0 if self.rope_deltas is None else self.rope_deltas,
device=device,
dtype=torch.long,
).reshape(-1, 1)
base_delta = torch.broadcast_to(base_delta, (batch_size, 1))
mask_delta = attention_mask.to(device=device, dtype=torch.long).sum(
1, keepdim=True
) - attention_mask.size(1)
rope_position = cp + base_delta + mask_delta
pos_3d = rope_position.unsqueeze(-1).expand(-1, -1, 3)
return pos_3d, base_delta
position_ids = position_ids.to(device=device)
if position_ids.ndim == 2:
position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3)
if position_ids.shape[1] != seq_len:
start_positions = position_ids[:, :1, 0]
position_ids = (
torch.arange(seq_len, device=position_ids.device).view(1, -1)
+ start_positions
)
position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3)
attn = attention_mask.to(device=device, dtype=torch.long)
m_per_batch = position_ids.amax(dim=(1, 2))
seq_lens = attn.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=device)
rope_deltas = (
(m_per_batch + 1 - seq_lens).to(dtype=position_ids.dtype).unsqueeze(1)
)
return position_ids, rope_deltas
@auto_docstring
@check_model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
packed_inputs: Optional[dict[str, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithPast:
"""
Forward pass with MRoPE position embeddings.
Computes position embeddings once and passes them through all layers.
Args:
packed_inputs (`dict`, *optional*):
Plain tensor payloads. When provided, requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt).
modality_tensor (`torch.LongTensor`, *optional*):
Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing
values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted.
"""
output_attentions = kwargs.pop("output_attentions", None)
modality_tensor: Optional[torch.Tensor] = None
if packed_inputs is not None:
inputs_embeds, modality_tensor = self.embed_packed_inputs(
input_ids, packed_inputs
)
elif input_ids is not None:
inputs_embeds = self.text_model.embed_tokens(input_ids)
device = inputs_embeds.device
batch_size, seq_len = inputs_embeds.shape[:2]
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config.get_text_config())
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + seq_len, device=device
)
if attention_mask is None:
attention_mask = torch.ones(
inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long
)
if (
position_ids is None
and packed_inputs is not None
and packed_inputs.get("position_ids") is not None
):
position_ids = packed_inputs.get("position_ids").to(device=device)
position_ids, rope_deltas = self.get_rope_index(
position_ids=position_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
)
self.rope_deltas = rope_deltas
if modality_tensor is None:
modality_tensor = torch.full(
(batch_size, seq_len),
ModalityType.text.value,
device=device,
dtype=torch.long,
)
cos, sin = self.rotary_emb(
position_ids, modality_tensor, hidden_states=inputs_embeds
)
decoder_position_ids = (
position_ids[..., 0] if position_ids.ndim == 3 else position_ids
)
if not isinstance(attention_mask, dict):
attention_mask = create_masks_for_generate(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=decoder_position_ids,
)
is_mask_dict = isinstance(attention_mask, dict)
hidden_states = inputs_embeds
all_attentions = [] if output_attentions else None
for layer in self.text_model.layers:
layer_mask = (
attention_mask[layer.attention_type] if is_mask_dict else attention_mask
)
layer_outputs = layer(
hidden_states,
attention_mask=layer_mask,
position_ids=decoder_position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=(cos, sin),
output_attentions=output_attentions,
**kwargs,
)
layer_outputs_is_tuple = isinstance(layer_outputs, tuple)
hidden_states = (
layer_outputs[0] if layer_outputs_is_tuple else layer_outputs
)
if output_attentions and layer_outputs_is_tuple:
all_attentions.append(layer_outputs[1])
hidden_states = self.text_model.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
hidden_states=(hidden_states,),
attentions=tuple(all_attentions) if output_attentions else None,
)
@auto_docstring
class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin):
config_class = IsaacConfig
_can_compile_fullgraph = False
_tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"}
all_tied_weights_keys: dict[str, str] = {
"lm_head.weight": "model.text_model.embed_tokens.weight"
}
def __init__(self, config: IsaacConfig):
super().__init__(config)
self.model = IsaacModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
@auto_docstring
@can_return_tuple
@check_model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
packed_inputs: Optional[dict[str, torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[list[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | CausalLMOutputWithPast:
"""Run multimodal CausalLM forward, accepting packed vision/text inputs.
Args:
packed_inputs (`dict`, *optional*):
Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and
vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings.
Returns:
CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions.
"""
output_attentions = kwargs.pop("output_attentions", None)
outputs = self.model(
input_ids=input_ids,
packed_inputs=packed_inputs,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.vocab_size
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions if output_attentions else None,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[list[torch.FloatTensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
packed_inputs: Optional[dict[str, torch.Tensor]] = None,
cache_position: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
**kwargs,
) -> dict[str, Any]:
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
**kwargs,
)
if packed_inputs is None:
return model_inputs
past_len = (
past_key_values.get_seq_length() if past_key_values is not None else 0
)
first_step = past_len == 0
model_inputs["packed_inputs"] = packed_inputs if first_step else None
model_inputs["position_ids"] = None
return model_inputs
@classmethod
def can_generate(cls) -> bool:
return True
def set_input_embeddings(self, value: nn.Module) -> None:
self.model.set_input_embeddings(value)
vocab_size = getattr(value, "num_embeddings", None)
self.config.vocab_size = vocab_size
self.model.config.vocab_size = vocab_size
self.model.text_model.config.vocab_size = vocab_size
if self.lm_head.weight.shape[0] != vocab_size:
self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False)
self.lm_head.weight = self.model.text_model.embed_tokens.weight
__all__ = [
"IsaacConfig",
"IsaacModel",
"IsaacPreTrainedModel", # noqa: F822
"IsaacForConditionalGeneration",
"IsaacImageProcessorFast",
"IsaacProcessor",
]