| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from __future__ import annotations |
| |
|
| | import copy |
| | import math |
| | from collections.abc import Callable, Sequence |
| | from enum import IntEnum |
| | from typing import Any, Optional, Union |
| |
|
| | from transformers.cache_utils import DynamicCache |
| | from transformers.configuration_utils import PretrainedConfig, layer_type_validation |
| | from transformers.feature_extraction_utils import BatchFeature |
| | from transformers.generation.utils import GenerationMixin |
| | from transformers.image_processing_utils_fast import ( |
| | BaseImageProcessorFast, |
| | ImagesKwargs, |
| | SizeDict, |
| | group_images_by_shape, |
| | reorder_images, |
| | ) |
| | from transformers.image_utils import ( |
| | PILImageResampling, |
| | ) |
| | from transformers.masking_utils import ( |
| | ALL_MASK_ATTENTION_FUNCTIONS, |
| | create_masks_for_generate, |
| | packed_sequence_mask_function, |
| | ) |
| | from transformers.modeling_outputs import ( |
| | BaseModelOutputWithPast, |
| | CausalLMOutputWithPast, |
| | ) |
| | from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| | from transformers.models.qwen3.configuration_qwen3 import Qwen3Config |
| | from transformers.models.qwen3.modeling_qwen3 import ( |
| | Qwen3ForCausalLM, |
| | Qwen3Model, |
| | Qwen3PreTrainedModel, |
| | ) |
| | from transformers.processing_utils import ProcessorMixin, Unpack |
| | from transformers.utils import TensorType, auto_docstring |
| | from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN |
| | from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD |
| | from transformers.utils.generic import ( |
| | OutputRecorder, |
| | TransformersKwargs, |
| | can_return_tuple, |
| | check_model_inputs, |
| | ) |
| | from transformers.utils.import_utils import ( |
| | is_torch_available, |
| | is_torchdynamo_compiling, |
| | is_torchvision_available, |
| | is_vision_available, |
| | ) |
| | from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling |
| | from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig |
| | from transformers.models.siglip2.modeling_siglip2 import ( |
| | Siglip2Attention, |
| | Siglip2Encoder, |
| | Siglip2EncoderLayer, |
| | Siglip2VisionEmbeddings, |
| | ) |
| |
|
| |
|
| | if is_torch_available(): |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | if is_vision_available(): |
| | from PIL.Image import Image |
| | else: |
| | Image = None |
| | if is_torchvision_available(): |
| | from transformers.models.pix2struct.image_processing_pix2struct_fast import ( |
| | torch_extract_patches, |
| | ) |
| |
|
| |
|
| | class ModalityType(IntEnum): |
| | """ |
| | Modality identifiers for events. |
| | |
| | Members: |
| | image: Vision tokens (e.g., patches). |
| | text: Textual tokens. |
| | """ |
| |
|
| | image = 0 |
| | text = 1 |
| |
|
| |
|
| | class IsaacVisionConfig(Siglip2VisionConfig): |
| | """Vision configuration for Isaac with Pixel Shuffle support. |
| | |
| | Extends Siglip2VisionConfig with additional fields for pixel shuffle. |
| | |
| | Args: |
| | pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): |
| | Spatial factor applied before pixel shuffle reduces the resolution. |
| | num_patches (`int`, *optional*, defaults to 256): |
| | Maximum number of learnable positional embeddings to initialize. |
| | """ |
| |
|
| | model_type = "isaac_vision" |
| | base_config_key = "vision_config" |
| |
|
| | def __init__( |
| | self, |
| | hidden_size=768, |
| | intermediate_size=3072, |
| | num_hidden_layers=12, |
| | num_attention_heads=12, |
| | num_channels=3, |
| | num_patches=256, |
| | patch_size=16, |
| | hidden_act="gelu_pytorch_tanh", |
| | layer_norm_eps=1e-6, |
| | attention_dropout=0.0, |
| | pixel_shuffle_scale_factor=1, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| |
|
| | self.hidden_size = hidden_size |
| | self.intermediate_size = intermediate_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.num_channels = num_channels |
| | self.patch_size = patch_size |
| | self.attention_dropout = attention_dropout |
| | self.layer_norm_eps = layer_norm_eps |
| | self.hidden_act = hidden_act |
| | self.num_patches = num_patches |
| |
|
| | |
| | self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor |
| |
|
| | |
| | if getattr(self, "_attn_implementation", None) is None: |
| | self._attn_implementation = "sdpa" |
| |
|
| |
|
| | class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): |
| | patch_size: Optional[int] |
| | max_num_patches: Optional[int] |
| | min_num_patches: Optional[int] |
| | pixel_shuffle_scale: Optional[int] |
| |
|
| |
|
| | @auto_docstring |
| | class IsaacImageProcessorFast(BaseImageProcessorFast): |
| | MAX_PIXELS = 60_000_000 |
| |
|
| | resample = PILImageResampling.BILINEAR |
| | model_input_names = ["patches", "token_grids"] |
| | valid_kwargs = IsaacImageProcessorFastKwargs |
| | unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] |
| |
|
| | do_resize = True |
| | do_center_crop = False |
| | patch_size: Optional[int] = 16 |
| | max_num_patches: Optional[int] = 256 |
| | min_num_patches: Optional[int] = None |
| | pixel_shuffle_scale: Optional[int] = 1 |
| | do_pad = False |
| | do_rescale = True |
| | do_normalize = True |
| | image_mean = list(VISION_MEAN) |
| | image_std = list(VISION_STD) |
| | do_convert_rgb = True |
| | disable_grouping = False |
| |
|
| | def __init__( |
| | self, |
| | **kwargs: Unpack[IsaacImageProcessorFastKwargs], |
| | ) -> None: |
| | super().__init__(**kwargs) |
| |
|
| | def _validate_preprocess_kwargs(self, **kwargs): |
| | |
| | kwargs.pop("do_resize", None) |
| | kwargs.pop("size", None) |
| | kwargs.pop("do_center_crop", None) |
| | kwargs.pop("crop_size", None) |
| | kwargs.pop("disable_grouping", None) |
| | return super()._validate_preprocess_kwargs(**kwargs) |
| |
|
| | def resize( |
| | self, |
| | image: torch.Tensor, |
| | size: SizeDict, |
| | **kwargs, |
| | ) -> torch.Tensor: |
| | resize_kwargs: dict[str, Any] = {"align_corners": False} |
| | resize_mode = "bilinear" |
| |
|
| | return F.interpolate( |
| | image, |
| | size=(size.height, size.width), |
| | mode=resize_mode, |
| | **resize_kwargs, |
| | ) |
| |
|
| | def _preprocess( |
| | self, |
| | images: list[torch.Tensor], |
| | do_resize: bool, |
| | interpolation: Optional[Any], |
| | do_rescale: Optional[bool], |
| | rescale_factor: Optional[float], |
| | do_normalize: Optional[bool], |
| | image_mean: Optional[Union[float, Sequence[float]]], |
| | image_std: Optional[Union[float, Sequence[float]]], |
| | disable_grouping: Optional[bool] = None, |
| | return_tensors: Optional[Union[str, TensorType]] = None, |
| | *, |
| | patch_size: Optional[int] = None, |
| | max_num_patches: Optional[int] = None, |
| | min_num_patches: Optional[int] = None, |
| | pixel_shuffle_scale: Optional[int] = None, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | grouped_images, grouped_images_index = group_images_by_shape( |
| | images, disable_grouping=disable_grouping |
| | ) |
| |
|
| | grouped_outputs = {} |
| |
|
| | for shape, stacked_images in grouped_images.items(): |
| | batch_size, channels, original_height, original_width = stacked_images.shape |
| |
|
| | if bool(self.do_convert_rgb) and channels == 1: |
| | stacked_images = stacked_images.repeat(1, 3, 1, 1) |
| |
|
| | target_height, target_width = get_image_size_for_max_num_patches( |
| | original_height, |
| | original_width, |
| | patch_size, |
| | max_num_patches, |
| | min_num_patches=min_num_patches, |
| | pixel_shuffle_scale=pixel_shuffle_scale, |
| | ) |
| | if do_resize: |
| | image_batch = self.resize( |
| | stacked_images, |
| | SizeDict(height=target_height, width=target_width), |
| | interpolation=interpolation, |
| | ) |
| | else: |
| | if (original_height % patch_size) or (original_width % patch_size): |
| | raise ValueError( |
| | f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." |
| | ) |
| | image_batch, target_height, target_width = ( |
| | stacked_images, |
| | original_height, |
| | original_width, |
| | ) |
| |
|
| | image_batch = self.rescale_and_normalize( |
| | image_batch, |
| | do_rescale=do_rescale, |
| | rescale_factor=rescale_factor, |
| | do_normalize=do_normalize, |
| | image_mean=image_mean, |
| | image_std=image_std, |
| | ) |
| |
|
| | patches = torch_extract_patches(image_batch, patch_size, patch_size) |
| | _, height_tokens, width_tokens, _ = patches.shape |
| |
|
| | token_grid = ( |
| | torch.tensor([height_tokens, width_tokens], device=patches.device) |
| | .long() |
| | .expand(batch_size, 2) |
| | ) |
| |
|
| | real_dim = ( |
| | torch.tensor( |
| | [1, height_tokens, width_tokens], |
| | dtype=torch.long, |
| | device=patches.device, |
| | ) |
| | .unsqueeze(0) |
| | .repeat(batch_size, 1) |
| | ) |
| |
|
| | if (height_tokens % pixel_shuffle_scale) or ( |
| | width_tokens % pixel_shuffle_scale |
| | ): |
| | raise ValueError( |
| | f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." |
| | ) |
| | virtual_height = height_tokens // pixel_shuffle_scale |
| | virtual_width = width_tokens // pixel_shuffle_scale |
| |
|
| | virtual_dim = ( |
| | torch.tensor( |
| | [1, virtual_height, virtual_width], |
| | dtype=torch.long, |
| | device=patches.device, |
| | ) |
| | .unsqueeze(0) |
| | .repeat(batch_size, 1) |
| | ) |
| | grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) |
| |
|
| | def _reorder_grouped_item( |
| | grouped: dict[ |
| | tuple[int, ...], |
| | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], |
| | ], |
| | grouped_index: dict[tuple[int, ...], list[int]], |
| | item_idx: int, |
| | ) -> list[torch.Tensor]: |
| | return reorder_images( |
| | {k: v[item_idx] for k, v in grouped.items()}, grouped_index |
| | ) |
| |
|
| | keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") |
| | tensors: dict[str, torch.Tensor] = {} |
| |
|
| | for i, key in enumerate(keys): |
| | slices = _reorder_grouped_item(grouped_outputs, grouped_images_index, i) |
| | tensors[key] = torch.stack(slices, dim=0) |
| |
|
| | return BatchFeature(data=tensors, tensor_type=return_tensors) |
| |
|
| |
|
| | def create_document_attention_mask( |
| | config: PretrainedConfig, |
| | input_embeds: torch.Tensor, |
| | cu_seqlens: Optional[torch.Tensor], |
| | ) -> Optional[Union[torch.Tensor, Any]]: |
| | """ |
| | Materialize a backend-specific block-diagonal attention mask from packed cu_seqlens. |
| | |
| | Returns None if cu_seqlens is missing/degenerate. |
| | """ |
| | if cu_seqlens is None or cu_seqlens.numel() < 2: |
| | return None |
| |
|
| | seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() |
| | if seq_sizes.numel() == 0 or int(seq_sizes.sum()) == 0: |
| | return None |
| |
|
| | seg_ids = torch.repeat_interleave( |
| | torch.arange(seq_sizes.numel(), device=cu_seqlens.device), |
| | seq_sizes, |
| | ) |
| | mask_function = packed_sequence_mask_function(seg_ids.view(1, -1)) |
| |
|
| | seq_len = input_embeds.shape[1] |
| | cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) |
| |
|
| | mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] |
| | return mask_interface( |
| | batch_size=input_embeds.shape[0], |
| | cache_position=cache_position, |
| | kv_length=seq_len, |
| | kv_offset=0, |
| | mask_function=mask_function, |
| | attention_mask=None, |
| | allow_is_causal_skip=False, |
| | allow_is_bidirectional_skip=False, |
| | dtype=input_embeds.dtype, |
| | config=config, |
| | use_vmap=False, |
| | ) |
| |
|
| |
|
| | class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): |
| | """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. |
| | |
| | Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image |
| | `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional |
| | embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing). |
| | """ |
| |
|
| | def __init__(self, config: IsaacVisionConfig): |
| | super().__init__(config) |
| | self.config = config |
| | self.embed_dim = config.hidden_size |
| | self.patch_size = config.patch_size |
| |
|
| | self.patch_embedding = nn.Linear( |
| | in_features=config.num_channels * self.patch_size * self.patch_size, |
| | out_features=self.embed_dim, |
| | ) |
| |
|
| | self.num_patches = config.num_patches |
| | self.position_embedding_size = int(self.num_patches**0.5) |
| | self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) |
| |
|
| | @check_model_inputs |
| | def forward( |
| | self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor |
| | ) -> torch.Tensor: |
| | |
| | |
| | packed_pixel_values, seq_lengths = self._pack_to_batch( |
| | seq_patches, spatial_shapes |
| | ) |
| | if packed_pixel_values is None: |
| | return seq_patches.new_zeros((0, self.embed_dim)) |
| |
|
| | target_dtype = self.patch_embedding.weight.dtype |
| | patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) |
| |
|
| | positional_embeddings = self.position_embedding.weight.reshape( |
| | self.position_embedding_size, |
| | self.position_embedding_size, |
| | -1, |
| | ) |
| | resized_positional_embeddings = self.resize_positional_embeddings( |
| | positional_embeddings, |
| | spatial_shapes, |
| | max_length=packed_pixel_values.shape[1], |
| | ) |
| |
|
| | embeddings = patch_embeds + resized_positional_embeddings |
| | return self._unpack_from_batch(embeddings, seq_lengths) |
| |
|
| | def _pack_to_batch( |
| | self, |
| | seq_patches: torch.Tensor, |
| | spatial_shapes: torch.Tensor, |
| | ) -> tuple[Optional[torch.Tensor], torch.Tensor]: |
| | """Rebatch a packed patch sequence using per-image grids to align embeddings. |
| | |
| | Args: |
| | seq_patches: Packed patches of shape (total_patches, patch_dim). |
| | spatial_shapes: Per-image patch grids of shape (num_images, 2) as (H_tokens, W_tokens). |
| | |
| | Returns: |
| | (packed_pixel_values, seq_lengths) where: |
| | - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 |
| | - seq_lengths: (batch,) lengths for each image |
| | """ |
| | seq_lengths = spatial_shapes.long().prod(dim=-1) |
| | batch_size = int(seq_lengths.numel()) |
| | if batch_size == 0: |
| | return None, seq_lengths |
| |
|
| | |
| | lengths_list = seq_lengths.tolist() |
| | chunks = seq_patches.split(lengths_list, dim=0) |
| | packed_pixel_values = nn.utils.rnn.pad_sequence( |
| | chunks, batch_first=True |
| | ) |
| | return packed_pixel_values, seq_lengths |
| |
|
| | def _unpack_from_batch( |
| | self, embeddings: torch.Tensor, seq_lengths: torch.Tensor |
| | ) -> torch.Tensor: |
| | """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" |
| | lengths = seq_lengths.to(device=embeddings.device).tolist() |
| | chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] |
| | return torch.cat(chunks, dim=0) |
| |
|
| |
|
| | class IsaacVisionAttention(Siglip2Attention): |
| | """Custom attention that supports variable-length sequences with flash attention.""" |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: bool = False, |
| | cu_seqlens: Optional[torch.Tensor] = None, |
| | max_seqlen: Optional[int] = None, |
| | **kwargs, |
| | ): |
| | kwargs.pop("output_hidden_states", None) |
| | kwargs.pop("return_dict", None) |
| |
|
| | batch_size, seq_length, embed_dim = hidden_states.shape |
| | queries = self.q_proj(hidden_states) |
| | keys = self.k_proj(hidden_states) |
| | values = self.v_proj(hidden_states) |
| |
|
| | queries = queries.view( |
| | batch_size, seq_length, self.num_heads, self.head_dim |
| | ).transpose(1, 2) |
| | keys = keys.view( |
| | batch_size, seq_length, self.num_heads, self.head_dim |
| | ).transpose(1, 2) |
| | values = values.view( |
| | batch_size, seq_length, self.num_heads, self.head_dim |
| | ).transpose(1, 2) |
| |
|
| | attn_impl = self.config._attn_implementation |
| | attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] |
| | if attn_impl != "sdpa": |
| | attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] |
| |
|
| | attention_kwargs: dict[str, Any] = { |
| | "is_causal": False, |
| | "scaling": self.scale, |
| | } |
| |
|
| | supports_varlen = cu_seqlens is not None and attn_impl in { |
| | "flash_attention_2", |
| | "flash_attention_3", |
| | "flex_attention", |
| | "paged|flash_attention_2", |
| | "paged|flash_attention_3", |
| | } |
| | if supports_varlen: |
| | if max_seqlen is not None: |
| | max_q = max_k = int(max_seqlen) |
| | elif cu_seqlens.numel() >= 2: |
| | lengths = cu_seqlens[1:] - cu_seqlens[:-1] |
| | max_q = max_k = lengths.max() if lengths.numel() > 0 else seq_length |
| | else: |
| | max_q = max_k = seq_length |
| |
|
| | attention_kwargs.update( |
| | { |
| | "cu_seq_lens_q": cu_seqlens, |
| | "cu_seq_lens_k": cu_seqlens, |
| | "max_length_q": max_q, |
| | "max_length_k": max_k, |
| | } |
| | ) |
| |
|
| | attn_output, attn_weights = attention_interface( |
| | self, |
| | queries, |
| | keys, |
| | values, |
| | attention_mask, |
| | **attention_kwargs, |
| | ) |
| | attn_output = attn_output.reshape( |
| | batch_size, seq_length, embed_dim |
| | ).contiguous() |
| | attn_output = self.out_proj(attn_output) |
| |
|
| | return attn_output, attn_weights |
| |
|
| |
|
| | class IsaacVisionEncoderLayer(Siglip2EncoderLayer): |
| | """Isaac vision encoder layer with variable-length attention.""" |
| |
|
| | def __init__(self, config: IsaacVisionConfig): |
| | super().__init__(config) |
| | self.self_attn = IsaacVisionAttention(config) |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | cu_seqlens: Optional[torch.Tensor] = None, |
| | max_seqlen: Optional[int] = None, |
| | output_attentions: bool = False, |
| | **kwargs: Unpack[TransformersKwargs], |
| | ): |
| | r""" |
| | cu_seqlens (`torch.Tensor`, *optional*): |
| | Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive |
| | entries gives each document's token count and enables block-diagonal attention masking for packed batches. |
| | max_seqlen (`int`, *optional*): |
| | Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary |
| | buffers for packed variable-length attention. |
| | """ |
| | |
| | residual = hidden_states |
| | hidden_states = self.layer_norm1(hidden_states) |
| | attn_output, _ = self.self_attn( |
| | hidden_states, |
| | attention_mask=attention_mask, |
| | cu_seqlens=cu_seqlens, |
| | max_seqlen=max_seqlen, |
| | **kwargs, |
| | ) |
| | hidden_states = residual + attn_output |
| |
|
| | residual = hidden_states |
| | hidden_states = self.layer_norm2(hidden_states) |
| | hidden_states = self.mlp(hidden_states) |
| | hidden_states = residual + hidden_states |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class IsaacVisionEncoder(Siglip2Encoder): |
| | """Encoder using Isaac encoder layers with variable-length attention support.""" |
| |
|
| | def __init__(self, config: IsaacVisionConfig): |
| | super().__init__(config) |
| | self.layers = nn.ModuleList( |
| | [IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] |
| | ) |
| |
|
| |
|
| | def create_pixel_shuffle_index_map( |
| | seq_sizes: torch.Tensor, |
| | token_grids: torch.Tensor, |
| | scale_factor: int = 1, |
| | device: Optional[torch.device] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Build a gather-index map that tells us, for every *output* token after |
| | pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. |
| | |
| | Args |
| | ---- |
| | seq_sizes : (num_images,) - #patches in each image (row-major order) |
| | token_grids : (num_images,2) - (height, width) for every image |
| | scale_factor : spatial down-scale factor (≥2) |
| | device : (optional) overrides `seq_sizes.device` |
| | |
| | Returns |
| | ------- |
| | gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. |
| | gather_idx[i, j] is the *flat* index into the *original* |
| | packed sequence for the j-th sub-patch that forms the |
| | i-th output token. |
| | """ |
| | if not is_torchdynamo_compiling(): |
| | if (token_grids % scale_factor).any(): |
| | raise AssertionError( |
| | f"Every (H,W) in token_grids must be divisible by scale_factor={scale_factor}, got {token_grids.tolist()}" |
| | ) |
| |
|
| | gather_chunks: list[torch.Tensor] = [] |
| | tok_offset = 0 |
| | for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist()): |
| | |
| | grid = ( |
| | torch.arange(seq_len, device=device, dtype=torch.int64).view(h, w) |
| | + tok_offset |
| | ) |
| |
|
| | |
| | grid = ( |
| | grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) |
| | .permute(0, 2, 1, 3) |
| | .contiguous() |
| | ) |
| | gather_chunks.append(grid.view(-1, scale_factor * scale_factor)) |
| |
|
| | tok_offset += seq_len |
| |
|
| | return torch.cat(gather_chunks, dim=0) |
| |
|
| |
|
| | def pixel_shuffle_varlen( |
| | x: torch.Tensor, |
| | token_grids: torch.Tensor, |
| | scale_factor: int = 1, |
| | ) -> torch.Tensor: |
| | r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. |
| | |
| | Args: |
| | x (`torch.Tensor`): |
| | Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes |
| | produced by stacking image patches. |
| | token_grids (`torch.Tensor`): |
| | Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes |
| | corresponding to each image segment inside `x`. |
| | scale_factor (`int`, *optional*, defaults to 1): |
| | Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a |
| | single embedding channel-group. |
| | |
| | Returns: |
| | `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: |
| | `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` |
| | if the singleton batch dimension was present. |
| | |
| | Raises: |
| | ValueError: If more than one batch item is provided. |
| | """ |
| | return_with_batch_dim = x.dim() == 3 |
| | if return_with_batch_dim: |
| | if x.size(0) != 1: |
| | raise ValueError( |
| | f"Packed vision sequences expect a singleton batch dimension; received batch_size={x.size(0)}." |
| | ) |
| | embeddings = x.squeeze(0) |
| | else: |
| | embeddings = x |
| |
|
| | embed_dim = embeddings.size(-1) |
| | scale_factor = int(scale_factor) |
| |
|
| | |
| | seq_sizes = torch.prod(token_grids, dim=-1) |
| |
|
| | |
| | |
| | gather_idx = create_pixel_shuffle_index_map( |
| | seq_sizes=seq_sizes, |
| | token_grids=token_grids, |
| | scale_factor=scale_factor, |
| | device=embeddings.device, |
| | ) |
| |
|
| | |
| | gathered = embeddings[gather_idx] |
| |
|
| | |
| | out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) |
| |
|
| | |
| | if return_with_batch_dim: |
| | out = out.unsqueeze(0) |
| | return out |
| |
|
| |
|
| | class IsaacVisionTransformer(nn.Module): |
| | """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. |
| | |
| | Args: |
| | config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. |
| | |
| | Inputs: |
| | packed_seq_patches (Tuple[Tensor, Tensor]): ``(patches, token_grids)`` where ``patches`` is a packed |
| | patch sequence and ``token_grids`` holds per-image (H_tokens, W_tokens). |
| | |
| | Returns: |
| | torch.Tensor: Vision embeddings after encoder + pixel shuffle, shaped ``(seq_len, hidden_size * s^2)``. |
| | """ |
| |
|
| | _supports_sdpa = True |
| |
|
| | def __init__(self, config: IsaacVisionConfig): |
| | super().__init__() |
| | self.config = config |
| | self.embeddings = IsaacVisionEmbeddings(config) |
| | self.encoder = IsaacVisionEncoder(config) |
| | self.post_layernorm = nn.LayerNorm( |
| | config.hidden_size, eps=config.layer_norm_eps |
| | ) |
| | self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor |
| |
|
| | def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): |
| | seq_patches, token_grids = packed_seq_patches |
| | seq_sizes = torch.prod(token_grids, dim=-1) |
| |
|
| | |
| | hidden_states = self.embeddings(seq_patches, token_grids) |
| |
|
| | |
| | |
| | hidden_states = hidden_states.unsqueeze(0) |
| |
|
| | |
| | cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) |
| |
|
| | attention_mask = create_document_attention_mask( |
| | self.config, hidden_states, cu_seqlens |
| | ) |
| |
|
| | |
| | encoder_outputs = self.encoder( |
| | inputs_embeds=hidden_states, |
| | attention_mask=attention_mask, |
| | cu_seqlens=cu_seqlens, |
| | ) |
| | hidden_states = encoder_outputs.last_hidden_state |
| |
|
| | |
| | hidden_states = self.post_layernorm(hidden_states) |
| |
|
| | hidden_states = pixel_shuffle_varlen( |
| | x=hidden_states, |
| | token_grids=token_grids, |
| | scale_factor=self.pixel_shuffle_scale_factor, |
| | ) |
| | |
| | hidden_states = hidden_states.squeeze(0) |
| |
|
| | |
| | return hidden_states |
| |
|
| |
|
| | class IsaacMultiModalProjector(nn.Module): |
| | """Maps vision tower outputs to the text hidden size with a SiLU MLP.""" |
| |
|
| | def __init__(self, config: IsaacConfig): |
| | super().__init__() |
| | self.vision_hidden_size = config.vision_config.hidden_size * ( |
| | config.vision_config.pixel_shuffle_scale_factor**2 |
| | ) |
| | self.backbone_hidden_size = config.hidden_size |
| | self.linear_1 = nn.Linear( |
| | self.vision_hidden_size, 4 * self.vision_hidden_size, bias=False |
| | ) |
| | self.silu = nn.SiLU() |
| | self.linear_2 = nn.Linear( |
| | 4 * self.vision_hidden_size, self.backbone_hidden_size, bias=False |
| | ) |
| |
|
| | def forward(self, image_features): |
| | hidden_states = self.linear_1(image_features) |
| | hidden_states = self.silu(hidden_states) |
| | hidden_states = self.linear_2(hidden_states) |
| | return hidden_states |
| |
|
| |
|
| | class IsaacVisionEmbedding(nn.Module): |
| | _supports_sdpa = True |
| |
|
| | def __init__(self, config: IsaacConfig): |
| | super().__init__() |
| | vision_cfg = config.vision_config |
| |
|
| | self.vision_tower = IsaacVisionTransformer(vision_cfg) |
| | self.multimodal_projector = IsaacMultiModalProjector(config) |
| |
|
| | def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
| | hidden_states = self.vision_tower(vision_tokens) |
| | return self.multimodal_projector(hidden_states) |
| |
|
| |
|
| | def get_scaled_image_size( |
| | scale: float, |
| | original_size: int, |
| | patch_size: int, |
| | pixel_shuffle_scale: int, |
| | ) -> int: |
| | scaled_size = scale * original_size |
| | divisor = patch_size * pixel_shuffle_scale |
| | scaled_size = math.ceil(scaled_size / divisor) * divisor |
| | scaled_size = max(divisor, scaled_size) |
| | return int(scaled_size) |
| |
|
| |
|
| | def get_image_size_for_max_num_patches( |
| | image_height: int, |
| | image_width: int, |
| | patch_size: int, |
| | max_num_patches: int, |
| | min_num_patches: Optional[int] = None, |
| | eps: float = 1e-5, |
| | pixel_shuffle_scale: int = 1, |
| | ) -> tuple[int, int]: |
| | r"""Compute a target resolution whose patch grid satisfies patching parametrization. |
| | |
| | Args: |
| | image_height (`int`): |
| | Height in pixels of the source image prior to any resizing. |
| | image_width (`int`): |
| | Width in pixels of the source image prior to any resizing. |
| | patch_size (`int`): |
| | Size of the square patch used by the vision encoder. |
| | max_num_patches (`int`): |
| | Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. |
| | min_num_patches (`int`, *optional*): |
| | Lower bound on the number of patches. When provided the image will be scaled up if necessary. |
| | eps (`float`, *optional*, defaults to 1e-5): |
| | Convergence tolerance for the internal binary search to determing the target dimensions. |
| | pixel_shuffle_scale (`int`, *optional*, defaults to 1): |
| | Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. |
| | |
| | Returns: |
| | `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` |
| | and respect both the maximum and optional minimum patch-count constraints. |
| | """ |
| |
|
| | |
| | divisor = patch_size * pixel_shuffle_scale |
| | adjusted_height = math.ceil(image_height / divisor) * divisor |
| | adjusted_height = max(divisor, adjusted_height) |
| | adjusted_width = math.ceil(image_width / divisor) * divisor |
| | adjusted_width = max(divisor, adjusted_width) |
| |
|
| | num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) |
| |
|
| | if min_num_patches is not None and num_patches < min_num_patches: |
| | |
| | |
| | scale_min, scale_max = 1.0, 100.0 |
| | while (scale_max - scale_min) >= eps: |
| | scale = (scale_min + scale_max) / 2 |
| | target_height = get_scaled_image_size( |
| | scale, image_height, patch_size, pixel_shuffle_scale |
| | ) |
| | target_width = get_scaled_image_size( |
| | scale, image_width, patch_size, pixel_shuffle_scale |
| | ) |
| | num_patches = (target_height / patch_size) * (target_width / patch_size) |
| | if num_patches >= min_num_patches: |
| | scale_max = scale |
| | else: |
| | scale_min = scale |
| | scale = scale_max |
| | target_height = get_scaled_image_size( |
| | scale, image_height, patch_size, pixel_shuffle_scale |
| | ) |
| | target_width = get_scaled_image_size( |
| | scale, image_width, patch_size, pixel_shuffle_scale |
| | ) |
| | return target_height, target_width |
| | elif num_patches <= max_num_patches: |
| | return adjusted_height, adjusted_width |
| | else: |
| | |
| | scale_min, scale_max = eps / 10, 1.0 |
| | while (scale_max - scale_min) >= eps: |
| | scale = (scale_min + scale_max) / 2 |
| | target_height = get_scaled_image_size( |
| | scale, image_height, patch_size, pixel_shuffle_scale |
| | ) |
| | target_width = get_scaled_image_size( |
| | scale, image_width, patch_size, pixel_shuffle_scale |
| | ) |
| | num_patches = (target_height / patch_size) * (target_width / patch_size) |
| | if num_patches <= max_num_patches: |
| | scale_min = scale |
| | else: |
| | scale_max = scale |
| | scale = scale_min |
| | target_height = get_scaled_image_size( |
| | scale, image_height, patch_size, pixel_shuffle_scale |
| | ) |
| | target_width = get_scaled_image_size( |
| | scale, image_width, patch_size, pixel_shuffle_scale |
| | ) |
| | return target_height, target_width |
| |
|
| |
|
| | class IsaacConfig(PretrainedConfig): |
| | """Configuration class for Isaac multimodal model. |
| | |
| | This configuration corresponds to checkpoints such as |
| | [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). |
| | """ |
| |
|
| | model_type = "isaac" |
| | sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} |
| | image_processor_type = "IsaacImageProcessor" |
| |
|
| | def __init__( |
| | self, |
| | vision_config: Optional[IsaacVisionConfig] = None, |
| | text_config: Optional[Union[Qwen3Config, dict]] = None, |
| | vision_rescale_factor: float = 1 / 255, |
| | max_sequence_length: int = 16384, |
| | vision_token: str = "<image>", |
| | **kwargs, |
| | ): |
| | attn_implementation = kwargs.get("attn_implementation") |
| |
|
| | if isinstance(text_config, dict): |
| | self.text_config = self.sub_configs["text_config"](**text_config) |
| | elif isinstance(text_config, Qwen3Config): |
| | self.text_config = text_config |
| | elif text_config is None: |
| | self.text_config = self.sub_configs["text_config"]() |
| |
|
| | |
| | self.rope_parameters = getattr(self.text_config, "rope_parameters", None) |
| | self.layer_types = getattr(self.text_config, "layer_types", None) |
| |
|
| | super().__init__(**kwargs) |
| |
|
| | |
| | self.text_config.rope_parameters = self.rope_parameters |
| |
|
| | |
| | self.vocab_size = self.text_config.vocab_size |
| | self.hidden_size = self.text_config.hidden_size |
| | self.num_hidden_layers = self.text_config.num_hidden_layers |
| | self.num_attention_heads = self.text_config.num_attention_heads |
| | self.head_dim = self.text_config.head_dim |
| | self.hidden_act = self.text_config.hidden_act |
| | self.use_cache = self.text_config.use_cache |
| | self.rope_theta = self.rope_parameters["rope_theta"] |
| |
|
| | self.layer_types = getattr(self.text_config, "layer_types", None) |
| | layer_type_validation(self.layer_types, self.num_hidden_layers) |
| |
|
| | |
| | if isinstance(vision_config, dict): |
| | self.vision_config = self.sub_configs["vision_config"](**vision_config) |
| | elif isinstance(vision_config, IsaacVisionConfig): |
| | self.vision_config = vision_config |
| | elif vision_config is None: |
| | self.vision_config = self.sub_configs["vision_config"]() |
| |
|
| | |
| | if attn_implementation is not None: |
| | if isinstance(attn_implementation, dict): |
| | vision_attn = attn_implementation.get( |
| | "vision_config", attn_implementation.get("", None) |
| | ) |
| | else: |
| | vision_attn = attn_implementation |
| | if vision_attn is not None: |
| | self.vision_config._attn_implementation = vision_attn |
| |
|
| | if getattr(self, "_attn_implementation", None) is None: |
| | self._attn_implementation = "sdpa" |
| | |
| | self.vision_rescale_factor = float(vision_rescale_factor) |
| |
|
| | |
| | self.max_sequence_length = max_sequence_length |
| | self.vision_token = vision_token |
| |
|
| | def to_dict(self): |
| | output = super().to_dict() |
| | |
| | if hasattr(self, "text_config") and self.text_config is not None: |
| | output["text_config"] = self.text_config.to_dict() |
| | if hasattr(self, "vision_config") and self.vision_config is not None: |
| | output["vision_config"] = self.vision_config.to_dict() |
| | return output |
| |
|
| |
|
| | class IsaacProcessor(ProcessorMixin): |
| | """Processor that pairs the Isaac image processor with the Qwen2 tokenizer. |
| | |
| | Args: |
| | image_processor: Vision preprocessor (fast) used for patch extraction. |
| | tokenizer: Qwen2 tokenizer instance. |
| | vision_token (str, optional): Placeholder token marking image locations. Defaults to "<image>". |
| | max_sequence_length (int, optional): Maximum combined text+vision tokens kept. Defaults to 16384. |
| | rescale_factor (float, optional): Image rescale factor; defaults to 1/255. |
| | config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. |
| | |
| | Returns: |
| | BatchFeature: Contains ``input_ids`` and ``packed_inputs`` (patch tensors, grids, offsets, lengths, modality, positions). |
| | """ |
| |
|
| | attributes = ["image_processor", "tokenizer"] |
| | image_processor_class = ("IsaacImageProcessorFast",) |
| | tokenizer_class = ("Qwen2Tokenizer",) |
| | pad_token_id = 151643 |
| |
|
| | def __init__( |
| | self, |
| | image_processor, |
| | tokenizer, |
| | *, |
| | vision_token: str = "<image>", |
| | max_sequence_length: int = 16384, |
| | rescale_factor: Optional[float] = None, |
| | config: Optional[Union[IsaacConfig, dict]] = None, |
| | ) -> None: |
| | if isinstance(config, dict): |
| | config = IsaacConfig(**config) |
| |
|
| | if config is not None: |
| | vision_token = config.vision_token |
| | max_sequence_length = config.max_sequence_length |
| | rescale_factor = config.vision_rescale_factor |
| |
|
| | resolved_rescale_factor = ( |
| | float(rescale_factor) if rescale_factor is not None else float(1 / 255) |
| | ) |
| | if config is not None: |
| | config.vision_rescale_factor = resolved_rescale_factor |
| |
|
| | self.image_processor = image_processor |
| | super().__init__(image_processor, tokenizer) |
| |
|
| | text_pad_token_id = getattr(self.tokenizer, "pad_token_id", None) |
| | image_pad_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") |
| |
|
| | self.text_pad_token_id = int(text_pad_token_id) |
| | self.image_pad_token_id = int(image_pad_token_id) |
| | self.pad_token_id = self.text_pad_token_id |
| |
|
| | self.current_processor = self.image_processor |
| | self.config = config |
| | self.chat_template = getattr(self.tokenizer, "chat_template", None) |
| | self.vision_token = vision_token |
| | self.max_sequence_length = max_sequence_length |
| |
|
| | def _pack_batch( |
| | self, texts: list[str], images_list: Optional[list[Optional[list[Image]]]] |
| | ) -> dict[str, Optional[torch.Tensor]]: |
| | if images_list is None: |
| | pairs = ((t, None) for t in texts) |
| | else: |
| | pairs = zip(texts, images_list, strict=True) |
| |
|
| | per_sample: list[dict[str, Optional[torch.Tensor]]] = [] |
| | for txt, imgs in pairs: |
| | if imgs is not None and isinstance(imgs, Image): |
| | imgs = [imgs] |
| | per_sample.append(self._pack_single(txt, imgs)) |
| |
|
| | lengths = [int(p["input_ids"].shape[1]) for p in per_sample] |
| | max_len = max(lengths, default=0) |
| | batch = len(per_sample) |
| |
|
| | |
| | base_device = torch.device("cpu") |
| | for p in per_sample: |
| | if p["input_ids"].numel() > 0: |
| | base_device = p["input_ids"].device |
| | break |
| |
|
| | pad_id = self.text_pad_token_id |
| | padded_input_ids = torch.full( |
| | (batch, max_len), pad_id, device=base_device, dtype=torch.long |
| | ) |
| | padded_modality = torch.full( |
| | (batch, max_len), |
| | ModalityType.text.value, |
| | device=base_device, |
| | dtype=torch.long, |
| | ) |
| | padded_position_ids = torch.zeros( |
| | (batch, max_len, 3), device=base_device, dtype=torch.long |
| | ) |
| |
|
| | for i, (sample, l) in enumerate(zip(per_sample, lengths)): |
| | if l: |
| | padded_input_ids[i, -l:] = sample["input_ids"][0] |
| | padded_modality[i, -l:] = sample["modality_tensor"][0] |
| | padded_position_ids[i, -l:] = sample["position_ids"][0] |
| |
|
| | |
| | v_samples = [ |
| | (b, s) for b, s in enumerate(per_sample) if s["vision_patches"] is not None |
| | ] |
| | if v_samples: |
| | vision_patches_list = [s["vision_patches"] for _, s in v_samples] |
| | vision_grids_list = [s["vision_token_grids"] for _, s in v_samples] |
| | vision_offsets_list = [s["vision_token_offsets"] for _, s in v_samples] |
| | vision_lengths_list = [s["vision_token_lengths"] for _, s in v_samples] |
| | vision_batch_indices = [ |
| | torch.full_like(s["vision_token_offsets"], b) for b, s in v_samples |
| | ] |
| |
|
| | vision_patches = torch.cat(vision_patches_list, dim=0) |
| | vision_token_grids = torch.cat(vision_grids_list, dim=0) |
| | vision_token_offsets = torch.cat(vision_offsets_list, dim=0) |
| | vision_token_lengths = torch.cat(vision_lengths_list, dim=0) |
| | vision_token_batch_indices = torch.cat(vision_batch_indices, dim=0) |
| | else: |
| | vision_patches = vision_token_grids = vision_token_offsets = ( |
| | vision_token_lengths |
| | ) = vision_token_batch_indices = None |
| |
|
| | return { |
| | "input_ids": padded_input_ids, |
| | "vision_patches": vision_patches, |
| | "vision_token_grids": vision_token_grids, |
| | "vision_token_offsets": vision_token_offsets, |
| | "vision_token_lengths": vision_token_lengths, |
| | "vision_token_batch_indices": vision_token_batch_indices, |
| | "modality_tensor": padded_modality, |
| | "position_ids": padded_position_ids, |
| | } |
| |
|
| | def _pack_single( |
| | self, text: str, images: Optional[list[Image]] |
| | ) -> dict[str, Optional[torch.Tensor]]: |
| | segments = text.split( |
| | self.vision_token |
| | ) |
| | num_images = len(segments) - 1 |
| | items: list[dict[str, Any]] = [] |
| | total = 0 |
| | num_provided_images = len(images) if images is not None else 0 |
| | if not num_images == num_provided_images: |
| | raise ValueError( |
| | f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text} " |
| | ) |
| |
|
| | for index, segment in enumerate(segments): |
| | if segment: |
| | tok = ( |
| | self.tokenizer.encode( |
| | segment, add_special_tokens=False, return_tensors="pt" |
| | ) |
| | .squeeze(0) |
| | .to(torch.long) |
| | ) |
| | segment_length = int(tok.numel()) |
| | items.append( |
| | {"type": "text", "segment_length": segment_length, "tok": tok} |
| | ) |
| | total += segment_length |
| |
|
| | if index < num_images: |
| | feat = self.image_processor( |
| | images=images[index], return_tensors=TensorType.PYTORCH |
| | ) |
| | patches = feat["patches"][0].reshape(-1, feat["patches"].shape[-1]) |
| |
|
| | virtual_pixel_size = ( |
| | feat["virtual_pixel_size"][0].to(torch.long).tolist() |
| | ) |
| | real_pixel_size = feat["real_pixel_size"][0].to(torch.long).tolist() |
| | dims = tuple( |
| | (virtual_pixel_size + [1, 1, 1])[:3] |
| | ) |
| | segment_length = int(dims[0] * dims[1] * dims[2]) |
| |
|
| | items.append( |
| | { |
| | "type": "image", |
| | "segment_length": segment_length, |
| | "dims": dims, |
| | "patches": patches, |
| | "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), |
| | } |
| | ) |
| | total += segment_length |
| |
|
| | |
| | start = max(0, total - self.max_sequence_length) |
| | end = total |
| |
|
| | image_pad_value = self.image_pad_token_id |
| | base_device: Optional[torch.device] = None |
| | position_ids, modality, input_ids = [], [], [] |
| | vpatches, grids, vision_token_offsets, vision_token_lengths = [], [], [], [] |
| |
|
| | global_offset = 0 |
| | position_offset = 0 |
| |
|
| | for item in items: |
| | segment_length = int(item["segment_length"]) |
| | current_window_start = max(start, global_offset) |
| | current_window_end = min(end, global_offset + segment_length) |
| | has_overlap = current_window_end > current_window_start |
| |
|
| | if has_overlap and base_device is None: |
| | base_device = ( |
| | item["patches"].device |
| | if item["type"] == "image" |
| | else item["tok"].device |
| | ) |
| |
|
| | if has_overlap: |
| | segment_local_start = int(current_window_start - global_offset) |
| | segment_local_end = int(current_window_end - global_offset) |
| | segment_local_indices = torch.arange( |
| | segment_local_start, |
| | segment_local_end, |
| | device=base_device, |
| | dtype=torch.long, |
| | ) |
| | segment_kept_length = segment_local_end - segment_local_start |
| |
|
| | if item["type"] == "text": |
| | slice_index = segment_local_indices + position_offset |
| | zero_axis_pad = torch.zeros_like(slice_index) |
| | position_ids.append( |
| | torch.stack((slice_index, zero_axis_pad, zero_axis_pad), -1) |
| | ) |
| | modality.append( |
| | torch.full( |
| | (segment_kept_length,), |
| | ModalityType.text.value, |
| | device=base_device, |
| | dtype=torch.long, |
| | ) |
| | ) |
| | input_ids.append( |
| | item["tok"].to(base_device)[ |
| | segment_local_start:segment_local_end |
| | ] |
| | ) |
| | position_offset += segment_length |
| | else: |
| | num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] |
| | hw = grid_height_tokens * grid_width_tokens |
| | slice_index = (segment_local_indices // hw) + position_offset |
| | rem = segment_local_indices % hw |
| | row_index = rem // grid_width_tokens |
| | col_index = rem % grid_width_tokens |
| | position_ids.append( |
| | torch.stack((slice_index, row_index, col_index), -1) |
| | ) |
| | modality.append( |
| | torch.full( |
| | (segment_kept_length,), |
| | ModalityType.image.value, |
| | device=base_device, |
| | dtype=torch.long, |
| | ) |
| | ) |
| | input_ids.append( |
| | torch.full( |
| | (segment_kept_length,), |
| | image_pad_value, |
| | device=base_device, |
| | dtype=torch.long, |
| | ) |
| | ) |
| |
|
| | vpatches.append( |
| | item["patches"].to(base_device) |
| | ) |
| | |
| | |
| | grids.append(item["grid"]) |
| | vision_token_offsets.append(segment_local_start) |
| | vision_token_lengths.append(segment_kept_length) |
| |
|
| | position_offset += int(num_pos_slices) |
| |
|
| | else: |
| | position_offset += ( |
| | segment_length if item["type"] == "text" else int(item["dims"][0]) |
| | ) |
| |
|
| | global_offset += segment_length |
| |
|
| | if base_device is None: |
| | base_device = torch.device("cpu") |
| |
|
| | modality_tensor = ( |
| | torch.cat(modality, 0).unsqueeze(0) |
| | if modality |
| | else torch.zeros((1, 0), device=base_device, dtype=torch.long) |
| | ) |
| | position_ids = ( |
| | torch.cat(position_ids, 0).unsqueeze(0) |
| | if position_ids |
| | else torch.zeros((1, 0, 3), device=base_device, dtype=torch.long) |
| | ) |
| | input_ids = ( |
| | torch.cat(input_ids, 0).unsqueeze(0) |
| | if input_ids |
| | else torch.zeros((1, 0), device=base_device, dtype=torch.long) |
| | ) |
| |
|
| | if vpatches: |
| | vision_patches = torch.cat(vpatches, 0) |
| | vision_token_grids = torch.tensor( |
| | grids, device=base_device, dtype=torch.long |
| | ) |
| | vision_token_offsets = torch.tensor( |
| | vision_token_offsets, device=base_device, dtype=torch.long |
| | ) |
| | vision_token_lengths = torch.tensor( |
| | vision_token_lengths, device=base_device, dtype=torch.long |
| | ) |
| | else: |
| | vision_patches = vision_token_grids = vision_token_offsets = ( |
| | vision_token_lengths |
| | ) = None |
| |
|
| | return { |
| | "input_ids": input_ids, |
| | "vision_patches": vision_patches, |
| | "vision_token_grids": vision_token_grids, |
| | "vision_token_offsets": vision_token_offsets, |
| | "vision_token_lengths": vision_token_lengths, |
| | "modality_tensor": modality_tensor, |
| | "position_ids": position_ids, |
| | } |
| |
|
| | def __call__( |
| | self, |
| | text: Union[str, list[str]], |
| | images: Optional[Union[Image, list[Image]]] = None, |
| | return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
| | **kwargs, |
| | ) -> BatchFeature: |
| | texts = [text] if isinstance(text, str) else text |
| | images_list: Optional[list[Optional[list[Image]]]] = None |
| | if images is not None: |
| | if isinstance(images, list) and len(images) == len(texts): |
| | if not images: |
| | images_list = [] |
| | elif isinstance(images[0], list): |
| | images_list = images |
| | else: |
| | images_list = [ |
| | [img] for img in images |
| | ] |
| | else: |
| | images_list = [] |
| | for t in texts: |
| | n_tok = t.count(self.vision_token) |
| | if n_tok == 0: |
| | images_list.append(None) |
| | else: |
| | if isinstance(images, list): |
| | images_list.append(images) |
| | else: |
| | images_list.append([images]) |
| |
|
| | packed = self._pack_batch(texts, images_list) |
| | input_ids = packed.pop("input_ids") |
| | return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) |
| |
|
| |
|
| | class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): |
| | def __init__(self, config: IsaacConfig, device=None): |
| | rope_source_cfg = ( |
| | config.get_text_config() if hasattr(config, "get_text_config") else config |
| | ) |
| | rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} |
| | config_for_rope = copy.copy(rope_source_cfg) |
| | config_for_rope.rope_scaling = rope_scaling |
| |
|
| | init_device = ( |
| | device |
| | if device is not None and getattr(device, "type", None) != "meta" |
| | else None |
| | ) |
| | super().__init__(config_for_rope, device=init_device) |
| |
|
| | rotary_half_dim = self.inv_freq.shape[0] |
| | self.mrope_section = self._resolve_mrope_section( |
| | rope_scaling.get("mrope_section"), rotary_half_dim |
| | ) |
| | self.hidden_size = ( |
| | getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size |
| | ) |
| |
|
| | @staticmethod |
| | def _resolve_mrope_section( |
| | section: Optional[list[int]], rotary_half_dim: int |
| | ) -> list[int]: |
| | if section is None: |
| | weights = (2, 1, 1) |
| | base = [rotary_half_dim * w // sum(weights) for w in weights] |
| | base[0] += rotary_half_dim - sum(base) |
| | return base |
| |
|
| | section = [int(v) for v in section] |
| | return section |
| |
|
| | def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: |
| | split_sections = tuple(self.mrope_section * 2) |
| | chunks = tensor.split(split_sections, dim=-1) |
| | return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) |
| |
|
| | def forward( |
| | self, |
| | position_ids: torch.Tensor, |
| | modality_tensor: torch.Tensor, |
| | hidden_states: Optional[torch.Tensor] = None, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | if hidden_states is None: |
| | batch, seq_len, _ = position_ids.shape |
| | hidden_states = torch.zeros( |
| | batch, |
| | seq_len, |
| | self.hidden_size, |
| | dtype=torch.float32, |
| | device=position_ids.device, |
| | ) |
| |
|
| | with torch.no_grad(): |
| | pos = position_ids.clone() |
| | not_spatial = modality_tensor != ModalityType.image.value |
| | data_1d = pos[not_spatial][..., 0].unsqueeze( |
| | -1 |
| | ) |
| | pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) |
| | pos_axes = pos.permute(2, 0, 1).contiguous() |
| |
|
| | cos_axes, sin_axes = super().forward(hidden_states, pos_axes) |
| | cos_axes, sin_axes = ( |
| | cos_axes.to(hidden_states.dtype), |
| | sin_axes.to(hidden_states.dtype), |
| | ) |
| | cos_combined, sin_combined = ( |
| | self._combine_axes(cos_axes), |
| | self._combine_axes(sin_axes), |
| | ) |
| |
|
| | return cos_combined, sin_combined |
| |
|
| |
|
| | @auto_docstring |
| | class IsaacModel(Qwen3PreTrainedModel): |
| | supports_gradient_checkpointing = True |
| | _can_compile_fullgraph = False |
| | _supports_flex_attn = False |
| | _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} |
| | all_tied_weights_keys: dict[str, str] = {} |
| |
|
| | def __init__(self, config: IsaacConfig): |
| | Qwen3PreTrainedModel.__init__(self, config) |
| |
|
| | text_cfg_source = config.text_config |
| | text_cfg = copy.deepcopy(text_cfg_source) |
| | self.text_model = Qwen3Model._from_config(text_cfg) |
| | self.text_model.config = ( |
| | config |
| | ) |
| |
|
| | self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) |
| |
|
| | self.vision_embedding = IsaacVisionEmbedding(config) |
| | self.vision_embedding._supports_sdpa = True |
| | self.max_sequence_length = config.max_sequence_length |
| | self.vision_rescale_factor = config.vision_rescale_factor |
| | self.vision_token = config.vision_token |
| | self.rope_deltas = None |
| |
|
| | self.post_init() |
| |
|
| | |
| | if getattr(config, "gradient_checkpointing", False): |
| | self.gradient_checkpointing_enable() |
| |
|
| | def get_input_embeddings(self) -> nn.Module: |
| | return self.text_model.get_input_embeddings() |
| |
|
| | def set_input_embeddings(self, value: nn.Module) -> None: |
| | self.text_model.set_input_embeddings(value) |
| | vocab_size = getattr(value, "num_embeddings", None) |
| | if vocab_size is not None: |
| | self.config.vocab_size = vocab_size |
| | if hasattr(self.config, "text_config"): |
| | self.config.text_config.vocab_size = vocab_size |
| | self.text_model.config.vocab_size = vocab_size |
| |
|
| | @property |
| | def embed_tokens(self) -> nn.Module: |
| | return self.text_model.embed_tokens |
| |
|
| | @embed_tokens.setter |
| | def embed_tokens(self, value: nn.Module) -> None: |
| | self.text_model.embed_tokens = value |
| |
|
| | @property |
| | def vision_model(self) -> nn.Module: |
| | return self.vision_embedding.vision_tower |
| |
|
| | def embed_packed_inputs( |
| | self, input_ids: torch.Tensor, packed_inputs: dict[str, Optional[torch.Tensor]] |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """ |
| | Expects input_ids for text tokens and packed_inputs containing: |
| | - modality_tensor: (batch, seq_len) modality ids aligned to the sequence |
| | - position_ids: (batch, seq_len, 3) MRoPE coordinates (optional) |
| | - vision_patches: concatenated vision tokens shaped (total_tokens, embed_dim) or None |
| | - vision_token_grids: (num_images, 2) token grid sizes or None |
| | - vision_token_offsets: (num_images,) offsets into each image's virtual token span (optional) |
| | - vision_token_lengths: (num_images,) surviving virtual token lengths per image (optional) |
| | - vision_token_batch_indices: (num_images,) batch row for each image (optional; defaults to zeros) |
| | """ |
| | modality = packed_inputs["modality_tensor"].to( |
| | device=input_ids.device, dtype=torch.long |
| | ) |
| | embeds = self.text_model.embed_tokens(input_ids) |
| |
|
| | vision_patches = packed_inputs.get("vision_patches") |
| | if vision_patches is None: |
| | return embeds, modality |
| |
|
| | token_grids = packed_inputs["vision_token_grids"].to( |
| | device=vision_patches.device, dtype=torch.long |
| | ) |
| | vision = self.vision_embedding( |
| | (vision_patches, token_grids) |
| | ) |
| |
|
| | |
| | vision_reduction_factor = int( |
| | self.config.vision_config.pixel_shuffle_scale_factor |
| | ) |
| | sizes = ( |
| | token_grids.prod(-1) |
| | .div( |
| | vision_reduction_factor * vision_reduction_factor, rounding_mode="floor" |
| | ) |
| | .tolist() |
| | ) |
| | offsets = packed_inputs.get("vision_token_offsets") |
| | lengths = packed_inputs.get("vision_token_lengths") |
| | batch_indices = packed_inputs.get("vision_token_batch_indices") |
| |
|
| | chunks = vision.split(sizes, dim=0) |
| | picked: list[torch.Tensor] = [] |
| | picked_batch: list[int] = [] |
| | for chunk, size, offset, length, batch_index in zip( |
| | chunks, |
| | sizes, |
| | offsets.tolist(), |
| | lengths.tolist(), |
| | (batch_indices.tolist() if batch_indices is not None else [0] * len(sizes)), |
| | ): |
| | if size <= 0: |
| | continue |
| | offset = max(0, min(int(offset), size)) |
| | length = max(0, min(int(length), size - offset)) |
| | if length: |
| | picked.append(chunk[offset : offset + length]) |
| | picked_batch.append(int(batch_index)) |
| | if picked: |
| | vision_chunks = picked |
| | vision_batch_idx = picked_batch |
| | else: |
| | vision_chunks = vision_batch_idx = [] |
| |
|
| | vision = ( |
| | torch.cat(vision_chunks, 0) |
| | if vision_chunks |
| | else vision.new_zeros((0, vision.size(-1))) |
| | ) |
| | embeds = embeds.clone() |
| | num_batches = modality.shape[0] |
| | image_positions = [ |
| | (modality[b] == ModalityType.image.value) |
| | .nonzero(as_tuple=False) |
| | .squeeze(-1) |
| | for b in range(num_batches) |
| | ] |
| | cursors = [0 for _ in range(num_batches)] |
| |
|
| | for chunk, batch_index in zip(vision_chunks, vision_batch_idx): |
| | if chunk.numel() == 0: |
| | continue |
| | positions = image_positions[batch_index] |
| | start = cursors[batch_index] |
| | end = start + chunk.shape[0] |
| | embeds[batch_index, positions[start:end]] = chunk.to( |
| | device=embeds.device, dtype=embeds.dtype |
| | ) |
| | cursors[batch_index] = end |
| |
|
| | return embeds, modality |
| |
|
| | def get_rope_index( |
| | self, |
| | *, |
| | position_ids: Optional[torch.Tensor] = None, |
| | attention_mask: torch.Tensor, |
| | inputs_embeds: torch.Tensor, |
| | cache_position: torch.Tensor, |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | """Build 3D position ids and per-batch RoPE deltas.""" |
| |
|
| | device = inputs_embeds.device |
| | batch_size, seq_len = inputs_embeds.shape[:2] |
| |
|
| | if position_ids is None: |
| | cp = cache_position.to(device=device, dtype=torch.long) |
| | if cp.ndim == 1: |
| | cp = cp.view(1, -1).expand(batch_size or 1, -1) |
| |
|
| | base_delta = torch.as_tensor( |
| | 0 if self.rope_deltas is None else self.rope_deltas, |
| | device=device, |
| | dtype=torch.long, |
| | ).reshape(-1, 1) |
| | base_delta = torch.broadcast_to(base_delta, (batch_size, 1)) |
| |
|
| | mask_delta = attention_mask.to(device=device, dtype=torch.long).sum( |
| | 1, keepdim=True |
| | ) - attention_mask.size(1) |
| | rope_position = cp + base_delta + mask_delta |
| | pos_3d = rope_position.unsqueeze(-1).expand(-1, -1, 3) |
| | return pos_3d, base_delta |
| |
|
| | position_ids = position_ids.to(device=device) |
| | if position_ids.ndim == 2: |
| | position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) |
| |
|
| | if position_ids.shape[1] != seq_len: |
| | start_positions = position_ids[:, :1, 0] |
| | position_ids = ( |
| | torch.arange(seq_len, device=position_ids.device).view(1, -1) |
| | + start_positions |
| | ) |
| | position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) |
| |
|
| | attn = attention_mask.to(device=device, dtype=torch.long) |
| | m_per_batch = position_ids.amax(dim=(1, 2)) |
| | seq_lens = attn.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=device) |
| | rope_deltas = ( |
| | (m_per_batch + 1 - seq_lens).to(dtype=position_ids.dtype).unsqueeze(1) |
| | ) |
| | return position_ids, rope_deltas |
| |
|
| | @auto_docstring |
| | @check_model_inputs |
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | packed_inputs: Optional[dict[str, torch.Tensor]] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[list[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs: Unpack[TransformersKwargs], |
| | ) -> tuple | BaseModelOutputWithPast: |
| | """ |
| | Forward pass with MRoPE position embeddings. |
| | |
| | Computes position embeddings once and passes them through all layers. |
| | |
| | Args: |
| | packed_inputs (`dict`, *optional*): |
| | Plain tensor payloads. When provided, requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). |
| | modality_tensor (`torch.LongTensor`, *optional*): |
| | Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing |
| | values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. |
| | """ |
| |
|
| | output_attentions = kwargs.pop("output_attentions", None) |
| |
|
| | modality_tensor: Optional[torch.Tensor] = None |
| |
|
| | if packed_inputs is not None: |
| | inputs_embeds, modality_tensor = self.embed_packed_inputs( |
| | input_ids, packed_inputs |
| | ) |
| | elif input_ids is not None: |
| | inputs_embeds = self.text_model.embed_tokens(input_ids) |
| |
|
| | device = inputs_embeds.device |
| | batch_size, seq_len = inputs_embeds.shape[:2] |
| |
|
| | if use_cache and past_key_values is None: |
| | past_key_values = DynamicCache(config=self.config.get_text_config()) |
| |
|
| | if cache_position is None: |
| | past_seen_tokens = ( |
| | past_key_values.get_seq_length() if past_key_values is not None else 0 |
| | ) |
| | cache_position = torch.arange( |
| | past_seen_tokens, past_seen_tokens + seq_len, device=device |
| | ) |
| |
|
| | if attention_mask is None: |
| | attention_mask = torch.ones( |
| | inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long |
| | ) |
| |
|
| | if ( |
| | position_ids is None |
| | and packed_inputs is not None |
| | and packed_inputs.get("position_ids") is not None |
| | ): |
| | position_ids = packed_inputs.get("position_ids").to(device=device) |
| |
|
| | position_ids, rope_deltas = self.get_rope_index( |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | inputs_embeds=inputs_embeds, |
| | cache_position=cache_position, |
| | ) |
| | self.rope_deltas = rope_deltas |
| |
|
| | if modality_tensor is None: |
| | modality_tensor = torch.full( |
| | (batch_size, seq_len), |
| | ModalityType.text.value, |
| | device=device, |
| | dtype=torch.long, |
| | ) |
| |
|
| | cos, sin = self.rotary_emb( |
| | position_ids, modality_tensor, hidden_states=inputs_embeds |
| | ) |
| |
|
| | decoder_position_ids = ( |
| | position_ids[..., 0] if position_ids.ndim == 3 else position_ids |
| | ) |
| |
|
| | if not isinstance(attention_mask, dict): |
| | attention_mask = create_masks_for_generate( |
| | config=self.config, |
| | input_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | cache_position=cache_position, |
| | past_key_values=past_key_values, |
| | position_ids=decoder_position_ids, |
| | ) |
| |
|
| | is_mask_dict = isinstance(attention_mask, dict) |
| | hidden_states = inputs_embeds |
| | all_attentions = [] if output_attentions else None |
| |
|
| | for layer in self.text_model.layers: |
| | layer_mask = ( |
| | attention_mask[layer.attention_type] if is_mask_dict else attention_mask |
| | ) |
| | layer_outputs = layer( |
| | hidden_states, |
| | attention_mask=layer_mask, |
| | position_ids=decoder_position_ids, |
| | past_key_values=past_key_values, |
| | use_cache=use_cache, |
| | cache_position=cache_position, |
| | position_embeddings=(cos, sin), |
| | output_attentions=output_attentions, |
| | **kwargs, |
| | ) |
| |
|
| | layer_outputs_is_tuple = isinstance(layer_outputs, tuple) |
| | hidden_states = ( |
| | layer_outputs[0] if layer_outputs_is_tuple else layer_outputs |
| | ) |
| | if output_attentions and layer_outputs_is_tuple: |
| | all_attentions.append(layer_outputs[1]) |
| |
|
| | hidden_states = self.text_model.norm(hidden_states) |
| |
|
| | return BaseModelOutputWithPast( |
| | last_hidden_state=hidden_states, |
| | past_key_values=past_key_values, |
| | hidden_states=(hidden_states,), |
| | attentions=tuple(all_attentions) if output_attentions else None, |
| | ) |
| |
|
| |
|
| | @auto_docstring |
| | class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): |
| | config_class = IsaacConfig |
| | _can_compile_fullgraph = False |
| | _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} |
| | all_tied_weights_keys: dict[str, str] = { |
| | "lm_head.weight": "model.text_model.embed_tokens.weight" |
| | } |
| |
|
| | def __init__(self, config: IsaacConfig): |
| | super().__init__(config) |
| | self.model = IsaacModel(config) |
| | self.vocab_size = config.vocab_size |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | @auto_docstring |
| | @can_return_tuple |
| | @check_model_inputs |
| | def forward( |
| | self, |
| | input_ids: Optional[torch.LongTensor] = None, |
| | packed_inputs: Optional[dict[str, torch.Tensor]] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[list[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | **kwargs: Unpack[TransformersKwargs], |
| | ) -> tuple | CausalLMOutputWithPast: |
| | """Run multimodal CausalLM forward, accepting packed vision/text inputs. |
| | |
| | Args: |
| | packed_inputs (`dict`, *optional*): |
| | Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and |
| | vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. |
| | |
| | Returns: |
| | CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. |
| | """ |
| | output_attentions = kwargs.pop("output_attentions", None) |
| |
|
| | outputs = self.model( |
| | input_ids=input_ids, |
| | packed_inputs=packed_inputs, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | cache_position=cache_position, |
| | **kwargs, |
| | ) |
| | hidden_states = outputs[0] |
| | logits = self.lm_head(hidden_states) |
| | loss = None |
| | if labels is not None: |
| | loss = self.loss_function( |
| | logits=logits, labels=labels, vocab_size=self.config.vocab_size |
| | ) |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=outputs.past_key_values, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions if output_attentions else None, |
| | ) |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: torch.LongTensor, |
| | past_key_values: Optional[list[torch.FloatTensor]] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | packed_inputs: Optional[dict[str, torch.Tensor]] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | **kwargs, |
| | ) -> dict[str, Any]: |
| | model_inputs = super().prepare_inputs_for_generation( |
| | input_ids, |
| | past_key_values=past_key_values, |
| | attention_mask=attention_mask, |
| | inputs_embeds=inputs_embeds, |
| | cache_position=cache_position, |
| | position_ids=position_ids, |
| | **kwargs, |
| | ) |
| | if packed_inputs is None: |
| | return model_inputs |
| |
|
| | past_len = ( |
| | past_key_values.get_seq_length() if past_key_values is not None else 0 |
| | ) |
| | first_step = past_len == 0 |
| | model_inputs["packed_inputs"] = packed_inputs if first_step else None |
| | model_inputs["position_ids"] = None |
| |
|
| | return model_inputs |
| |
|
| | @classmethod |
| | def can_generate(cls) -> bool: |
| | return True |
| |
|
| | def set_input_embeddings(self, value: nn.Module) -> None: |
| | self.model.set_input_embeddings(value) |
| | vocab_size = getattr(value, "num_embeddings", None) |
| | self.config.vocab_size = vocab_size |
| | self.model.config.vocab_size = vocab_size |
| | self.model.text_model.config.vocab_size = vocab_size |
| | if self.lm_head.weight.shape[0] != vocab_size: |
| | self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) |
| | self.lm_head.weight = self.model.text_model.embed_tokens.weight |
| |
|
| |
|
| | __all__ = [ |
| | "IsaacConfig", |
| | "IsaacModel", |
| | "IsaacPreTrainedModel", |
| | "IsaacForConditionalGeneration", |
| | "IsaacImageProcessorFast", |
| | "IsaacProcessor", |
| | ] |
| |
|
| |
|