| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import math |
| from collections.abc import Callable, Sequence |
| from enum import IntEnum |
| from typing import Any, Optional, Union |
|
|
| from transformers.cache_utils import DynamicCache |
| from transformers.configuration_utils import PretrainedConfig, layer_type_validation |
| from transformers.feature_extraction_utils import BatchFeature |
| from transformers.generation.utils import GenerationMixin |
| from transformers.image_processing_utils_fast import ( |
| BaseImageProcessorFast, |
| ImagesKwargs, |
| SizeDict, |
| group_images_by_shape, |
| reorder_images, |
| ) |
| from transformers.image_utils import ( |
| PILImageResampling, |
| ) |
| from transformers.masking_utils import ( |
| ALL_MASK_ATTENTION_FUNCTIONS, |
| create_masks_for_generate, |
| packed_sequence_mask_function, |
| ) |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| ) |
| from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS |
| from transformers.models.qwen3.configuration_qwen3 import Qwen3Config |
| from transformers.models.qwen3.modeling_qwen3 import ( |
| Qwen3ForCausalLM, |
| Qwen3Model, |
| Qwen3PreTrainedModel, |
| ) |
| from transformers.processing_utils import ProcessorMixin, Unpack |
| from transformers.utils import TensorType, auto_docstring |
| from transformers.utils.constants import IMAGENET_STANDARD_MEAN as VISION_MEAN |
| from transformers.utils.constants import IMAGENET_STANDARD_STD as VISION_STD |
| from transformers.utils.generic import ( |
| OutputRecorder, |
| TransformersKwargs, |
| can_return_tuple, |
| check_model_inputs, |
| ) |
| from transformers.utils.import_utils import ( |
| is_torch_available, |
| is_torchdynamo_compiling, |
| is_torchvision_available, |
| is_vision_available, |
| ) |
| from transformers.models.qwen2_5_vl import modeling_qwen2_5_vl as qwen2_5_vl_modeling |
| from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig |
| from transformers.models.siglip2.modeling_siglip2 import ( |
| Siglip2Attention, |
| Siglip2Encoder, |
| Siglip2EncoderLayer, |
| Siglip2VisionEmbeddings, |
| ) |
|
|
|
|
| if is_torch_available(): |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| if is_vision_available(): |
| from PIL.Image import Image |
| else: |
| Image = None |
| if is_torchvision_available(): |
| from transformers.models.pix2struct.image_processing_pix2struct_fast import ( |
| torch_extract_patches, |
| ) |
|
|
|
|
| class ModalityType(IntEnum): |
| """ |
| Modality identifiers for events. |
| |
| Members: |
| image: Vision tokens (e.g., patches). |
| text: Textual tokens. |
| """ |
|
|
| image = 0 |
| text = 1 |
|
|
|
|
| class IsaacVisionConfig(Siglip2VisionConfig): |
| """Vision configuration for Isaac with Pixel Shuffle support. |
| |
| Extends Siglip2VisionConfig with additional fields for pixel shuffle. |
| |
| Args: |
| pixel_shuffle_scale_factor (`int`, *optional*, defaults to 1): |
| Spatial factor applied before pixel shuffle reduces the resolution. |
| num_patches (`int`, *optional*, defaults to 256): |
| Maximum number of learnable positional embeddings to initialize. |
| """ |
|
|
| model_type = "isaac_vision" |
| base_config_key = "vision_config" |
|
|
| def __init__( |
| self, |
| hidden_size=768, |
| intermediate_size=3072, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| num_channels=3, |
| num_patches=256, |
| patch_size=16, |
| hidden_act="gelu_pytorch_tanh", |
| layer_norm_eps=1e-6, |
| attention_dropout=0.0, |
| pixel_shuffle_scale_factor=1, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
|
|
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.num_channels = num_channels |
| self.patch_size = patch_size |
| self.attention_dropout = attention_dropout |
| self.layer_norm_eps = layer_norm_eps |
| self.hidden_act = hidden_act |
| self.num_patches = num_patches |
|
|
| |
| self.pixel_shuffle_scale_factor = pixel_shuffle_scale_factor |
|
|
| |
| if getattr(self, "_attn_implementation", None) is None: |
| self._attn_implementation = "sdpa" |
|
|
|
|
| class IsaacImageProcessorFastKwargs(ImagesKwargs, total=False): |
| patch_size: Optional[int] |
| max_num_patches: Optional[int] |
| min_num_patches: Optional[int] |
| pixel_shuffle_scale: Optional[int] |
|
|
|
|
| @auto_docstring |
| class IsaacImageProcessorFast(BaseImageProcessorFast): |
| MAX_PIXELS = 60_000_000 |
|
|
| resample = PILImageResampling.BILINEAR |
| model_input_names = ["patches", "token_grids"] |
| valid_kwargs = IsaacImageProcessorFastKwargs |
| unused_kwargs = ["size", "do_center_crop", "crop_size", "pad_size", "do_pad"] |
|
|
| do_resize = True |
| do_center_crop = False |
| patch_size: Optional[int] = 16 |
| max_num_patches: Optional[int] = 256 |
| min_num_patches: Optional[int] = None |
| pixel_shuffle_scale: Optional[int] = 1 |
| do_pad = False |
| do_rescale = True |
| do_normalize = True |
| image_mean = list(VISION_MEAN) |
| image_std = list(VISION_STD) |
| do_convert_rgb = True |
| disable_grouping = False |
|
|
| def __init__( |
| self, |
| **kwargs: Unpack[IsaacImageProcessorFastKwargs], |
| ) -> None: |
| super().__init__(**kwargs) |
|
|
| def _validate_preprocess_kwargs(self, **kwargs): |
| |
| kwargs.pop("do_resize", None) |
| kwargs.pop("size", None) |
| kwargs.pop("do_center_crop", None) |
| kwargs.pop("crop_size", None) |
| kwargs.pop("disable_grouping", None) |
| return super()._validate_preprocess_kwargs(**kwargs) |
|
|
| def resize( |
| self, |
| image: torch.Tensor, |
| size: SizeDict, |
| **kwargs, |
| ) -> torch.Tensor: |
| resize_kwargs: dict[str, Any] = {"align_corners": False} |
| resize_mode = "bilinear" |
|
|
| return F.interpolate( |
| image, |
| size=(size.height, size.width), |
| mode=resize_mode, |
| **resize_kwargs, |
| ) |
|
|
| def _preprocess( |
| self, |
| images: list[torch.Tensor], |
| do_resize: bool, |
| interpolation: Optional[Any], |
| do_rescale: Optional[bool], |
| rescale_factor: Optional[float], |
| do_normalize: Optional[bool], |
| image_mean: Optional[Union[float, Sequence[float]]], |
| image_std: Optional[Union[float, Sequence[float]]], |
| disable_grouping: Optional[bool] = None, |
| return_tensors: Optional[Union[str, TensorType]] = None, |
| *, |
| patch_size: Optional[int] = None, |
| max_num_patches: Optional[int] = None, |
| min_num_patches: Optional[int] = None, |
| pixel_shuffle_scale: Optional[int] = None, |
| **kwargs, |
| ) -> BatchFeature: |
| grouped_images, grouped_images_index = group_images_by_shape( |
| images, disable_grouping=disable_grouping |
| ) |
|
|
| grouped_outputs = {} |
|
|
| for shape, stacked_images in grouped_images.items(): |
| batch_size, channels, original_height, original_width = stacked_images.shape |
|
|
| if bool(self.do_convert_rgb) and channels == 1: |
| stacked_images = stacked_images.repeat(1, 3, 1, 1) |
|
|
| target_height, target_width = get_image_size_for_max_num_patches( |
| original_height, |
| original_width, |
| patch_size, |
| max_num_patches, |
| min_num_patches=min_num_patches, |
| pixel_shuffle_scale=pixel_shuffle_scale, |
| ) |
| if do_resize: |
| image_batch = self.resize( |
| stacked_images, |
| SizeDict(height=target_height, width=target_width), |
| interpolation=interpolation, |
| ) |
| else: |
| if (original_height % patch_size) or (original_width % patch_size): |
| raise ValueError( |
| f"Image dimensions (h={original_height}, w={original_width}) must be divisible by patch_size={patch_size} when resize is disabled; enable resizing or adjust the input resolution." |
| ) |
| image_batch, target_height, target_width = ( |
| stacked_images, |
| original_height, |
| original_width, |
| ) |
|
|
| image_batch = self.rescale_and_normalize( |
| image_batch, |
| do_rescale=do_rescale, |
| rescale_factor=rescale_factor, |
| do_normalize=do_normalize, |
| image_mean=image_mean, |
| image_std=image_std, |
| ) |
|
|
| patches = torch_extract_patches(image_batch, patch_size, patch_size) |
| _, height_tokens, width_tokens, _ = patches.shape |
|
|
| token_grid = ( |
| torch.tensor([height_tokens, width_tokens], device=patches.device) |
| .long() |
| .expand(batch_size, 2) |
| ) |
|
|
| real_dim = ( |
| torch.tensor( |
| [1, height_tokens, width_tokens], |
| dtype=torch.long, |
| device=patches.device, |
| ) |
| .unsqueeze(0) |
| .repeat(batch_size, 1) |
| ) |
|
|
| if (height_tokens % pixel_shuffle_scale) or ( |
| width_tokens % pixel_shuffle_scale |
| ): |
| raise ValueError( |
| f"Token grid (h={height_tokens}, w={width_tokens}) must be divisible by pixel_shuffle_scale={pixel_shuffle_scale}; adjust resize/patch parameters or disable pixel shuffle." |
| ) |
| virtual_height = height_tokens // pixel_shuffle_scale |
| virtual_width = width_tokens // pixel_shuffle_scale |
|
|
| virtual_dim = ( |
| torch.tensor( |
| [1, virtual_height, virtual_width], |
| dtype=torch.long, |
| device=patches.device, |
| ) |
| .unsqueeze(0) |
| .repeat(batch_size, 1) |
| ) |
| grouped_outputs[shape] = (patches, token_grid, virtual_dim, real_dim) |
|
|
| def _reorder_grouped_item( |
| grouped: dict[ |
| tuple[int, ...], |
| tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], |
| ], |
| grouped_index: dict[tuple[int, ...], list[int]], |
| item_idx: int, |
| ) -> list[torch.Tensor]: |
| return reorder_images( |
| {k: v[item_idx] for k, v in grouped.items()}, grouped_index |
| ) |
|
|
| keys = ("patches", "token_grids", "virtual_pixel_size", "real_pixel_size") |
| tensors: dict[str, torch.Tensor] = {} |
|
|
| for i, key in enumerate(keys): |
| slices = _reorder_grouped_item(grouped_outputs, grouped_images_index, i) |
| tensors[key] = torch.stack(slices, dim=0) |
|
|
| return BatchFeature(data=tensors, tensor_type=return_tensors) |
|
|
|
|
| def create_document_attention_mask( |
| config: PretrainedConfig, |
| input_embeds: torch.Tensor, |
| cu_seqlens: Optional[torch.Tensor], |
| ) -> Optional[Union[torch.Tensor, Any]]: |
| """ |
| Materialize a backend-specific block-diagonal attention mask from packed cu_seqlens. |
| |
| Returns None if cu_seqlens is missing/degenerate. |
| """ |
| if cu_seqlens is None or cu_seqlens.numel() < 2: |
| return None |
|
|
| seq_sizes = (cu_seqlens[1:] - cu_seqlens[:-1]).long() |
| if seq_sizes.numel() == 0 or int(seq_sizes.sum()) == 0: |
| return None |
|
|
| seg_ids = torch.repeat_interleave( |
| torch.arange(seq_sizes.numel(), device=cu_seqlens.device), |
| seq_sizes, |
| ) |
| mask_function = packed_sequence_mask_function(seg_ids.view(1, -1)) |
|
|
| seq_len = input_embeds.shape[1] |
| cache_position = torch.arange(seq_len, device=input_embeds.device, dtype=torch.long) |
|
|
| mask_interface = ALL_MASK_ATTENTION_FUNCTIONS[config._attn_implementation] |
| return mask_interface( |
| batch_size=input_embeds.shape[0], |
| cache_position=cache_position, |
| kv_length=seq_len, |
| kv_offset=0, |
| mask_function=mask_function, |
| attention_mask=None, |
| allow_is_causal_skip=False, |
| allow_is_bidirectional_skip=False, |
| dtype=input_embeds.dtype, |
| config=config, |
| use_vmap=False, |
| ) |
|
|
|
|
| class IsaacVisionEmbeddings(Siglip2VisionEmbeddings): |
| """Adapter around SigLIP2 vision embeddings that consumes packed patch sequences. |
| |
| Isaac accepts variable-resolution vision inputs as a single packed sequence with per-image |
| `token_grids`; packing/unpacking here reconstructs per-image shapes so we can resize positional |
| embeddings and build `cu_seqlens` for variable-length attention (not generic generation packing). |
| """ |
|
|
| def __init__(self, config: IsaacVisionConfig): |
| super().__init__(config) |
| self.config = config |
| self.embed_dim = config.hidden_size |
| self.patch_size = config.patch_size |
|
|
| self.patch_embedding = nn.Linear( |
| in_features=config.num_channels * self.patch_size * self.patch_size, |
| out_features=self.embed_dim, |
| ) |
|
|
| self.num_patches = config.num_patches |
| self.position_embedding_size = int(self.num_patches**0.5) |
| self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim) |
|
|
| @check_model_inputs |
| def forward( |
| self, seq_patches: torch.Tensor, spatial_shapes: torch.Tensor |
| ) -> torch.Tensor: |
| |
| |
| packed_pixel_values, seq_lengths = self._pack_to_batch( |
| seq_patches, spatial_shapes |
| ) |
| if packed_pixel_values is None: |
| return seq_patches.new_zeros((0, self.embed_dim)) |
|
|
| target_dtype = self.patch_embedding.weight.dtype |
| patch_embeds = self.patch_embedding(packed_pixel_values.to(dtype=target_dtype)) |
|
|
| positional_embeddings = self.position_embedding.weight.reshape( |
| self.position_embedding_size, |
| self.position_embedding_size, |
| -1, |
| ) |
| resized_positional_embeddings = self.resize_positional_embeddings( |
| positional_embeddings, |
| spatial_shapes, |
| max_length=packed_pixel_values.shape[1], |
| ) |
|
|
| embeddings = patch_embeds + resized_positional_embeddings |
| return self._unpack_from_batch(embeddings, seq_lengths) |
|
|
| def _pack_to_batch( |
| self, |
| seq_patches: torch.Tensor, |
| spatial_shapes: torch.Tensor, |
| ) -> tuple[Optional[torch.Tensor], torch.Tensor]: |
| """Rebatch a packed patch sequence using per-image grids to align embeddings. |
| |
| Args: |
| seq_patches: Packed patches of shape (total_patches, patch_dim). |
| spatial_shapes: Per-image patch grids of shape (num_images, 2) as (H_tokens, W_tokens). |
| |
| Returns: |
| (packed_pixel_values, seq_lengths) where: |
| - packed_pixel_values: (batch, max_len, patch_dim) padded with zeros, or None if batch_size == 0 |
| - seq_lengths: (batch,) lengths for each image |
| """ |
| seq_lengths = spatial_shapes.long().prod(dim=-1) |
| batch_size = int(seq_lengths.numel()) |
| if batch_size == 0: |
| return None, seq_lengths |
|
|
| |
| lengths_list = seq_lengths.tolist() |
| chunks = seq_patches.split(lengths_list, dim=0) |
| packed_pixel_values = nn.utils.rnn.pad_sequence( |
| chunks, batch_first=True |
| ) |
| return packed_pixel_values, seq_lengths |
|
|
| def _unpack_from_batch( |
| self, embeddings: torch.Tensor, seq_lengths: torch.Tensor |
| ) -> torch.Tensor: |
| """Flatten a padded batch back to packed sequence order using `seq_lengths`.""" |
| lengths = seq_lengths.to(device=embeddings.device).tolist() |
| chunks = [embeddings[i, :l] for i, l in enumerate(lengths) if l > 0] |
| return torch.cat(chunks, dim=0) |
|
|
|
|
| class IsaacVisionAttention(Siglip2Attention): |
| """Custom attention that supports variable-length sequences with flash attention.""" |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| output_attentions: bool = False, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| **kwargs, |
| ): |
| kwargs.pop("output_hidden_states", None) |
| kwargs.pop("return_dict", None) |
|
|
| batch_size, seq_length, embed_dim = hidden_states.shape |
| queries = self.q_proj(hidden_states) |
| keys = self.k_proj(hidden_states) |
| values = self.v_proj(hidden_states) |
|
|
| queries = queries.view( |
| batch_size, seq_length, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
| keys = keys.view( |
| batch_size, seq_length, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
| values = values.view( |
| batch_size, seq_length, self.num_heads, self.head_dim |
| ).transpose(1, 2) |
|
|
| attn_impl = self.config._attn_implementation |
| attention_interface: Callable = ALL_ATTENTION_FUNCTIONS["sdpa"] |
| if attn_impl != "sdpa": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[attn_impl] |
|
|
| attention_kwargs: dict[str, Any] = { |
| "is_causal": False, |
| "scaling": self.scale, |
| } |
|
|
| supports_varlen = cu_seqlens is not None and attn_impl in { |
| "flash_attention_2", |
| "flash_attention_3", |
| "flex_attention", |
| "paged|flash_attention_2", |
| "paged|flash_attention_3", |
| } |
| if supports_varlen: |
| if max_seqlen is not None: |
| max_q = max_k = int(max_seqlen) |
| elif cu_seqlens.numel() >= 2: |
| lengths = cu_seqlens[1:] - cu_seqlens[:-1] |
| max_q = max_k = lengths.max() if lengths.numel() > 0 else seq_length |
| else: |
| max_q = max_k = seq_length |
|
|
| attention_kwargs.update( |
| { |
| "cu_seq_lens_q": cu_seqlens, |
| "cu_seq_lens_k": cu_seqlens, |
| "max_length_q": max_q, |
| "max_length_k": max_k, |
| } |
| ) |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| queries, |
| keys, |
| values, |
| attention_mask, |
| **attention_kwargs, |
| ) |
| attn_output = attn_output.reshape( |
| batch_size, seq_length, embed_dim |
| ).contiguous() |
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class IsaacVisionEncoderLayer(Siglip2EncoderLayer): |
| """Isaac vision encoder layer with variable-length attention.""" |
|
|
| def __init__(self, config: IsaacVisionConfig): |
| super().__init__(config) |
| self.self_attn = IsaacVisionAttention(config) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| output_attentions: bool = False, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| r""" |
| cu_seqlens (`torch.Tensor`, *optional*): |
| Prefix-sum tensor whose length equals the number of documents + 1. The difference between successive |
| entries gives each document's token count and enables block-diagonal attention masking for packed batches. |
| max_seqlen (`int`, *optional*): |
| Maximum document length referenced by `cu_seqlens`. Passed to FlashAttention so it can size temporary |
| buffers for packed variable-length attention. |
| """ |
| |
| residual = hidden_states |
| hidden_states = self.layer_norm1(hidden_states) |
| attn_output, _ = self.self_attn( |
| hidden_states, |
| attention_mask=attention_mask, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| **kwargs, |
| ) |
| hidden_states = residual + attn_output |
|
|
| residual = hidden_states |
| hidden_states = self.layer_norm2(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states |
|
|
|
|
| class IsaacVisionEncoder(Siglip2Encoder): |
| """Encoder using Isaac encoder layers with variable-length attention support.""" |
|
|
| def __init__(self, config: IsaacVisionConfig): |
| super().__init__(config) |
| self.layers = nn.ModuleList( |
| [IsaacVisionEncoderLayer(config) for _ in range(config.num_hidden_layers)] |
| ) |
|
|
|
|
| def create_pixel_shuffle_index_map( |
| seq_sizes: torch.Tensor, |
| token_grids: torch.Tensor, |
| scale_factor: int = 1, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Build a gather-index map that tells us, for every *output* token after |
| pixel-shuffle, which `scale_factor**2` *input* tokens are being merged. |
| |
| Args |
| ---- |
| seq_sizes : (num_images,) - #patches in each image (row-major order) |
| token_grids : (num_images,2) - (height, width) for every image |
| scale_factor : spatial down-scale factor (≥2) |
| device : (optional) overrides `seq_sizes.device` |
| |
| Returns |
| ------- |
| gather_idx : (new_total_seq_len, scale_factor**2) int64 tensor. |
| gather_idx[i, j] is the *flat* index into the *original* |
| packed sequence for the j-th sub-patch that forms the |
| i-th output token. |
| """ |
| if not is_torchdynamo_compiling(): |
| if (token_grids % scale_factor).any(): |
| raise AssertionError( |
| f"Every (H,W) in token_grids must be divisible by scale_factor={scale_factor}, got {token_grids.tolist()}" |
| ) |
|
|
| gather_chunks: list[torch.Tensor] = [] |
| tok_offset = 0 |
| for seq_len, (h, w) in zip(seq_sizes.tolist(), token_grids.tolist()): |
| |
| grid = ( |
| torch.arange(seq_len, device=device, dtype=torch.int64).view(h, w) |
| + tok_offset |
| ) |
|
|
| |
| grid = ( |
| grid.view(h // scale_factor, scale_factor, w // scale_factor, scale_factor) |
| .permute(0, 2, 1, 3) |
| .contiguous() |
| ) |
| gather_chunks.append(grid.view(-1, scale_factor * scale_factor)) |
|
|
| tok_offset += seq_len |
|
|
| return torch.cat(gather_chunks, dim=0) |
|
|
|
|
| def pixel_shuffle_varlen( |
| x: torch.Tensor, |
| token_grids: torch.Tensor, |
| scale_factor: int = 1, |
| ) -> torch.Tensor: |
| r"""Apply pixel shuffle to a packed vision sequence without unpacking per image. |
| |
| Args: |
| x (`torch.Tensor`): |
| Concatenated vision embeddings. Accepts `(seq_len, hidden_size)` or `(1, seq_len, hidden_size)` shapes |
| produced by stacking image patches. |
| token_grids (`torch.Tensor`): |
| Integer tensor of shape `(num_images, 2)` whose rows give the `(height, width)` patch grid sizes |
| corresponding to each image segment inside `x`. |
| scale_factor (`int`, *optional*, defaults to 1): |
| Spatial down-sampling factor specific to pixel shuffle. Values greater than one merge `scale_factor**2` neighboring patches into a |
| single embedding channel-group. |
| |
| Returns: |
| `torch.Tensor`: Pixel-shuffled embeddings with shape matching the input convention: |
| `(seq_len, hidden_size * scale_factor**2)` when the input was 2D, or `(1, seq_len, hidden_size * scale_factor**2)` |
| if the singleton batch dimension was present. |
| |
| Raises: |
| ValueError: If more than one batch item is provided. |
| """ |
| return_with_batch_dim = x.dim() == 3 |
| if return_with_batch_dim: |
| if x.size(0) != 1: |
| raise ValueError( |
| f"Packed vision sequences expect a singleton batch dimension; received batch_size={x.size(0)}." |
| ) |
| embeddings = x.squeeze(0) |
| else: |
| embeddings = x |
|
|
| embed_dim = embeddings.size(-1) |
| scale_factor = int(scale_factor) |
|
|
| |
| seq_sizes = torch.prod(token_grids, dim=-1) |
|
|
| |
| |
| gather_idx = create_pixel_shuffle_index_map( |
| seq_sizes=seq_sizes, |
| token_grids=token_grids, |
| scale_factor=scale_factor, |
| device=embeddings.device, |
| ) |
|
|
| |
| gathered = embeddings[gather_idx] |
|
|
| |
| out = gathered.reshape(gathered.size(0), embed_dim * scale_factor * scale_factor) |
|
|
| |
| if return_with_batch_dim: |
| out = out.unsqueeze(0) |
| return out |
|
|
|
|
| class IsaacVisionTransformer(nn.Module): |
| """Vision tower that packs variable-resolution patches, applies varlen attention, and pixel-shuffles outputs. |
| |
| Args: |
| config (IsaacVisionConfig): Vision configuration with pixel-shuffle and patching parameters. |
| |
| Inputs: |
| packed_seq_patches (Tuple[Tensor, Tensor]): ``(patches, token_grids)`` where ``patches`` is a packed |
| patch sequence and ``token_grids`` holds per-image (H_tokens, W_tokens). |
| |
| Returns: |
| torch.Tensor: Vision embeddings after encoder + pixel shuffle, shaped ``(seq_len, hidden_size * s^2)``. |
| """ |
|
|
| _supports_sdpa = True |
|
|
| def __init__(self, config: IsaacVisionConfig): |
| super().__init__() |
| self.config = config |
| self.embeddings = IsaacVisionEmbeddings(config) |
| self.encoder = IsaacVisionEncoder(config) |
| self.post_layernorm = nn.LayerNorm( |
| config.hidden_size, eps=config.layer_norm_eps |
| ) |
| self.pixel_shuffle_scale_factor = config.pixel_shuffle_scale_factor |
|
|
| def forward(self, packed_seq_patches: tuple[torch.Tensor, torch.Tensor]): |
| seq_patches, token_grids = packed_seq_patches |
| seq_sizes = torch.prod(token_grids, dim=-1) |
|
|
| |
| hidden_states = self.embeddings(seq_patches, token_grids) |
|
|
| |
| |
| hidden_states = hidden_states.unsqueeze(0) |
|
|
| |
| cu_seqlens = F.pad(seq_sizes.cumsum(0).to(torch.int32), (1, 0)) |
|
|
| attention_mask = create_document_attention_mask( |
| self.config, hidden_states, cu_seqlens |
| ) |
|
|
| |
| encoder_outputs = self.encoder( |
| inputs_embeds=hidden_states, |
| attention_mask=attention_mask, |
| cu_seqlens=cu_seqlens, |
| ) |
| hidden_states = encoder_outputs.last_hidden_state |
|
|
| |
| hidden_states = self.post_layernorm(hidden_states) |
|
|
| hidden_states = pixel_shuffle_varlen( |
| x=hidden_states, |
| token_grids=token_grids, |
| scale_factor=self.pixel_shuffle_scale_factor, |
| ) |
| |
| hidden_states = hidden_states.squeeze(0) |
|
|
| |
| return hidden_states |
|
|
|
|
| class IsaacVisionEmbedding(nn.Module): |
| _supports_sdpa = True |
|
|
| def __init__(self, config: IsaacConfig): |
| super().__init__() |
| vision_cfg = config.vision_config |
|
|
| self.vision_tower = IsaacVisionTransformer(vision_cfg) |
| self.multimodal_projector = IsaacMultiModalProjector(config) |
|
|
| def forward(self, vision_tokens: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: |
| hidden_states = self.vision_tower(vision_tokens) |
| return self.multimodal_projector(hidden_states) |
|
|
|
|
| def get_scaled_image_size( |
| scale: float, |
| original_size: int, |
| patch_size: int, |
| pixel_shuffle_scale: int, |
| ) -> int: |
| scaled_size = scale * original_size |
| divisor = patch_size * pixel_shuffle_scale |
| scaled_size = math.ceil(scaled_size / divisor) * divisor |
| scaled_size = max(divisor, scaled_size) |
| return int(scaled_size) |
|
|
|
|
| def get_image_size_for_max_num_patches( |
| image_height: int, |
| image_width: int, |
| patch_size: int, |
| max_num_patches: int, |
| min_num_patches: Optional[int] = None, |
| eps: float = 1e-5, |
| pixel_shuffle_scale: int = 1, |
| ) -> tuple[int, int]: |
| r"""Compute a target resolution whose patch grid satisfies patching parametrization. |
| |
| Args: |
| image_height (`int`): |
| Height in pixels of the source image prior to any resizing. |
| image_width (`int`): |
| Width in pixels of the source image prior to any resizing. |
| patch_size (`int`): |
| Size of the square patch used by the vision encoder. |
| max_num_patches (`int`): |
| Upper bound on `(height / patch_size) * (width / patch_size)` after resizing. |
| min_num_patches (`int`, *optional*): |
| Lower bound on the number of patches. When provided the image will be scaled up if necessary. |
| eps (`float`, *optional*, defaults to 1e-5): |
| Convergence tolerance for the internal binary search to determing the target dimensions. |
| pixel_shuffle_scale (`int`, *optional*, defaults to 1): |
| Additional stride multiplier applied when pixel shuffle later reduces spatial resolution. |
| |
| Returns: |
| `tuple[int, int]`: Height and width (in pixels) that are multiples of `patch_size * pixel_shuffle_scale` |
| and respect both the maximum and optional minimum patch-count constraints. |
| """ |
|
|
| |
| divisor = patch_size * pixel_shuffle_scale |
| adjusted_height = math.ceil(image_height / divisor) * divisor |
| adjusted_height = max(divisor, adjusted_height) |
| adjusted_width = math.ceil(image_width / divisor) * divisor |
| adjusted_width = max(divisor, adjusted_width) |
|
|
| num_patches = (adjusted_height / patch_size) * (adjusted_width / patch_size) |
|
|
| if min_num_patches is not None and num_patches < min_num_patches: |
| |
| |
| scale_min, scale_max = 1.0, 100.0 |
| while (scale_max - scale_min) >= eps: |
| scale = (scale_min + scale_max) / 2 |
| target_height = get_scaled_image_size( |
| scale, image_height, patch_size, pixel_shuffle_scale |
| ) |
| target_width = get_scaled_image_size( |
| scale, image_width, patch_size, pixel_shuffle_scale |
| ) |
| num_patches = (target_height / patch_size) * (target_width / patch_size) |
| if num_patches >= min_num_patches: |
| scale_max = scale |
| else: |
| scale_min = scale |
| scale = scale_max |
| target_height = get_scaled_image_size( |
| scale, image_height, patch_size, pixel_shuffle_scale |
| ) |
| target_width = get_scaled_image_size( |
| scale, image_width, patch_size, pixel_shuffle_scale |
| ) |
| return target_height, target_width |
| elif num_patches <= max_num_patches: |
| return adjusted_height, adjusted_width |
| else: |
| |
| scale_min, scale_max = eps / 10, 1.0 |
| while (scale_max - scale_min) >= eps: |
| scale = (scale_min + scale_max) / 2 |
| target_height = get_scaled_image_size( |
| scale, image_height, patch_size, pixel_shuffle_scale |
| ) |
| target_width = get_scaled_image_size( |
| scale, image_width, patch_size, pixel_shuffle_scale |
| ) |
| num_patches = (target_height / patch_size) * (target_width / patch_size) |
| if num_patches <= max_num_patches: |
| scale_min = scale |
| else: |
| scale_max = scale |
| scale = scale_min |
| target_height = get_scaled_image_size( |
| scale, image_height, patch_size, pixel_shuffle_scale |
| ) |
| target_width = get_scaled_image_size( |
| scale, image_width, patch_size, pixel_shuffle_scale |
| ) |
| return target_height, target_width |
|
|
|
|
| class IsaacConfig(PretrainedConfig): |
| """Configuration class for Isaac multimodal model. |
| |
| This configuration corresponds to checkpoints such as |
| [Perceptron/isaac-base](https://huggingface.co/Perceptron/isaac-base). |
| """ |
|
|
| model_type = "isaac" |
| sub_configs = {"vision_config": IsaacVisionConfig, "text_config": Qwen3Config} |
| image_processor_type = "IsaacImageProcessor" |
|
|
| def __init__( |
| self, |
| vision_config: Optional[IsaacVisionConfig] = None, |
| text_config: Optional[Union[Qwen3Config, dict]] = None, |
| vision_rescale_factor: float = 1 / 255, |
| max_sequence_length: int = 16384, |
| vision_token: str = "<image>", |
| **kwargs, |
| ): |
| attn_implementation = kwargs.get("attn_implementation") |
|
|
| if isinstance(text_config, dict): |
| self.text_config = self.sub_configs["text_config"](**text_config) |
| elif isinstance(text_config, Qwen3Config): |
| self.text_config = text_config |
| elif text_config is None: |
| self.text_config = self.sub_configs["text_config"]() |
|
|
| |
| self.rope_parameters = getattr(self.text_config, "rope_parameters", None) |
| self.layer_types = getattr(self.text_config, "layer_types", None) |
|
|
| super().__init__(**kwargs) |
|
|
| |
| self.text_config.rope_parameters = self.rope_parameters |
|
|
| |
| self.vocab_size = self.text_config.vocab_size |
| self.hidden_size = self.text_config.hidden_size |
| self.num_hidden_layers = self.text_config.num_hidden_layers |
| self.num_attention_heads = self.text_config.num_attention_heads |
| self.head_dim = self.text_config.head_dim |
| self.hidden_act = self.text_config.hidden_act |
| self.use_cache = self.text_config.use_cache |
| self.rope_theta = self.rope_parameters["rope_theta"] |
|
|
| self.layer_types = getattr(self.text_config, "layer_types", None) |
| layer_type_validation(self.layer_types, self.num_hidden_layers) |
|
|
| |
| if isinstance(vision_config, dict): |
| self.vision_config = self.sub_configs["vision_config"](**vision_config) |
| elif isinstance(vision_config, IsaacVisionConfig): |
| self.vision_config = vision_config |
| elif vision_config is None: |
| self.vision_config = self.sub_configs["vision_config"]() |
|
|
| |
| if attn_implementation is not None: |
| if isinstance(attn_implementation, dict): |
| vision_attn = attn_implementation.get( |
| "vision_config", attn_implementation.get("", None) |
| ) |
| else: |
| vision_attn = attn_implementation |
| if vision_attn is not None: |
| self.vision_config._attn_implementation = vision_attn |
|
|
| if getattr(self, "_attn_implementation", None) is None: |
| self._attn_implementation = "sdpa" |
| |
| self.vision_rescale_factor = float(vision_rescale_factor) |
|
|
| |
| self.max_sequence_length = max_sequence_length |
| self.vision_token = vision_token |
|
|
| def to_dict(self): |
| output = super().to_dict() |
| |
| if hasattr(self, "text_config") and self.text_config is not None: |
| output["text_config"] = self.text_config.to_dict() |
| if hasattr(self, "vision_config") and self.vision_config is not None: |
| output["vision_config"] = self.vision_config.to_dict() |
| return output |
|
|
|
|
| class IsaacProcessor(ProcessorMixin): |
| """Processor that pairs the Isaac image processor with the Qwen2 tokenizer. |
| |
| Args: |
| image_processor: Vision preprocessor (fast) used for patch extraction. |
| tokenizer: Qwen2 tokenizer instance. |
| vision_token (str, optional): Placeholder token marking image locations. Defaults to "<image>". |
| max_sequence_length (int, optional): Maximum combined text+vision tokens kept. Defaults to 16384. |
| rescale_factor (float, optional): Image rescale factor; defaults to 1/255. |
| config (IsaacConfig | dict, optional): If provided, overrides processor defaults from the model config. |
| |
| Returns: |
| BatchFeature: Contains ``input_ids`` and ``packed_inputs`` (patch tensors, grids, offsets, lengths, modality, positions). |
| """ |
|
|
| attributes = ["image_processor", "tokenizer"] |
| image_processor_class = ("IsaacImageProcessorFast",) |
| tokenizer_class = ("Qwen2Tokenizer",) |
| pad_token_id = 151643 |
|
|
| def __init__( |
| self, |
| image_processor, |
| tokenizer, |
| *, |
| vision_token: str = "<image>", |
| max_sequence_length: int = 16384, |
| rescale_factor: Optional[float] = None, |
| config: Optional[Union[IsaacConfig, dict]] = None, |
| ) -> None: |
| if isinstance(config, dict): |
| config = IsaacConfig(**config) |
|
|
| if config is not None: |
| vision_token = config.vision_token |
| max_sequence_length = config.max_sequence_length |
| rescale_factor = config.vision_rescale_factor |
|
|
| resolved_rescale_factor = ( |
| float(rescale_factor) if rescale_factor is not None else float(1 / 255) |
| ) |
| if config is not None: |
| config.vision_rescale_factor = resolved_rescale_factor |
|
|
| self.image_processor = image_processor |
| super().__init__(image_processor, tokenizer) |
|
|
| text_pad_token_id = getattr(self.tokenizer, "pad_token_id", None) |
| image_pad_token_id = self.tokenizer.convert_tokens_to_ids("<|image_pad|>") |
|
|
| self.text_pad_token_id = int(text_pad_token_id) |
| self.image_pad_token_id = int(image_pad_token_id) |
| self.pad_token_id = self.text_pad_token_id |
|
|
| self.current_processor = self.image_processor |
| self.config = config |
| self.chat_template = getattr(self.tokenizer, "chat_template", None) |
| self.vision_token = vision_token |
| self.max_sequence_length = max_sequence_length |
|
|
| def _pack_batch( |
| self, texts: list[str], images_list: Optional[list[Optional[list[Image]]]] |
| ) -> dict[str, Optional[torch.Tensor]]: |
| if images_list is None: |
| pairs = ((t, None) for t in texts) |
| else: |
| pairs = zip(texts, images_list, strict=True) |
|
|
| per_sample: list[dict[str, Any]] = [] |
| for txt, imgs in pairs: |
| if imgs is not None and isinstance(imgs, Image): |
| imgs = [imgs] |
| per_sample.append(self._pack_single(txt, imgs)) |
|
|
| lengths = [int(p["input_ids"].shape[1]) for p in per_sample] |
| max_len = max(lengths, default=0) |
| batch = len(per_sample) |
|
|
| |
| base_device = torch.device("cpu") |
| for p in per_sample: |
| if p["input_ids"].numel() > 0: |
| base_device = p["input_ids"].device |
| break |
|
|
| pad_id = self.text_pad_token_id |
| padded_input_ids = torch.full( |
| (batch, max_len), pad_id, device=base_device, dtype=torch.long |
| ) |
| padded_modality = torch.full( |
| (batch, max_len), |
| ModalityType.text.value, |
| device=base_device, |
| dtype=torch.long, |
| ) |
| padded_position_ids = torch.zeros( |
| (batch, max_len, 3), device=base_device, dtype=torch.long |
| ) |
|
|
| for i, (sample, l) in enumerate(zip(per_sample, lengths)): |
| if l: |
| padded_input_ids[i, -l:] = sample["input_ids"][0] |
| padded_modality[i, -l:] = sample["modality_tensor"][0] |
| padded_position_ids[i, -l:] = sample["position_ids"][0] |
|
|
| |
| max_images = max((s["vision_image_count"] for s in per_sample), default=0) |
| max_vision_len = max((s["vision_patches_total"] for s in per_sample), default=0) |
|
|
| if max_images > 0 and max_vision_len > 0: |
| |
| patch_dim = None |
| patch_dtype = None |
| for s in per_sample: |
| vps = s["vision_patches"] |
| if vps: |
| patch_dim = int(vps[0].shape[-1]) |
| patch_dtype = vps[0].dtype |
| break |
| assert patch_dim is not None and patch_dtype is not None |
|
|
| vision_patches = torch.zeros( |
| (batch, max_vision_len, patch_dim), |
| device=base_device, |
| dtype=patch_dtype, |
| ) |
| vision_patches_len = torch.zeros( |
| (batch,), device=base_device, dtype=torch.long |
| ) |
| vision_token_grids = torch.zeros( |
| (batch, max_images, 2), device=base_device, dtype=torch.long |
| ) |
| vision_token_offsets = torch.zeros( |
| (batch, max_images), device=base_device, dtype=torch.long |
| ) |
| vision_token_lengths = torch.zeros( |
| (batch, max_images), device=base_device, dtype=torch.long |
| ) |
| vision_image_count = torch.zeros( |
| (batch,), device=base_device, dtype=torch.long |
| ) |
|
|
| for i, sample in enumerate(per_sample): |
| vps: list[torch.Tensor] = sample["vision_patches"] |
| grids: list[tuple[int, int]] = sample["vision_token_grids"] |
| offs: list[int] = sample["vision_token_offsets"] |
| lens: list[int] = sample["vision_token_lengths"] |
| if not vps: |
| continue |
| vision_image_count[i] = len(vps) |
| cursor = 0 |
| for img_idx, (vp, grid, off, ln) in enumerate( |
| zip(vps, grids, offs, lens) |
| ): |
| plen = int(vp.shape[0]) |
| if plen <= 0 or img_idx >= max_images: |
| continue |
| vision_patches[i, cursor : cursor + plen] = vp.to(base_device) |
| cursor += plen |
| vision_token_grids[i, img_idx] = torch.tensor( |
| grid, device=base_device, dtype=torch.long |
| ) |
| vision_token_offsets[i, img_idx] = int(off) |
| vision_token_lengths[i, img_idx] = int(ln) |
| vision_patches_len[i] = cursor |
|
|
| else: |
| vision_patches = vision_patches_len = vision_token_grids = None |
| vision_token_offsets = vision_token_lengths = vision_image_count = None |
|
|
| return { |
| "input_ids": padded_input_ids, |
| "vision_patches": vision_patches, |
| "vision_patches_len": vision_patches_len, |
| "vision_token_grids": vision_token_grids, |
| "vision_token_offsets": vision_token_offsets, |
| "vision_token_lengths": vision_token_lengths, |
| "vision_image_count": vision_image_count, |
| "modality_tensor": padded_modality, |
| "position_ids": padded_position_ids, |
| } |
|
|
| def _pack_single(self, text: str, images: Optional[list[Image]]) -> dict[str, Any]: |
| segments = text.split( |
| self.vision_token |
| ) |
| num_images = len(segments) - 1 |
| items: list[dict[str, Any]] = [] |
| total = 0 |
| num_provided_images = len(images) if images is not None else 0 |
| if not num_images == num_provided_images: |
| raise ValueError( |
| f"IsaacProcessor expects one image per image token, got {num_images} tokens and {num_provided_images} images in sample with text {text} " |
| ) |
|
|
| for index, segment in enumerate(segments): |
| if segment: |
| tok = ( |
| self.tokenizer.encode( |
| segment, add_special_tokens=False, return_tensors="pt" |
| ) |
| .squeeze(0) |
| .to(torch.long) |
| ) |
| segment_length = int(tok.numel()) |
| items.append( |
| {"type": "text", "segment_length": segment_length, "tok": tok} |
| ) |
| total += segment_length |
|
|
| if index < num_images: |
| feat = self.image_processor( |
| images=images[index], return_tensors=TensorType.PYTORCH |
| ) |
| patches = feat["patches"][0].reshape(-1, feat["patches"].shape[-1]) |
|
|
| virtual_pixel_size = ( |
| feat["virtual_pixel_size"][0].to(torch.long).tolist() |
| ) |
| real_pixel_size = feat["real_pixel_size"][0].to(torch.long).tolist() |
| dims = tuple( |
| (virtual_pixel_size + [1, 1, 1])[:3] |
| ) |
| segment_length = int(dims[0] * dims[1] * dims[2]) |
|
|
| items.append( |
| { |
| "type": "image", |
| "segment_length": segment_length, |
| "dims": dims, |
| "patches": patches, |
| "grid": (int(real_pixel_size[1]), int(real_pixel_size[2])), |
| } |
| ) |
| total += segment_length |
|
|
| |
| start = max(0, total - self.max_sequence_length) |
| end = total |
|
|
| image_pad_value = self.image_pad_token_id |
| base_device: Optional[torch.device] = None |
| position_ids, modality, input_ids = [], [], [] |
| vpatches: list[torch.Tensor] = [] |
| grids: list[tuple[int, int]] = [] |
| vision_token_offsets: list[int] = [] |
| vision_token_lengths: list[int] = [] |
|
|
| global_offset = 0 |
| position_offset = 0 |
|
|
| for item in items: |
| segment_length = int(item["segment_length"]) |
| current_window_start = max(start, global_offset) |
| current_window_end = min(end, global_offset + segment_length) |
| has_overlap = current_window_end > current_window_start |
|
|
| if has_overlap and base_device is None: |
| base_device = ( |
| item["patches"].device |
| if item["type"] == "image" |
| else item["tok"].device |
| ) |
|
|
| if has_overlap: |
| segment_local_start = int(current_window_start - global_offset) |
| segment_local_end = int(current_window_end - global_offset) |
| segment_local_indices = torch.arange( |
| segment_local_start, |
| segment_local_end, |
| device=base_device, |
| dtype=torch.long, |
| ) |
| segment_kept_length = segment_local_end - segment_local_start |
|
|
| if item["type"] == "text": |
| slice_index = segment_local_indices + position_offset |
| zero_axis_pad = torch.zeros_like(slice_index) |
| position_ids.append( |
| torch.stack((slice_index, zero_axis_pad, zero_axis_pad), -1) |
| ) |
| modality.append( |
| torch.full( |
| (segment_kept_length,), |
| ModalityType.text.value, |
| device=base_device, |
| dtype=torch.long, |
| ) |
| ) |
| input_ids.append( |
| item["tok"].to(base_device)[ |
| segment_local_start:segment_local_end |
| ] |
| ) |
| position_offset += segment_length |
| else: |
| num_pos_slices, grid_height_tokens, grid_width_tokens = item["dims"] |
| hw = grid_height_tokens * grid_width_tokens |
| slice_index = (segment_local_indices // hw) + position_offset |
| rem = segment_local_indices % hw |
| row_index = rem // grid_width_tokens |
| col_index = rem % grid_width_tokens |
| position_ids.append( |
| torch.stack((slice_index, row_index, col_index), -1) |
| ) |
| modality.append( |
| torch.full( |
| (segment_kept_length,), |
| ModalityType.image.value, |
| device=base_device, |
| dtype=torch.long, |
| ) |
| ) |
| input_ids.append( |
| torch.full( |
| (segment_kept_length,), |
| image_pad_value, |
| device=base_device, |
| dtype=torch.long, |
| ) |
| ) |
|
|
| vpatches.append(item["patches"].to(base_device)) |
| grids.append(item["grid"]) |
| vision_token_offsets.append(segment_local_start) |
| vision_token_lengths.append(segment_kept_length) |
|
|
| position_offset += int(num_pos_slices) |
|
|
| else: |
| position_offset += ( |
| segment_length if item["type"] == "text" else int(item["dims"][0]) |
| ) |
|
|
| global_offset += segment_length |
|
|
| if base_device is None: |
| base_device = torch.device("cpu") |
|
|
| modality_tensor = ( |
| torch.cat(modality, 0).unsqueeze(0) |
| if modality |
| else torch.zeros((1, 0), device=base_device, dtype=torch.long) |
| ) |
| position_ids = ( |
| torch.cat(position_ids, 0).unsqueeze(0) |
| if position_ids |
| else torch.zeros((1, 0, 3), device=base_device, dtype=torch.long) |
| ) |
| input_ids = ( |
| torch.cat(input_ids, 0).unsqueeze(0) |
| if input_ids |
| else torch.zeros((1, 0), device=base_device, dtype=torch.long) |
| ) |
|
|
| return { |
| "input_ids": input_ids, |
| "modality_tensor": modality_tensor, |
| "position_ids": position_ids, |
| "vision_patches": vpatches, |
| "vision_token_grids": grids, |
| "vision_token_offsets": vision_token_offsets, |
| "vision_token_lengths": vision_token_lengths, |
| "vision_patches_total": sum(int(v.shape[0]) for v in vpatches), |
| "vision_image_count": len(vpatches), |
| } |
|
|
| def __call__( |
| self, |
| text: Union[str, list[str]], |
| images: Optional[Union[Image, list[Image]]] = None, |
| return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, |
| **kwargs, |
| ) -> BatchFeature: |
| texts = [text] if isinstance(text, str) else text |
| images_list: Optional[list[Optional[list[Image]]]] = None |
| if images is not None: |
| if isinstance(images, list) and len(images) == len(texts): |
| if not images: |
| images_list = [] |
| elif isinstance(images[0], list): |
| images_list = images |
| else: |
| images_list = [ |
| [img] for img in images |
| ] |
| else: |
| images_list = [] |
| for t in texts: |
| n_tok = t.count(self.vision_token) |
| if n_tok == 0: |
| images_list.append(None) |
| else: |
| if isinstance(images, list): |
| images_list.append(images) |
| else: |
| images_list.append([images]) |
|
|
| packed = self._pack_batch(texts, images_list) |
| input_ids = packed.pop("input_ids") |
| return BatchFeature(data={"input_ids": input_ids, "packed_inputs": packed}) |
|
|
|
|
| class IsaacRotaryEmbedding(qwen2_5_vl_modeling.Qwen2_5_VLRotaryEmbedding): |
| def __init__(self, config: IsaacConfig, device=None): |
| rope_source_cfg = ( |
| config.get_text_config() if hasattr(config, "get_text_config") else config |
| ) |
| rope_scaling = getattr(rope_source_cfg, "rope_scaling", None) or {} |
| config_for_rope = copy.copy(rope_source_cfg) |
| config_for_rope.rope_scaling = rope_scaling |
|
|
| init_device = ( |
| device |
| if device is not None and getattr(device, "type", None) != "meta" |
| else None |
| ) |
| super().__init__(config_for_rope, device=init_device) |
|
|
| rotary_half_dim = self.inv_freq.shape[0] |
| self.mrope_section = self._resolve_mrope_section( |
| rope_scaling.get("mrope_section"), rotary_half_dim |
| ) |
| self.hidden_size = ( |
| getattr(rope_source_cfg, "hidden_size", None) or config.hidden_size |
| ) |
|
|
| @staticmethod |
| def _resolve_mrope_section( |
| section: Optional[list[int]], rotary_half_dim: int |
| ) -> list[int]: |
| if section is None: |
| weights = (2, 1, 1) |
| base = [rotary_half_dim * w // sum(weights) for w in weights] |
| base[0] += rotary_half_dim - sum(base) |
| return base |
|
|
| section = [int(v) for v in section] |
| return section |
|
|
| def _combine_axes(self, tensor: torch.Tensor) -> torch.Tensor: |
| split_sections = tuple(self.mrope_section * 2) |
| chunks = tensor.split(split_sections, dim=-1) |
| return torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1) |
|
|
| def forward( |
| self, |
| position_ids: torch.Tensor, |
| modality_tensor: torch.Tensor, |
| hidden_states: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if hidden_states is None: |
| batch, seq_len, _ = position_ids.shape |
| hidden_states = torch.zeros( |
| batch, |
| seq_len, |
| self.hidden_size, |
| dtype=torch.float32, |
| device=position_ids.device, |
| ) |
|
|
| with torch.no_grad(): |
| pos = position_ids.clone() |
| not_spatial = modality_tensor != ModalityType.image.value |
| data_1d = pos[not_spatial][..., 0].unsqueeze( |
| -1 |
| ) |
| pos[not_spatial] = data_1d.expand(-1, pos.shape[-1]) |
| pos_axes = pos.permute(2, 0, 1).contiguous() |
|
|
| cos_axes, sin_axes = super().forward(hidden_states, pos_axes) |
| cos_axes, sin_axes = ( |
| cos_axes.to(hidden_states.dtype), |
| sin_axes.to(hidden_states.dtype), |
| ) |
| cos_combined, sin_combined = ( |
| self._combine_axes(cos_axes), |
| self._combine_axes(sin_axes), |
| ) |
|
|
| return cos_combined, sin_combined |
|
|
|
|
| @auto_docstring |
| class IsaacModel(Qwen3PreTrainedModel): |
| supports_gradient_checkpointing = True |
| _can_compile_fullgraph = False |
| _supports_flex_attn = False |
| _can_record_outputs = {"attentions": OutputRecorder(IsaacVisionAttention, index=1)} |
| all_tied_weights_keys: dict[str, str] = {} |
|
|
| def __init__(self, config: IsaacConfig): |
| Qwen3PreTrainedModel.__init__(self, config) |
|
|
| text_cfg_source = config.text_config |
| text_cfg = copy.deepcopy(text_cfg_source) |
| self.text_model = Qwen3Model._from_config(text_cfg) |
| self.text_model.config = ( |
| config |
| ) |
|
|
| self.rotary_emb = IsaacRotaryEmbedding(config, device=self.device) |
|
|
| hidden_dim = config.vision_config.hidden_size * ( |
| config.vision_config.pixel_shuffle_scale_factor**2 |
| ) |
| self.vision_embedding = nn.Sequential( |
| IsaacVisionTransformer(config.vision_config), |
| nn.Linear( |
| hidden_dim, |
| 4 * hidden_dim, |
| bias=False, |
| ), |
| nn.SiLU(), |
| nn.Linear(4 * hidden_dim, config.hidden_size, bias=False), |
| ) |
| self.vision_embedding._supports_sdpa = True |
| self.max_sequence_length = config.max_sequence_length |
| self.vision_rescale_factor = config.vision_rescale_factor |
| self.vision_token = config.vision_token |
| self.rope_deltas = None |
|
|
| self.post_init() |
|
|
| |
| if getattr(config, "gradient_checkpointing", False): |
| self.gradient_checkpointing_enable() |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| return self.text_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.text_model.set_input_embeddings(value) |
| vocab_size = getattr(value, "num_embeddings", None) |
| if vocab_size is not None: |
| self.config.vocab_size = vocab_size |
| if hasattr(self.config, "text_config"): |
| self.config.text_config.vocab_size = vocab_size |
| self.text_model.config.vocab_size = vocab_size |
|
|
| @property |
| def embed_tokens(self) -> nn.Module: |
| return self.text_model.embed_tokens |
|
|
| @embed_tokens.setter |
| def embed_tokens(self, value: nn.Module) -> None: |
| self.text_model.embed_tokens = value |
|
|
| @property |
| def vision_model(self) -> nn.Module: |
| return self.vision_embedding.vision_tower |
|
|
| def embed_packed_inputs( |
| self, input_ids: torch.Tensor, packed_inputs: dict[str, Optional[torch.Tensor]] |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Batch-first vision packing for DataParallel safety. |
| Expects packed_inputs containing: |
| - modality_tensor: (B, S) |
| - position_ids: (B, S, 3) (optional, used elsewhere) |
| - vision_patches: (B, max_vision_len, patch_dim) padded |
| - vision_patches_len: (B,) lengths of real vision patches per sample |
| - vision_token_grids: (B, max_images, 2) padded |
| - vision_token_offsets: (B, max_images) padded (virtual offsets) |
| - vision_token_lengths: (B, max_images) padded (virtual lengths) |
| - vision_image_count: (B,) number of images per sample |
| """ |
|
|
| modality = packed_inputs["modality_tensor"].to( |
| device=input_ids.device, dtype=torch.long |
| ) |
| embeds = self.text_model.embed_tokens(input_ids) |
|
|
| vision_patches = packed_inputs.get("vision_patches") |
| if vision_patches is None: |
| return embeds, modality |
|
|
| vision_patches_len = packed_inputs["vision_patches_len"].to( |
| device=vision_patches.device, dtype=torch.long |
| ) |
| vision_token_grids = packed_inputs["vision_token_grids"].to( |
| device=vision_patches.device, dtype=torch.long |
| ) |
| vision_token_offsets = packed_inputs["vision_token_offsets"].to( |
| device=vision_patches.device, dtype=torch.long |
| ) |
| vision_token_lengths = packed_inputs["vision_token_lengths"].to( |
| device=vision_patches.device, dtype=torch.long |
| ) |
| vision_image_count = packed_inputs["vision_image_count"].to( |
| device=vision_patches.device, dtype=torch.long |
| ) |
|
|
| flat_patches: list[torch.Tensor] = [] |
| flat_grids: list[torch.Tensor] = [] |
| flat_offsets: list[torch.Tensor] = [] |
| flat_lengths: list[torch.Tensor] = [] |
| flat_batch_idx: list[int] = [] |
|
|
| batch_size = vision_patches.shape[0] |
| for b in range(batch_size): |
| total_len = int(vision_patches_len[b].item()) |
| if total_len <= 0: |
| continue |
| vp = vision_patches[b, :total_len] |
| img_cnt = int(vision_image_count[b].item()) |
| if img_cnt <= 0: |
| continue |
| grids_b = vision_token_grids[b, :img_cnt] |
| offs_b = vision_token_offsets[b, :img_cnt] |
| lens_b = vision_token_lengths[b, :img_cnt] |
| sizes = grids_b.prod(-1).tolist() |
| cursor = 0 |
| for grid, off, ln, size in zip(grids_b, offs_b, lens_b, sizes): |
| size_int = int(size) |
| if size_int <= 0: |
| continue |
| chunk = vp[cursor : cursor + size_int] |
| cursor += size_int |
| flat_patches.append(chunk) |
| flat_grids.append(grid) |
| flat_offsets.append(off) |
| flat_lengths.append(ln) |
| flat_batch_idx.append(b) |
| assert cursor == total_len, ( |
| f"Vision patches cursor mismatch: cursor={cursor}, total={total_len}" |
| ) |
|
|
| if flat_patches: |
| flat_patches_t = torch.cat(flat_patches, dim=0) |
| flat_grids_t = torch.stack(flat_grids, dim=0) |
| flat_offsets_t = torch.stack(flat_offsets, dim=0) |
| flat_lengths_t = torch.stack(flat_lengths, dim=0) |
| flat_batch_idx_t = torch.tensor( |
| flat_batch_idx, device=vision_patches.device, dtype=torch.long |
| ) |
| else: |
| flat_patches_t = vision_patches.new_zeros((0, vision_patches.shape[-1])) |
| flat_grids_t = vision_patches.new_zeros((0, 2), dtype=torch.long) |
| flat_offsets_t = vision_patches.new_zeros((0,), dtype=torch.long) |
| flat_lengths_t = vision_patches.new_zeros((0,), dtype=torch.long) |
| flat_batch_idx_t = vision_patches.new_zeros((0,), dtype=torch.long) |
|
|
| grid_sum = ( |
| int(flat_grids_t.prod(-1).sum().item()) if flat_grids_t.numel() else 0 |
| ) |
| assert flat_patches_t.shape[0] == grid_sum, ( |
| "Packed vision mismatch after flatten: " |
| f"patches={flat_patches_t.shape[0]}, grid_sum={grid_sum}" |
| ) |
|
|
| if flat_patches_t.numel() == 0: |
| return embeds, modality |
|
|
| vision = self.vision_embedding((flat_patches_t, flat_grids_t)) |
|
|
| |
| vision_reduction_factor = int( |
| self.config.vision_config.pixel_shuffle_scale_factor |
| ) |
| sizes = ( |
| flat_grids_t.prod(-1) |
| .div( |
| vision_reduction_factor * vision_reduction_factor, rounding_mode="floor" |
| ) |
| .tolist() |
| ) |
|
|
| chunks = vision.split(sizes, dim=0) |
| vision_chunks: list[torch.Tensor] = [] |
| vision_batch_idx: list[int] = [] |
| for chunk, size, offset, length, batch_index in zip( |
| chunks, |
| sizes, |
| flat_offsets_t.tolist(), |
| flat_lengths_t.tolist(), |
| flat_batch_idx_t.tolist(), |
| ): |
| size_int = int(size) |
| if size_int <= 0: |
| continue |
| offset_int = max(0, min(int(offset), size_int)) |
| length_int = max(0, min(int(length), size_int - offset_int)) |
| if length_int: |
| vision_chunks.append(chunk[offset_int : offset_int + length_int]) |
| vision_batch_idx.append(int(batch_index)) |
|
|
| if vision_chunks: |
| vision_flat = torch.cat(vision_chunks, 0) |
| else: |
| vision_flat = vision.new_zeros((0, vision.size(-1))) |
|
|
| embeds = embeds.clone() |
| num_batches = modality.shape[0] |
| image_positions = [ |
| (modality[b] == ModalityType.image.value) |
| .nonzero(as_tuple=False) |
| .squeeze(-1) |
| for b in range(num_batches) |
| ] |
| cursors = [0 for _ in range(num_batches)] |
|
|
| for chunk, batch_index in zip(vision_chunks, vision_batch_idx): |
| if chunk.numel() == 0: |
| continue |
| positions = image_positions[batch_index] |
| start = cursors[batch_index] |
| end = start + chunk.shape[0] |
| embeds[batch_index, positions[start:end]] = chunk.to( |
| device=embeds.device, dtype=embeds.dtype |
| ) |
| cursors[batch_index] = end |
|
|
| return embeds, modality |
|
|
| def get_rope_index( |
| self, |
| *, |
| position_ids: Optional[torch.Tensor] = None, |
| attention_mask: torch.Tensor, |
| inputs_embeds: torch.Tensor, |
| cache_position: torch.Tensor, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| """Build 3D position ids and per-batch RoPE deltas.""" |
|
|
| device = inputs_embeds.device |
| batch_size, seq_len = inputs_embeds.shape[:2] |
|
|
| if position_ids is None: |
| cp = cache_position.to(device=device, dtype=torch.long) |
| if cp.ndim == 1: |
| cp = cp.view(1, -1).expand(batch_size or 1, -1) |
|
|
| base_delta = torch.as_tensor( |
| 0 if self.rope_deltas is None else self.rope_deltas, |
| device=device, |
| dtype=torch.long, |
| ).reshape(-1, 1) |
| base_delta = torch.broadcast_to(base_delta, (batch_size, 1)) |
|
|
| mask_delta = attention_mask.to(device=device, dtype=torch.long).sum( |
| 1, keepdim=True |
| ) - attention_mask.size(1) |
| rope_position = cp + base_delta + mask_delta |
| pos_3d = rope_position.unsqueeze(-1).expand(-1, -1, 3) |
| return pos_3d, base_delta |
|
|
| position_ids = position_ids.to(device=device) |
| if position_ids.ndim == 2: |
| position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) |
|
|
| if position_ids.shape[1] != seq_len: |
| start_positions = position_ids[:, :1, 0] |
| position_ids = ( |
| torch.arange(seq_len, device=position_ids.device).view(1, -1) |
| + start_positions |
| ) |
| position_ids = position_ids.unsqueeze(-1).expand(-1, -1, 3) |
|
|
| attn = attention_mask.to(device=device, dtype=torch.long) |
| m_per_batch = position_ids.amax(dim=(1, 2)) |
| seq_lens = attn.eq(1).sum(dim=-1).to(dtype=m_per_batch.dtype, device=device) |
| rope_deltas = ( |
| (m_per_batch + 1 - seq_lens).to(dtype=position_ids.dtype).unsqueeze(1) |
| ) |
| return position_ids, rope_deltas |
|
|
| @auto_docstring |
| @check_model_inputs |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| packed_inputs: Optional[dict[str, torch.Tensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple | BaseModelOutputWithPast: |
| """ |
| Forward pass with MRoPE position embeddings. |
| |
| Computes position embeddings once and passes them through all layers. |
| |
| Args: |
| packed_inputs (`dict`, *optional*): |
| Plain tensor payloads. When provided, requires `input_ids` for text tokens (or `text_token_ids` so `input_ids` can be rebuilt). |
| modality_tensor (`torch.LongTensor`, *optional*): |
| Modality identifiers aligned with the embedded sequence, shaped `(batch_size, seq_len)` and containing |
| values from `ModalityType`. Automatically built from `packed_inputs` or treated as text-only when omitted. |
| """ |
|
|
| output_attentions = kwargs.pop("output_attentions", None) |
|
|
| modality_tensor: Optional[torch.Tensor] = None |
|
|
| if packed_inputs is not None: |
| inputs_embeds, modality_tensor = self.embed_packed_inputs( |
| input_ids, packed_inputs |
| ) |
| elif input_ids is not None: |
| inputs_embeds = self.text_model.embed_tokens(input_ids) |
|
|
| device = inputs_embeds.device |
| batch_size, seq_len = inputs_embeds.shape[:2] |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config.get_text_config()) |
|
|
| if cache_position is None: |
| past_seen_tokens = ( |
| past_key_values.get_seq_length() if past_key_values is not None else 0 |
| ) |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + seq_len, device=device |
| ) |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones( |
| inputs_embeds.shape[:2], device=inputs_embeds.device, dtype=torch.long |
| ) |
|
|
| if ( |
| position_ids is None |
| and packed_inputs is not None |
| and packed_inputs.get("position_ids") is not None |
| ): |
| position_ids = packed_inputs.get("position_ids").to(device=device) |
|
|
| position_ids, rope_deltas = self.get_rope_index( |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| ) |
| self.rope_deltas = rope_deltas |
|
|
| if modality_tensor is None: |
| modality_tensor = torch.full( |
| (batch_size, seq_len), |
| ModalityType.text.value, |
| device=device, |
| dtype=torch.long, |
| ) |
|
|
| cos, sin = self.rotary_emb( |
| position_ids, modality_tensor, hidden_states=inputs_embeds |
| ) |
|
|
| decoder_position_ids = ( |
| position_ids[..., 0] if position_ids.ndim == 3 else position_ids |
| ) |
|
|
| if not isinstance(attention_mask, dict): |
| attention_mask = create_masks_for_generate( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=decoder_position_ids, |
| ) |
|
|
| is_mask_dict = isinstance(attention_mask, dict) |
| hidden_states = inputs_embeds |
| all_attentions = [] if output_attentions else None |
|
|
| for layer in self.text_model.layers: |
| layer_mask = ( |
| attention_mask[layer.attention_type] if is_mask_dict else attention_mask |
| ) |
| layer_outputs = layer( |
| hidden_states, |
| attention_mask=layer_mask, |
| position_ids=decoder_position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=(cos, sin), |
| output_attentions=output_attentions, |
| **kwargs, |
| ) |
|
|
| layer_outputs_is_tuple = isinstance(layer_outputs, tuple) |
| hidden_states = ( |
| layer_outputs[0] if layer_outputs_is_tuple else layer_outputs |
| ) |
| if output_attentions and layer_outputs_is_tuple: |
| all_attentions.append(layer_outputs[1]) |
|
|
| hidden_states = self.text_model.norm(hidden_states) |
|
|
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| hidden_states=(hidden_states,), |
| attentions=tuple(all_attentions) if output_attentions else None, |
| ) |
|
|
|
|
| @auto_docstring |
| class IsaacForConditionalGeneration(Qwen3ForCausalLM, GenerationMixin): |
| config_class = IsaacConfig |
| _can_compile_fullgraph = False |
| _tied_weights_keys = {"lm_head.weight": "model.text_model.embed_tokens.weight"} |
| all_tied_weights_keys: dict[str, str] = { |
| "lm_head.weight": "model.text_model.embed_tokens.weight" |
| } |
|
|
| def __init__(self, config: IsaacConfig): |
| super().__init__(config) |
| self.model = IsaacModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| @auto_docstring |
| @can_return_tuple |
| @check_model_inputs |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| packed_inputs: Optional[dict[str, torch.Tensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple | CausalLMOutputWithPast: |
| """Run multimodal CausalLM forward, accepting packed vision/text inputs. |
| |
| Args: |
| packed_inputs (`dict`, *optional*): |
| Packed vision/text payload from ``IsaacProcessor`` containing modality ids, MRoPE position ids, and |
| vision patch tensors/grids (with optional offsets/lengths) used to rebuild embeddings. |
| |
| Returns: |
| CausalLMOutputWithPast: logits, optional loss, caches, hidden states, attentions. |
| """ |
| output_attentions = kwargs.pop("output_attentions", None) |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| packed_inputs=packed_inputs, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
| loss = None |
| if labels is not None: |
| loss = self.loss_function( |
| logits=logits, labels=labels, vocab_size=self.config.vocab_size |
| ) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions if output_attentions else None, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[list[torch.FloatTensor]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| packed_inputs: Optional[dict[str, torch.Tensor]] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ) -> dict[str, Any]: |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| position_ids=position_ids, |
| **kwargs, |
| ) |
| if packed_inputs is None: |
| return model_inputs |
|
|
| past_len = ( |
| past_key_values.get_seq_length() if past_key_values is not None else 0 |
| ) |
| first_step = past_len == 0 |
| model_inputs["packed_inputs"] = packed_inputs if first_step else None |
| model_inputs["position_ids"] = None |
|
|
| return model_inputs |
|
|
| @classmethod |
| def can_generate(cls) -> bool: |
| return True |
|
|
| def set_input_embeddings(self, value: nn.Module) -> None: |
| self.model.set_input_embeddings(value) |
| vocab_size = getattr(value, "num_embeddings", None) |
| self.config.vocab_size = vocab_size |
| self.model.config.vocab_size = vocab_size |
| self.model.text_model.config.vocab_size = vocab_size |
| if self.lm_head.weight.shape[0] != vocab_size: |
| self.lm_head = nn.Linear(self.config.hidden_size, vocab_size, bias=False) |
| self.lm_head.weight = self.model.text_model.embed_tokens.weight |
|
|
|
|
| __all__ = [ |
| "IsaacConfig", |
| "IsaacModel", |
| "IsaacPreTrainedModel", |
| "IsaacForConditionalGeneration", |
| "IsaacImageProcessorFast", |
| "IsaacProcessor", |
| ] |
|
|