| from __future__ import annotations |
|
|
| import dataclasses |
| import glob |
| from collections.abc import Callable |
| from pathlib import Path |
| from typing import Any, Dict, Optional, Tuple, Union, cast |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| from einops import rearrange |
| from safetensors.torch import load_file as safetensors_load |
| from transformers import PretrainedConfig |
| from transformers.activations import ACT2FN |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.generation.utils import GenerationMixin |
| from transformers.integrations import use_kernel_forward_from_hub |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_layers import GradientCheckpointingLayer |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| BaseModelOutputWithPooling, |
| ) |
| from transformers.modeling_rope_utils import ( |
| ROPE_INIT_FUNCTIONS, |
| dynamic_rope_update, |
| ) |
| from transformers.modeling_utils import ( |
| ALL_ATTENTION_FUNCTIONS, |
| PreTrainedModel, |
| ) |
| from transformers.processing_utils import Unpack |
| from transformers.utils import ( |
| ModelOutput, |
| TransformersKwargs, |
| auto_docstring, |
| can_return_tuple, |
| logging, |
| ) |
| from transformers.utils.deprecation import deprecate_kwarg |
| try: |
| from transformers.utils.generic import check_model_inputs |
| except ImportError: |
| def check_model_inputs(*args, **kwargs): |
| def _wrap(fn): |
| return fn |
| return _wrap |
|
|
| from .configuration_yasa2 import ConvNextConfig, Yasa2Config, YasaConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| @dataclasses.dataclass |
| class Yasa2ModelOutputWithPast(BaseModelOutputWithPast): |
| """ |
| Base class for Yasa2 model outputs with past key values. |
| |
| Args: |
| last_hidden_state (`torch.FloatTensor`, *optional*): |
| Last hidden state of the model. |
| past_key_values (`Cache`, *optional*): |
| Cache of key/value tensors for each layer. |
| hidden_states (`Tuple[torch.FloatTensor]`, *optional*): |
| Tuple of hidden states from the model. |
| attentions (`Tuple[torch.FloatTensor]`, *optional*): |
| Tuple of attention maps from the model. |
| """ |
|
|
| last_hidden_state: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Cache] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| vision_hidden_states: Optional[torch.FloatTensor] = None |
|
|
|
|
| @dataclasses.dataclass |
| class Yasa2ForConditionalGenerationModelOutput(ModelOutput): |
| """ |
| Outputs for Yasa2 conditional generation. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| past_key_values (`Cache`, *optional*, returned when `use_cache=True`): |
| Cache of key/value tensors for each layer. |
| hidden_states (`Tuple[torch.FloatTensor]`, *optional*): |
| Tuple of hidden states from the language model. |
| attentions (`Tuple[torch.FloatTensor]`, *optional*): |
| Tuple of attention maps from the language model. |
| vision_hidden_states (`torch.FloatTensor`, *optional*): |
| Vision embeddings after projection and pooling. |
| language_model_outputs (`Yasa2ModelOutputWithPast`, *optional*): |
| The full language model outputs. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: Optional[torch.FloatTensor] = None |
| past_key_values: Optional[Cache] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| vision_hidden_states: Optional[torch.FloatTensor] = None |
| language_model_outputs: Optional[Yasa2ModelOutputWithPast] = None |
|
|
|
|
| |
| def get_2d_sincos_pos_embed( |
| embed_dim: int, image_size: int | tuple[int, int] |
| ) -> np.ndarray: |
| """Generate 2D sincos positional embeddings for a vision grid. |
| |
| Args: |
| embed_dim (int): Embedding dimension. |
| image_size (int | tuple[int, int]): Image size as an int or (height, width) tuple. |
| |
| Returns: |
| np.ndarray: Positional embedding array of shape (H*W, embed_dim). |
| """ |
| if isinstance(image_size, int): |
| grid_h_size, grid_w_size = image_size, image_size |
| else: |
| grid_h_size, grid_w_size = image_size[0], image_size[1] |
|
|
| grid_h = np.arange(grid_h_size, dtype=np.float32) |
| grid_w = np.arange(grid_w_size, dtype=np.float32) |
| |
| grid = np.meshgrid(grid_w, grid_h) |
| grid = np.stack(grid, axis=0) |
|
|
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| return pos_embed |
|
|
|
|
| def get_2d_sincos_pos_embed_from_grid( |
| embed_dim: int, grid: np.ndarray |
| ) -> np.ndarray: |
| """Generate 2D sincos positional embeddings from a coordinate grid. |
| |
| Args: |
| embed_dim (int): Embedding dimension. |
| grid (np.ndarray): Grid array of shape (2, H, W). |
| |
| Returns: |
| np.ndarray: Positional embedding array of shape (H, W, embed_dim). |
| """ |
| assert embed_dim % 2 == 0 |
|
|
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
| emb = np.concatenate([emb_h, emb_w], axis=-1) |
| return emb |
|
|
|
|
| def get_1d_sincos_pos_embed_from_grid( |
| embed_dim: int, pos: np.ndarray |
| ) -> np.ndarray: |
| """Generate 1D sincos positional embeddings from a positional array. |
| |
| Args: |
| embed_dim (int): Embedding dimension. |
| pos (np.ndarray): Position grid array for one dimension. |
| |
| Returns: |
| np.ndarray: Positional embedding array with sin/cos features. |
| """ |
| assert embed_dim % 2 == 0 |
| omega = np.arange(embed_dim // 2, dtype=np.float32) |
| omega /= embed_dim / 2.0 |
| omega = 1.0 / 10000**omega |
|
|
| out = np.einsum("hw,d->hwd", pos, omega) |
|
|
| emb_sin = np.sin(out) |
| emb_cos = np.cos(out) |
|
|
| emb = np.concatenate([emb_sin, emb_cos], axis=-1) |
| return emb |
|
|
|
|
| |
| def drop_path( |
| input: torch.Tensor, drop_prob: float = 0.0, training: bool = False |
| ) -> torch.Tensor: |
| """Apply stochastic depth (drop path) to the input tensor. |
| |
| Args: |
| input (torch.Tensor): Input tensor to apply drop path to. |
| drop_prob (float): Probability of dropping a path. Defaults to 0.0. |
| training (bool): Whether the model runs in training mode. Defaults to False. |
| |
| Returns: |
| torch.Tensor: Tensor with drop path applied when enabled. |
| """ |
| if drop_prob == 0.0 or not training: |
| return input |
| keep_prob = 1 - drop_prob |
| shape = (input.shape[0],) + (1,) * (input.ndim - 1) |
| |
| random_tensor = keep_prob + torch.rand( |
| shape, dtype=input.dtype, device=input.device |
| ) |
| random_tensor.floor_() |
| output = input.div(keep_prob) * random_tensor |
| return output |
|
|
|
|
| class ConvNextDropPath(nn.Module): |
| """Drop paths (stochastic depth) per sample in residual blocks.""" |
|
|
| def __init__(self, drop_prob: Optional[float] = None): |
| """Initialize the drop-path module. |
| |
| Args: |
| drop_prob (Optional[float]): Probability of dropping a path. |
| """ |
| super().__init__() |
| self.drop_prob = drop_prob |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| """Apply drop path to the provided hidden states. |
| |
| Args: |
| hidden_states (torch.Tensor): Tensor to apply stochastic depth to. |
| |
| Returns: |
| torch.Tensor: Tensor after stochastic depth. |
| """ |
| return drop_path(hidden_states, self.drop_prob, self.training) |
|
|
| def extra_repr(self) -> str: |
| """Return a string representation for module printing. |
| |
| Returns: |
| str: Description containing the configured drop probability. |
| """ |
| return "p={}".format(self.drop_prob) |
|
|
|
|
| class ConvNextLayerNorm(nn.Module): |
| r"""LayerNorm that supports channels_last (default) or channels_first. |
| The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height, |
| width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width). |
| """ |
|
|
| def __init__( |
| self, |
| normalized_shape: int, |
| eps: float = 1e-6, |
| data_format: str = "channels_last", |
| ) -> None: |
| """Initialize ConvNext LayerNorm. |
| |
| Args: |
| normalized_shape (int): Expected shape of the input channels. |
| eps (float): Small epsilon to avoid division by zero. |
| data_format (str): Either 'channels_last' or 'channels_first'. |
| |
| Raises: |
| NotImplementedError: If data_format is not supported. |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) |
| self.bias = nn.Parameter(torch.zeros(normalized_shape)) |
| self.eps = eps |
| self.data_format = data_format |
| if self.data_format not in ["channels_last", "channels_first"]: |
| raise NotImplementedError( |
| f"Unsupported data format: {self.data_format}" |
| ) |
| self.normalized_shape = (normalized_shape,) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """Apply layer normalization according to the configured data format. |
| |
| Args: |
| x (torch.Tensor): Input tensor of shape (N, C, H, W) or (N, H, W, C). |
| |
| Returns: |
| torch.Tensor: Normalized tensor with the same shape as input. |
| """ |
| if self.data_format == "channels_last": |
| x = nn.functional.layer_norm( |
| x, self.normalized_shape, self.weight, self.bias, self.eps |
| ) |
| elif self.data_format == "channels_first": |
| input_dtype = x.dtype |
| x = x.float() |
| u = x.mean(1, keepdim=True) |
| s = (x - u).pow(2).mean(1, keepdim=True) |
| |
| x = (x - u) / torch.sqrt(s + self.eps) |
| x = x.to(dtype=input_dtype) |
| x = self.weight[:, None, None] * x + self.bias[:, None, None] |
| return x |
|
|
|
|
| class ConvNextV2GRN(nn.Module): |
| """Global Response Normalization (GRN) layer for ConvNeXt V2.""" |
|
|
| def __init__(self, dim: int): |
| """Initialize the GRN layer parameters. |
| |
| Args: |
| dim (int): Channel dimension of the input tensor. |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
| self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim)) |
|
|
| def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: |
| """Apply Global Response Normalization to the hidden states. |
| |
| Args: |
| hidden_states (torch.FloatTensor): Input tensor shaped (batch, height, width, channels). |
| |
| Returns: |
| torch.FloatTensor: Normalized tensor with the same shape. |
| """ |
| |
| global_features = torch.norm( |
| hidden_states, p=2, dim=(1, 2), keepdim=True |
| ) |
| norm_features = global_features / ( |
| global_features.mean(dim=-1, keepdim=True) + 1e-6 |
| ) |
| |
| hidden_states = ( |
| self.weight * (hidden_states * norm_features) |
| + self.bias |
| + hidden_states |
| ) |
| return hidden_states |
|
|
|
|
| class ConvNextEmbeddings(nn.Module): |
| """ConvNeXt patch embedding layer.""" |
|
|
| def __init__( |
| self, num_channels: int = 3, hidden_size: int = 96, patch_size: int = 4 |
| ) -> None: |
| """Initialize ConvNeXt patch embeddings. |
| |
| Args: |
| num_channels (int): Number of image channels. Defaults to 3. |
| hidden_size (int): Hidden dimension size. Defaults to 96. |
| patch_size (int): Size of patches for initial convolution. Defaults to 4. |
| """ |
| super().__init__() |
| self.patch_embeddings = nn.Conv2d( |
| num_channels, |
| hidden_size, |
| kernel_size=patch_size, |
| stride=patch_size, |
| ) |
|
|
| self.layernorm = ConvNextLayerNorm( |
| hidden_size, eps=1e-6, data_format="channels_first" |
| ) |
| self.num_channels = num_channels |
|
|
| def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
| """Create patch embeddings from pixel values. |
| |
| Args: |
| pixel_values (torch.FloatTensor): Image tensor shaped (batch, channels, height, width). |
| |
| Returns: |
| torch.Tensor: Embedded tensor after patch convolution. |
| |
| Raises: |
| ValueError: If the channel dimension does not match the expected count. |
| """ |
| num_channels = pixel_values.shape[1] |
| if num_channels != self.num_channels: |
| raise ValueError( |
| "Make sure that the channel dimension of the pixel values match with the one set in the configuration." |
| ) |
| embeddings = self.patch_embeddings(pixel_values) |
| embeddings = self.layernorm(embeddings) |
| return embeddings |
|
|
|
|
| class ConvNextLayer(nn.Module): |
| """ConvNeXt V2 layer with GRN.""" |
|
|
| def __init__( |
| self, |
| dim: int, |
| drop_path: float = 0, |
| layer_scale_init_value: float = 1e-6, |
| use_grn: bool = True, |
| ) -> None: |
| """Construct a ConvNeXt V2 layer with GRN and scaling. |
| |
| Args: |
| dim (int): Input/output channel dimension. |
| drop_path (float): Drop path probability for stochastic depth. |
| layer_scale_init_value (float): Initial scaling factor for residual branches. |
| use_grn (bool): Whether to enable Global Response Normalization. |
| """ |
| super().__init__() |
| self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) |
| self.layernorm = ConvNextLayerNorm(dim, eps=1e-6) |
| self.pwconv1 = nn.Linear(dim, 4 * dim) |
| self.act = nn.GELU() |
| if not use_grn: |
| raise ValueError("ConvNeXt V2 requires use_grn=True.") |
| self.grn = ConvNextV2GRN(4 * dim) |
| self.pwconv2 = nn.Linear(4 * dim, dim) |
| self.layer_scale_parameter = ( |
| nn.Parameter( |
| layer_scale_init_value * torch.ones((dim)), requires_grad=True |
| ) |
| if layer_scale_init_value > 0 |
| else None |
| ) |
| self.drop_path = ( |
| ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| ) |
|
|
| def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: |
| """Run the ConvNeXt layer forward. |
| |
| Args: |
| hidden_states (torch.FloatTensor): Input tensor shaped (batch, channels, height, width). |
| |
| Returns: |
| torch.Tensor: Tensor after depthwise conv, GRN, and residual connection. |
| """ |
| input = hidden_states |
| x = self.dwconv(hidden_states) |
| x = x.permute(0, 2, 3, 1) |
| x = self.layernorm(x) |
| x = self.pwconv1(x) |
| x = self.act(x) |
| x = self.grn(x) |
| x = self.pwconv2(x) |
| if self.layer_scale_parameter is not None: |
| x = self.layer_scale_parameter * x |
| x = x.permute(0, 3, 1, 2) |
|
|
| x = input + self.drop_path(x) |
| return x |
|
|
|
|
| class ConvNextStage(nn.Module): |
| """ConvNeXt V2 stage with optional downsampling and residual blocks.""" |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int = 2, |
| stride: int = 2, |
| depth: int = 2, |
| drop_path_rates: Optional[list[float]] = None, |
| layer_scale_init_value: float = 1e-6, |
| use_grn: bool = True, |
| ) -> None: |
| """Build a ConvNeXt stage that can downsample and stack layers. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int): Number of output channels. |
| kernel_size (int): Kernel size for stripe downsampling. |
| stride (int): Stride for downsampling. |
| depth (int): Number of layers in the stage. |
| drop_path_rates (Optional[list[float]]): Per-layer drop path rates. |
| layer_scale_init_value (float): Residual scaling initial value. |
| use_grn (bool): Whether to enable GRN. |
| """ |
| super().__init__() |
|
|
| if in_channels != out_channels or stride > 1: |
| self.downsampling_layer = nn.Sequential( |
| ConvNextLayerNorm( |
| in_channels, eps=1e-6, data_format="channels_first" |
| ), |
| nn.Conv2d( |
| in_channels, |
| out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| ), |
| ) |
| else: |
| self.downsampling_layer = nn.Identity() |
| drop_path_rates = drop_path_rates or [0.0] * depth |
| self.layers = nn.Sequential( |
| *[ |
| ConvNextLayer( |
| dim=out_channels, |
| drop_path=drop_path_rates[j], |
| layer_scale_init_value=layer_scale_init_value, |
| use_grn=use_grn, |
| ) |
| for j in range(depth) |
| ] |
| ) |
|
|
| def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor: |
| """Process a batch through downsampling and residual layers. |
| |
| Args: |
| hidden_states (torch.FloatTensor): Input tensor of shape (batch, channels, height, width). |
| |
| Returns: |
| torch.Tensor: Output tensor after the stage. |
| """ |
| hidden_states = self.downsampling_layer(hidden_states) |
| hidden_states = self.layers(hidden_states) |
| return hidden_states |
|
|
|
|
| class ConvNextEncoder(nn.Module): |
| """ConvNeXt V2 encoder.""" |
|
|
| def __init__( |
| self, |
| hidden_sizes: list[int], |
| depths: list[int], |
| drop_path_rate: float = 0.0, |
| layer_scale_init_value: float = 1e-6, |
| use_grn: bool = True, |
| ) -> None: |
| """Construct the ConvNeXt encoder with multiple stages. |
| |
| Args: |
| hidden_sizes (list[int]): Hidden dimensions per stage. |
| depths (list[int]): Number of layers per stage. |
| drop_path_rate (float): Maximum drop path rate (linear schedule). |
| layer_scale_init_value (float): Initial residual scaling. |
| use_grn (bool): Whether to use GRN within layers. |
| """ |
| super().__init__() |
| self.stages = nn.ModuleList() |
| self.gradient_checkpointing = False |
| num_stages = len(hidden_sizes) |
| total_depth = sum(depths) |
| drop_path_schedule = np.linspace( |
| 0.0, float(drop_path_rate), total_depth |
| ).tolist() |
| drop_path_rates = [] |
| start = 0 |
| for depth in depths: |
| end = start + depth |
| drop_path_rates.append(drop_path_schedule[start:end]) |
| start = end |
| |
| prev_chs = hidden_sizes[0] |
| for i in range(num_stages): |
| out_chs = hidden_sizes[i] |
| stage = ConvNextStage( |
| in_channels=prev_chs, |
| out_channels=out_chs, |
| stride=2 if i > 0 else 1, |
| depth=depths[i], |
| drop_path_rates=drop_path_rates[i], |
| layer_scale_init_value=layer_scale_init_value, |
| use_grn=use_grn, |
| ) |
| self.stages.append(stage) |
| prev_chs = out_chs |
|
|
| def forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| output_hidden_states: Optional[bool] = False, |
| return_dict: Optional[bool] = True, |
| ) -> Tuple: |
| """Forward propagate through the ConvNeXt encoder stack. |
| |
| Args: |
| hidden_states (torch.FloatTensor): Input tensor shaped (batch, channels, height, width). |
| output_hidden_states (Optional[bool]): Whether to collect intermediate states. |
| return_dict (Optional[bool]): Whether to return tuple or dict-like output. |
| |
| Returns: |
| Tuple: Last hidden state followed by optional hidden states tuple. |
| """ |
| all_hidden_states = () if output_hidden_states else None |
|
|
| for i, layer_module in enumerate(self.stages): |
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| hidden_states = torch.utils.checkpoint.checkpoint( |
| layer_module, |
| hidden_states, |
| use_reentrant=False, |
| ) |
| else: |
| hidden_states = layer_module(hidden_states) |
|
|
| if output_hidden_states: |
| all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
| if not return_dict: |
| return tuple( |
| v for v in [hidden_states, all_hidden_states] if v is not None |
| ) |
|
|
| return (hidden_states, all_hidden_states) |
|
|
|
|
| class ConvNextModel(nn.Module): |
| """ConvNeXt V2 model.""" |
|
|
| def __init__( |
| self, |
| hidden_sizes: list[int], |
| depths: list[int], |
| num_channels: int = 3, |
| patch_size: int = 4, |
| drop_path_rate: float = 0.0, |
| layer_scale_init_value: float = 1e-6, |
| use_grn: bool = True, |
| ) -> None: |
| """Build the ConvNeXt V2 model with embedding, encoder, and pooling. |
| |
| Args: |
| hidden_sizes (list[int]): Hidden channel sizes per stage. |
| depths (list[int]): Layer counts per stage. |
| num_channels (int): Number of image channels. |
| patch_size (int): Patch size for initial embedding. |
| drop_path_rate (float): Drop path rate range for residual blocks. |
| layer_scale_init_value (float): Initial scale for residuals. |
| use_grn (bool): Whether to enable GRN. |
| """ |
| super().__init__() |
| if not use_grn: |
| raise ValueError("ConvNeXt V2 requires use_grn=True.") |
| self.embeddings = ConvNextEmbeddings( |
| num_channels, hidden_sizes[0], patch_size |
| ) |
| self.encoder = ConvNextEncoder( |
| hidden_sizes, |
| depths, |
| drop_path_rate, |
| layer_scale_init_value, |
| use_grn, |
| ) |
| self.layernorm = nn.LayerNorm(hidden_sizes[-1], eps=1e-6) |
|
|
| |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, module: nn.Module) -> None: |
| """Initialize module weights following standard ConvNeXt heuristics. |
| |
| Args: |
| module (nn.Module): Module to initialize. |
| """ |
| if isinstance(module, (nn.Linear, nn.Conv2d)): |
| module.weight.data.normal_(mean=0.0, std=0.02) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.LayerNorm): |
| module.bias.data.zero_() |
| module.weight.data.fill_(1.0) |
|
|
| def forward( |
| self, |
| pixel_values: Optional[torch.FloatTensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = True, |
| return_pooled: bool = True, |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| """Encode images and optionally return pooled features. |
| |
| Args: |
| pixel_values (Optional[torch.FloatTensor]): Input tensor shaped (batch, channels, height, width). |
| output_hidden_states (Optional[bool]): Whether to return intermediate hidden states. |
| return_dict (Optional[bool]): Whether to return output as BaseModelOutput. |
| return_pooled (bool): Whether to include pooled output. |
| |
| Returns: |
| Union[Tuple, BaseModelOutputWithPooling]: Model outputs containing last hidden states and optionally pooled output. |
| |
| Raises: |
| ValueError: If `pixel_values` is None. |
| """ |
| if pixel_values is None: |
| raise ValueError("You have to specify pixel_values") |
|
|
| embedding_output = self.embeddings(pixel_values) |
|
|
| encoder_outputs = self.encoder( |
| embedding_output, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| last_hidden_state = encoder_outputs[0] |
| all_hidden_states = ( |
| encoder_outputs[1] if output_hidden_states else None |
| ) |
|
|
| |
| pooled_output = None |
| if return_pooled: |
| |
| pooled_output = self.layernorm(last_hidden_state.mean([-2, -1])) |
|
|
| if not return_dict: |
| outputs = [last_hidden_state] |
| if return_pooled: |
| outputs.append(pooled_output) |
| if output_hidden_states: |
| outputs.append(all_hidden_states) |
| return tuple(outputs) |
|
|
| return BaseModelOutputWithPooling( |
| last_hidden_state=last_hidden_state, |
| pooler_output=pooled_output, |
| hidden_states=all_hidden_states, |
| ) |
|
|
| @staticmethod |
| def from_pretrained(model_path: Path | str) -> "ConvNextModel": |
| """Load ConvNeXt model weights from a pretrained checkpoint directory. |
| |
| Args: |
| model_path (Path | str): Directory path containing the checkpoint files. |
| |
| Returns: |
| ConvNextModel: Initialized model with weights loaded from checkpoint. |
| |
| Raises: |
| NotImplementedError: If config.json is missing in the directory. |
| FileNotFoundError: If no weight file is found. |
| """ |
|
|
| model_path_str = str(model_path) |
| model_path_obj = Path(model_path_str) |
|
|
| |
| is_ckpt_dir = ( |
| model_path_obj.is_dir() |
| and (model_path_obj / "config.json").exists() |
| ) |
|
|
| if not is_ckpt_dir: |
| raise NotImplementedError( |
| "The checkpoint path should be a directory containing config.json " |
| "and model.safetensors or pytorch_model.bin files." |
| ) |
|
|
| |
| config = ConvNextConfig.from_pretrained(model_path_str) |
|
|
| checkpoint_dir = model_path_obj |
|
|
| |
| if not config.use_grn: |
| raise ValueError( |
| "ConvNeXt V2 requires use_grn=True in the checkpoint config." |
| ) |
| logger.info( |
| "Loading ConvNeXt V2 model from checkpoint: %s", checkpoint_dir |
| ) |
|
|
| model = ConvNextModel( |
| hidden_sizes=config.hidden_sizes, |
| depths=config.depths, |
| num_channels=config.num_channels, |
| patch_size=config.patch_size, |
| drop_path_rate=config.drop_path_rate, |
| layer_scale_init_value=config.layer_scale_init_value, |
| use_grn=config.use_grn, |
| ) |
|
|
| |
| state_dict = {} |
|
|
| |
| safetensors_file = checkpoint_dir / "model.safetensors" |
| if safetensors_file.exists(): |
| logger.info("Loading weights from %s", safetensors_file) |
| state_dict = safetensors_load(str(safetensors_file)) |
| else: |
| |
| pytorch_file = checkpoint_dir / "pytorch_model.bin" |
| if pytorch_file.exists(): |
| logger.info("Loading weights from %s", pytorch_file) |
| state_dict = torch.load( |
| str(pytorch_file), map_location="cpu", weights_only=False |
| ) |
| else: |
| |
| shard_files = sorted( |
| glob.glob(str(checkpoint_dir / "pytorch_model-*.bin")) |
| ) |
| if shard_files: |
| logger.info( |
| "Loading weights from %s sharded files", |
| len(shard_files), |
| ) |
| for shard_file in shard_files: |
| state_dict.update( |
| torch.load( |
| shard_file, |
| map_location="cpu", |
| weights_only=False, |
| ) |
| ) |
| else: |
| raise FileNotFoundError( |
| f"Could not find model weights in {checkpoint_dir}. " |
| "Expected model.safetensors, pytorch_model.bin, or pytorch_model-*.bin files." |
| ) |
|
|
| |
| missing_keys, unexpected_keys = model.load_state_dict( |
| state_dict, strict=False |
| ) |
|
|
| if missing_keys: |
| logger.warning( |
| "Some weights of the model were not initialized from the checkpoint " |
| "and are newly initialized: %s", |
| missing_keys, |
| ) |
|
|
| if unexpected_keys: |
| logger.warning( |
| "Some weights of the checkpoint were not used when initializing the model: %s", |
| unexpected_keys, |
| ) |
|
|
| return model |
|
|
|
|
| class ConvNextVisionModel(nn.Module): |
| """Vision model wrapper around ConvNeXt V2 backbone.""" |
|
|
| def __init__(self, config: Optional[ConvNextConfig] = None): |
| """Wrap ConvNeXt backbone for use within the multimodal stack. |
| |
| Args: |
| config (Optional[ConvNextConfig]): Configuration for the ConvNeXt backbone. |
| |
| Raises: |
| ValueError: If the config lacks required ConvNeXt attributes. |
| """ |
| super().__init__() |
| if config is None: |
| config = ConvNextConfig.convnext_large() |
|
|
| self.config = config |
|
|
| |
| if hasattr(config, "hidden_sizes"): |
| |
| hidden_sizes = config.hidden_sizes |
| depths = config.depths |
| num_channels = config.num_channels |
| patch_size = config.patch_size |
| drop_path_rate = config.drop_path_rate |
| layer_scale_init_value = config.layer_scale_init_value |
| use_grn = config.use_grn |
| else: |
| raise ValueError("Config must be a ConvNextConfig") |
| if not use_grn: |
| raise ValueError("ConvNeXt V2 requires use_grn=True.") |
|
|
| self.backbone = ConvNextModel( |
| hidden_sizes=hidden_sizes, |
| depths=depths, |
| num_channels=num_channels, |
| patch_size=patch_size, |
| drop_path_rate=drop_path_rate, |
| layer_scale_init_value=layer_scale_init_value, |
| use_grn=use_grn, |
| ) |
|
|
| @staticmethod |
| def from_pretrained(model_path: Path | str) -> "ConvNextVisionModel": |
| """Load a vision wrapper with pretrained ConvNeXt weights. |
| |
| Args: |
| model_path (Path | str): Directory path containing the pretrained weights. |
| |
| Returns: |
| ConvNextVisionModel: Wrapper instance with backbone weights loaded. |
| """ |
| |
| backbone = ConvNextModel.from_pretrained(model_path) |
| config = ConvNextConfig.from_pretrained(str(model_path)) |
| wrapper = ConvNextVisionModel(config) |
| wrapper.backbone = backbone |
|
|
| return wrapper |
|
|
| def forward( |
| self, |
| pixel_values: torch.FloatTensor, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: bool = True, |
| patch_attention_mask: Optional[torch.Tensor] = None, |
| return_pooled: bool = True, |
| ) -> Union[Tuple, BaseModelOutputWithPooling]: |
| """Encode pixel values and reformat the ConvNeXt output. |
| |
| Args: |
| pixel_values (torch.FloatTensor): Input tensor shaped (batch, channels, height, width). |
| output_attentions (Optional[bool]): Ignored but present for compatibility. |
| output_hidden_states (Optional[bool]): Whether to return staged hidden states. |
| return_dict (bool): Whether to return `BaseModelOutputWithPooling`. |
| patch_attention_mask (Optional[torch.Tensor]): Mask for patch tokens (unused here). |
| return_pooled (bool): Whether to request pooled output. |
| |
| Returns: |
| Union[Tuple, BaseModelOutputWithPooling]: Vision outputs in sequence format. |
| """ |
| |
| outputs = self.backbone( |
| pixel_values, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| return_pooled=return_pooled, |
| ) |
| outputs = cast(BaseModelOutputWithPooling, outputs) |
| last_hidden_state = outputs.last_hidden_state |
| pooled = outputs.pooler_output if return_pooled else None |
|
|
| |
| last_hidden_state = rearrange( |
| last_hidden_state, "b c h w -> b (h w) c" |
| ) |
|
|
| if return_dict: |
| return BaseModelOutputWithPooling( |
| last_hidden_state=last_hidden_state, |
| pooler_output=pooled, |
| hidden_states=( |
| outputs.hidden_states if output_hidden_states else None |
| ), |
| ) |
|
|
| if output_hidden_states: |
| outputs_tuple = [last_hidden_state] |
| if return_pooled: |
| outputs_tuple.append(pooled) |
| outputs_tuple.append(outputs.hidden_states) |
| return tuple(outputs_tuple) |
|
|
| if return_pooled: |
| return (last_hidden_state, pooled) |
|
|
| return (last_hidden_state,) |
|
|
|
|
| |
| @use_kernel_forward_from_hub("RMSNorm") |
| class YasaRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| """ |
| YasaRMSNorm is equivalent to T5LayerNorm |
| """ |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
|
|
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt( |
| variance + self.variance_epsilon |
| ) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| def extra_repr(self): |
| return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" |
|
|
|
|
| class YasaRotaryEmbedding(nn.Module): |
| inv_freq: torch.Tensor |
|
|
| def __init__(self, config: YasaConfig, device=None): |
| super().__init__() |
| |
| if hasattr(config, "rope_scaling") and isinstance( |
| config.rope_scaling, dict |
| ): |
| self.rope_type = config.rope_scaling.get( |
| "rope_type", config.rope_scaling.get("type") |
| ) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn( |
| self.config, device |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = ( |
| self.inv_freq[None, :, None] |
| .float() |
| .expand(position_ids.shape[0], -1, 1) |
| .to(x.device) |
| ) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = ( |
| x.device.type |
| if isinstance(x.device.type, str) and x.device.type != "mps" |
| else "cpu" |
| ) |
| with torch.autocast( |
| device_type=device_type, enabled=False |
| ): |
| freqs = ( |
| inv_freq_expanded.float() @ position_ids_expanded.float() |
| ).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
|
|
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., : x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2 :] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`, *optional*): |
| Deprecated and unused. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class YasaMLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.intermediate_size = config.intermediate_size |
| self.gate_proj = nn.Linear( |
| self.hidden_size, self.intermediate_size, bias=config.mlp_bias |
| ) |
| self.up_proj = nn.Linear( |
| self.hidden_size, self.intermediate_size, bias=config.mlp_bias |
| ) |
| self.down_proj = nn.Linear( |
| self.intermediate_size, self.hidden_size, bias=config.mlp_bias |
| ) |
| self.act_fn = ACT2FN[config.hidden_act] |
|
|
| def forward(self, x): |
| down_proj = self.down_proj( |
| self.act_fn(self.gate_proj(x)) * self.up_proj(x) |
| ) |
| return down_proj |
|
|
|
|
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |
| """ |
| This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, |
| num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) |
| """ |
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape |
| if n_rep == 1: |
| return hidden_states |
| hidden_states = hidden_states[:, :, None, :, :].expand( |
| batch, num_key_value_heads, n_rep, slen, head_dim |
| ) |
| return hidden_states.reshape( |
| batch, num_key_value_heads * n_rep, slen, head_dim |
| ) |
|
|
|
|
| def eager_attention_forward( |
| module: nn.Module, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attention_mask: Optional[torch.Tensor], |
| scaling: float, |
| dropout: float = 0.0, |
| **kwargs: Unpack[TransformersKwargs], |
| ): |
| key_states = repeat_kv(key, module.num_key_value_groups) |
| value_states = repeat_kv(value, module.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling |
| if attention_mask is not None: |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
| attn_weights = attn_weights + causal_mask |
|
|
| attn_weights = nn.functional.softmax( |
| attn_weights, dim=-1, dtype=torch.float32 |
| ).to(query.dtype) |
| attn_weights = nn.functional.dropout( |
| attn_weights, p=dropout, training=module.training |
| ) |
| attn_output = torch.matmul(attn_weights, value_states) |
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class YasaAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: YasaConfig, layer_idx: int): |
| super().__init__() |
| self.config = config |
| self.layer_idx = layer_idx |
| self.head_dim = getattr( |
| config, |
| "head_dim", |
| config.hidden_size // config.num_attention_heads, |
| ) |
| self.num_key_value_groups = ( |
| config.num_attention_heads // config.num_key_value_heads |
| ) |
| self.scaling = self.head_dim**-0.5 |
| self.attention_dropout = config.attention_dropout |
| self.is_causal = True |
|
|
| self.q_proj = nn.Linear( |
| config.hidden_size, |
| config.num_attention_heads * self.head_dim, |
| bias=config.attention_bias, |
| ) |
| self.k_proj = nn.Linear( |
| config.hidden_size, |
| config.num_key_value_heads * self.head_dim, |
| bias=config.attention_bias, |
| ) |
| self.v_proj = nn.Linear( |
| config.hidden_size, |
| config.num_key_value_heads * self.head_dim, |
| bias=config.attention_bias, |
| ) |
| self.o_proj = nn.Linear( |
| config.num_attention_heads * self.head_dim, |
| config.hidden_size, |
| bias=config.attention_bias, |
| ) |
|
|
| @deprecate_kwarg( |
| "past_key_value", new_name="past_key_values", version="4.58" |
| ) |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| position_embeddings: tuple[torch.Tensor, torch.Tensor], |
| attention_mask: Optional[torch.Tensor], |
| past_key_values: Optional[Cache] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| input_shape = hidden_states.shape[:-1] |
| hidden_shape = (*input_shape, -1, self.head_dim) |
|
|
| query_states = ( |
| self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| ) |
| key_states = ( |
| self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| ) |
| value_states = ( |
| self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| ) |
|
|
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb( |
| query_states, key_states, cos, sin |
| ) |
|
|
| if past_key_values is not None: |
| |
| cache_kwargs = { |
| "sin": sin, |
| "cos": cos, |
| "cache_position": cache_position, |
| } |
| key_states, value_states = past_key_values.update( |
| key_states, value_states, self.layer_idx, cache_kwargs |
| ) |
|
|
| attention_interface: Callable = eager_attention_forward |
| if self.config._attn_implementation != "eager": |
| attention_interface = ALL_ATTENTION_FUNCTIONS[ |
| self.config._attn_implementation |
| ] |
|
|
| attn_output, attn_weights = attention_interface( |
| self, |
| query_states, |
| key_states, |
| value_states, |
| attention_mask, |
| dropout=0.0 if not self.training else self.attention_dropout, |
| scaling=self.scaling, |
| **kwargs, |
| ) |
|
|
| attn_output = attn_output.reshape(*input_shape, -1).contiguous() |
| attn_output = self.o_proj(attn_output) |
| return attn_output, attn_weights |
|
|
|
|
| class YasaDecoderLayer(GradientCheckpointingLayer): |
| def __init__(self, config: YasaConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
|
|
| self.self_attn = YasaAttention(config=config, layer_idx=layer_idx) |
|
|
| self.mlp = YasaMLP(config) |
| self.input_layernorm = YasaRMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
| self.post_attention_layernorm = YasaRMSNorm( |
| config.hidden_size, eps=config.rms_norm_eps |
| ) |
|
|
| @deprecate_kwarg( |
| "past_key_value", new_name="past_key_values", version="4.58" |
| ) |
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| use_cache: Optional[bool] = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| position_embeddings: Optional[ |
| tuple[torch.Tensor, torch.Tensor] |
| ] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> torch.Tensor: |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| |
| hidden_states, _ = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
| return hidden_states |
|
|
|
|
| class YasaPreTrainedModel(PreTrainedModel): |
| config = Yasa2Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["YasaDecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_flash_attn = True |
| _supports_sdpa = True |
| _supports_flex_attn = True |
|
|
| _can_compile_fullgraph = True |
| _supports_attention_backend = True |
| _can_record_outputs = { |
| "hidden_states": YasaDecoderLayer, |
| "attentions": YasaAttention, |
| } |
|
|
|
|
| @auto_docstring |
| class YasaModel(YasaPreTrainedModel): |
| def __init__(self, config: YasaConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding( |
| config.vocab_size, config.hidden_size, self.padding_idx |
| ) |
| self.layers = nn.ModuleList( |
| [ |
| YasaDecoderLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
| self.norm = YasaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.rotary_emb = YasaRotaryEmbedding(config=config) |
| self.gradient_checkpointing = False |
|
|
| |
| self.post_init() |
|
|
| @check_model_inputs() |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You must specify exactly one of input_ids or inputs_embeds" |
| ) |
|
|
| if inputs_embeds is None: |
| inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = ( |
| past_key_values.get_seq_length() |
| if past_key_values is not None |
| else 0 |
| ) |
| cache_position: torch.Tensor = ( |
| torch.arange( |
| inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
| + past_seen_tokens |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb( |
| hidden_states, position_ids=position_ids |
| ) |
|
|
| for decoder_layer in self.layers[: self.config.num_hidden_layers]: |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_embeddings=position_embeddings, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
|
|
| class Yasa2Model(YasaPreTrainedModel): |
| """Pretrained base class that holds the full Yasa2 multimodal stack.""" |
|
|
| config_class: PretrainedConfig = Yasa2Config |
|
|
| base_model_prefix: str = "" |
| _checkpoint_conversion_mapping: Dict[str, str] = {} |
| _no_split_modules = ["YasaDecoderLayer", "ConvNextVisionModel"] |
| config: Yasa2Config |
|
|
| def __init__( |
| self, |
| config: Yasa2Config, |
| ): |
| """Initialize the full Yasa2 multimodal stack. |
| |
| Args: |
| config (Yasa2Config): Configuration for the multimodal model. |
| """ |
| super().__init__(config) |
|
|
| self.vision_pooling = config.vision_pooling |
| if self.vision_pooling != "adaptive_avg": |
| raise ValueError( |
| f"Yasa2 only supports adaptive_avg vision pooling, got {self.vision_pooling}" |
| ) |
| self.adaptive_pooling = nn.AdaptiveAvgPool2d( |
| int(config.num_query_tokens**0.5) |
| ) |
|
|
| if not (config.num_query_tokens**0.5).is_integer(): |
| raise ValueError( |
| f"num_query_tokens {config.num_query_tokens} must be a " |
| "square number for adaptive_avg pooling" |
| ) |
|
|
| |
| vision_config = config.vision_config |
| if isinstance(vision_config, dict): |
| vision_config = ConvNextConfig(**vision_config) |
| self.vision_model = ConvNextVisionModel(vision_config) |
|
|
| self.language_projection = nn.Sequential( |
| nn.Linear( |
| config.vision_config.hidden_size, |
| config.text_config.hidden_size, |
| ), |
| nn.GELU(), |
| nn.Linear( |
| config.text_config.hidden_size, |
| config.text_config.hidden_size, |
| ), |
| ) |
|
|
| |
| self.language_model = YasaModel(config.text_config) |
|
|
| |
| |
| self.add_vision_pos_embed = config.use_vision_pos_embed |
| self._vision_pos_embed_np = get_2d_sincos_pos_embed( |
| config.vision_config.hidden_size, |
| image_size=50, |
| ) |
| self._vision_pos_embed_cache: Dict[str, torch.Tensor] = {} |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self) -> torch.nn.Module: |
| """Return the multimodal head's input embeddings. |
| |
| Returns: |
| torch.nn.Module: Embedding module used by the language model. |
| """ |
| return self.language_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: torch.nn.Module) -> None: |
| """Override the multimodal head's input embeddings. |
| |
| Args: |
| value (torch.nn.Module): Embedding module to register. |
| """ |
| self.language_model.set_input_embeddings(value) |
|
|
| def set_decoder(self, decoder: YasaModel) -> None: |
| """Proxy to set the multimodal model decoder. |
| |
| Args: |
| decoder: Decoder to register with the multimodal model. |
| """ |
| self.language_model = decoder |
|
|
| def get_decoder(self) -> YasaModel: |
| """Return the decoder component. |
| |
| Returns: |
| YasaModel: Registered decoder module. |
| """ |
| return self.language_model |
|
|
| def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, torch.Tensor]: |
| """Return a filtered state dict that omits derived or non-persistent buffers. |
| |
| Args: |
| *args: Positional arguments forwarded to the superclass. |
| **kwargs: Keyword arguments forwarded to the superclass. |
| |
| Returns: |
| Dict[str, torch.Tensor]: Filtered parameter mapping. |
| """ |
| state_dict = super().state_dict(*args, **kwargs) |
| for key in list(state_dict.keys()): |
| |
| if "attention.masked_bias" in key: |
| state_dict.pop(key, None) |
| continue |
| |
| if "rotary_emb.inv_freq" in key: |
| state_dict.pop(key, None) |
| return state_dict |
|
|
| def _encode_vision_adaptive_2d_avg_pooling( |
| self, |
| pixel_values: torch.Tensor, |
| patch_attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Encode vision inputs via the ConvNeXt backbone and adaptive avg pooling. |
| |
| Args: |
| pixel_values (torch.Tensor): Vision input tensor. |
| patch_attention_mask (Optional[torch.Tensor]): Optional patch mask. |
| |
| Returns: |
| torch.Tensor: Vision embeddings projected into text hidden size. |
| """ |
| |
| image_embeds = self.vision_model( |
| pixel_values=pixel_values, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=False, |
| patch_attention_mask=patch_attention_mask, |
| return_pooled=False, |
| )[0] |
|
|
| img_num, seq_length, vision_hidden_size = image_embeds.size() |
| height, width = int(seq_length**0.5), int(seq_length**0.5) |
| if self.add_vision_pos_embed: |
| vision_pos_embed = self._get_vision_pos_embed( |
| device=image_embeds.device, |
| dtype=image_embeds.dtype, |
| seq_len=image_embeds.size(1), |
| ) |
| image_embeds = image_embeds + vision_pos_embed |
|
|
| image_embeds = image_embeds.permute(0, 2, 1).contiguous() |
| image_embeds = image_embeds.reshape( |
| img_num, vision_hidden_size, height, width |
| ) |
|
|
| if ( |
| self.config.apply_patch_attention_mask |
| and patch_attention_mask is not None |
| and patch_attention_mask.numel() > 0 |
| ): |
| patch_attention_mask = patch_attention_mask.reshape( |
| img_num, height, width |
| ) |
| image_embeds = image_embeds * patch_attention_mask.unsqueeze(1).to( |
| dtype=image_embeds.dtype |
| ) |
|
|
| |
| pooled_dtype = image_embeds.dtype |
| with torch.autocast(device_type="cuda", enabled=False): |
| image_embeds = torch.nn.functional.adaptive_avg_pool2d( |
| image_embeds.float(), self.adaptive_pooling.output_size |
| ) |
| image_embeds = image_embeds.to(dtype=pooled_dtype) |
| image_embeds = image_embeds.flatten(2) |
| image_embeds = image_embeds.permute(0, 2, 1).contiguous() |
|
|
| vision_embeds = self.language_projection(image_embeds) |
|
|
| return vision_embeds |
|
|
| def _get_vision_pos_embed( |
| self, |
| device: torch.device, |
| dtype: torch.dtype, |
| seq_len: int, |
| ) -> torch.Tensor: |
| """Return cached/runtime-built vision positional embeddings.""" |
| cache_key = f"{device}:{dtype}" |
| cached = self._vision_pos_embed_cache.get(cache_key) |
| if cached is None: |
| cached = ( |
| torch.from_numpy(self._vision_pos_embed_np) |
| .view(-1, self.config.vision_config.hidden_size) |
| .to(device=device, dtype=dtype) |
| .unsqueeze(0) |
| ) |
| self._vision_pos_embed_cache[cache_key] = cached |
| return cached[:, :seq_len, :] |
|
|
| def get_image_features( |
| self, pixel_values: torch.Tensor, **kwargs: Any |
| ) -> torch.Tensor: |
| """Return vision features for vLLM compatibility.""" |
| patch_attention_mask = kwargs.get("patch_attention_mask") |
| return self._encode_vision_adaptive_2d_avg_pooling( |
| pixel_values, patch_attention_mask=patch_attention_mask |
| ) |
|
|
| @classmethod |
| def scatter_embeddings_to_target_special_id( |
| cls, |
| target_tensor: torch.Tensor, |
| target_input_ids: torch.Tensor, |
| src_embeddings: torch.Tensor, |
| special_token_id: int, |
| ) -> torch.Tensor: |
| """Scatter vision embeddings into the language embedding buffer at special tokens. |
| |
| Args: |
| target_tensor (torch.Tensor): Target embedding buffer to update. |
| target_input_ids (torch.Tensor): Input IDs aligned with the target tensor. |
| src_embeddings (torch.Tensor): Source embeddings to scatter from vision outputs. |
| special_token_id (int): Token ID used to locate insertion positions. |
| |
| Returns: |
| torch.Tensor: Updated target tensor with vision embeddings placed at special IDs. |
| """ |
| b_source, n_source, d_embedding = src_embeddings.shape |
| b_target, n_target, d_target = target_tensor.shape |
|
|
| if b_target != target_input_ids.size(0): |
| raise ValueError( |
| "Batch size mismatch: target_input_ids " |
| f"{target_input_ids.size(0)} vs target_tensor {b_target}" |
| ) |
| if n_target != target_input_ids.size(1): |
| raise ValueError( |
| "Sequence length mismatch: target_input_ids " |
| f"{target_input_ids.size(1)} vs target_tensor {n_target}" |
| ) |
| if d_embedding != d_target: |
| raise ValueError( |
| "Embedding dimension mismatch: src_embeddings " |
| f"{d_embedding} vs target_tensor {d_target}" |
| ) |
|
|
| special_token_mask = target_input_ids.view(-1) == special_token_id |
| special_token_indices = torch.nonzero(special_token_mask).squeeze(-1) |
|
|
| if len(special_token_indices) != b_source * n_source: |
| raise ValueError( |
| "Special token count mismatch: found " |
| f"{len(special_token_indices)}, expected {b_source * n_source}" |
| ) |
|
|
| target_tensor = target_tensor.view(-1, d_embedding) |
| src_embeddings = src_embeddings.view(-1, d_embedding) |
| target_tensor[special_token_indices] = src_embeddings |
| target_tensor = target_tensor.view(b_target, n_target, d_embedding) |
| return target_tensor |
|
|
| def _interleave_scatter( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| inputs_embeds: torch.Tensor, |
| vision_embeds: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Scatter vision embeddings into language embeddings at the image token positions. |
| |
| Args: |
| input_ids (torch.Tensor): Token IDs containing image placeholders. |
| attention_mask (torch.Tensor): Attention mask for text tokens. |
| inputs_embeds (torch.Tensor): Language model input embeddings. |
| vision_embeds (torch.Tensor): Vision embeddings to be inserted. |
| |
| Returns: |
| Tuple[torch.Tensor, torch.Tensor]: Updated inputs_embeds and attention_mask. |
| """ |
| inputs_embeds = Yasa2Model.scatter_embeddings_to_target_special_id( |
| target_tensor=inputs_embeds, |
| target_input_ids=input_ids, |
| src_embeddings=vision_embeds, |
| special_token_id=self.config.image_token_id, |
| ) |
| return inputs_embeds, attention_mask |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| past_key_values: Optional[ |
| Union[Cache, Tuple[Tuple[torch.FloatTensor]]] |
| ] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| patch_attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| mm_token_type_ids: Optional[torch.Tensor] = None, |
| **kwargs: Any, |
| ) -> Union[Tuple[torch.Tensor, ...], "Yasa2ModelOutputWithPast"]: |
| """Forward pass combining language and vision inputs for Yasa2. |
| |
| Args: |
| input_ids (Optional[torch.LongTensor]): Token IDs for the language model. |
| attention_mask (Optional[torch.Tensor]): Attention mask aligned with `input_ids`. |
| position_ids (Optional[torch.LongTensor]): Position indices feeding the language model. |
| inputs_embeds (Optional[torch.FloatTensor]): Precomputed token embeddings. |
| past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached decoder key/value tensors. |
| cache_position (Optional[torch.LongTensor]): Positions used for cache alignment. |
| use_cache (Optional[bool]): Whether to request cached key/values. |
| output_attentions (Optional[bool]): Whether to return attention weights. |
| output_hidden_states (Optional[bool]): Whether to return hidden states for each layer. |
| return_dict (Optional[bool]): Whether to return a `ModelOutput`. |
| pixel_values (Optional[torch.Tensor]): Vision inputs providing image context. |
| patch_attention_mask (Optional[torch.Tensor]): Optional patch mask for vision tokens. |
| token_type_ids (Optional[torch.Tensor]): Unused token type ids for compatibility. |
| mm_token_type_ids (Optional[torch.Tensor]): Unused multimodal token type ids. |
| |
| Returns: |
| Union[Tuple[torch.Tensor, ...], Yasa2ModelOutputWithPast]: Combined multimodal outputs. |
| """ |
| return_dict = ( |
| return_dict |
| if return_dict is not None |
| else self.config.use_return_dict |
| ) |
| use_cache = ( |
| use_cache if use_cache is not None else self.config.use_cache |
| ) |
|
|
| if input_ids is None and inputs_embeds is None: |
| raise ValueError( |
| "You must provide either input_ids or inputs_embeds." |
| ) |
| if inputs_embeds is not None and pixel_values is not None: |
| raise ValueError( |
| "pixel_values cannot be used when inputs_embeds is provided." |
| ) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.language_model.get_input_embeddings()( |
| input_ids |
| ) |
|
|
| if attention_mask is None: |
| pad_token_id = self.config.text_config.pad_token_id |
| if input_ids is not None and pad_token_id is not None: |
| if (input_ids == pad_token_id).any(): |
| attention_mask = input_ids.ne(pad_token_id) |
|
|
| if attention_mask is not None: |
| if attention_mask.numel() == 0: |
| attention_mask = None |
|
|
| if cache_position is not None: |
| expected_len = inputs_embeds.shape[1] |
| if cache_position.shape[-1] != expected_len: |
| raise ValueError( |
| "cache_position length must match input sequence length: " |
| f"{cache_position.shape[-1]} vs {expected_len}" |
| ) |
|
|
| vision_embeds = None |
| if pixel_values is not None and len(pixel_values) > 0: |
| if input_ids is None: |
| raise ValueError( |
| "input_ids is required when pixel_values is provided." |
| ) |
| vision_embeds = self._encode_vision_adaptive_2d_avg_pooling( |
| pixel_values, |
| patch_attention_mask=patch_attention_mask, |
| ) |
| inputs_embeds, attention_mask = self._interleave_scatter( |
| input_ids, |
| attention_mask, |
| inputs_embeds, |
| vision_embeds, |
| ) |
|
|
| outputs = self.language_model( |
| input_ids=None, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| head_mask=None, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| return_dict=True, |
| **kwargs, |
| ) |
|
|
| return Yasa2ModelOutputWithPast( |
| last_hidden_state=outputs.last_hidden_state, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| vision_hidden_states=vision_embeds, |
| ) |
|
|
|
|
| class Yasa2ForConditionalGeneration(YasaPreTrainedModel, GenerationMixin): |
| """Yasa2 multimodal conditional generation model (vision + text).""" |
|
|
| config_class = Yasa2Config |
|
|
| _checkpoint_conversion_mapping = {} |
| _tied_weights_keys = [] |
| config: Yasa2Config |
|
|
| def __init__(self, config: Yasa2Config): |
| """Initialize the Yasa2 conditional generation model. |
| |
| Args: |
| config: Yasa2 configuration object. |
| """ |
| super().__init__(config) |
|
|
| self.model = Yasa2Model(config) |
| self.lm_head = nn.Linear( |
| config.hidden_size, config.vocab_size, bias=False |
| ) |
| self.vocab_size = config.vocab_size |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self) -> torch.nn.Module: |
| """Return the multimodal head's input embeddings. |
| |
| Returns: |
| torch.nn.Module: Embedding module used by the language model. |
| """ |
| return self.model.language_model.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: torch.nn.Module) -> None: |
| """Override the multimodal head's input embeddings. |
| |
| Args: |
| value (torch.nn.Module): Embedding module to register. |
| """ |
| self.model.language_model.set_input_embeddings(value) |
|
|
| def set_decoder(self, decoder): |
| """Proxy to set the multimodal model decoder. |
| |
| Args: |
| decoder: Decoder to register with the multimodal model. |
| """ |
| self.model.set_decoder(decoder) |
|
|
| def get_decoder(self): |
| """Proxy to return the multimodal decoder.""" |
| return self.model.get_decoder() |
|
|
| |
| @property |
| def language_model(self) -> torch.nn.Module: |
| """Expose the language model component. |
| |
| Returns: |
| torch.nn.Module: Language model module. |
| """ |
| return self.model.language_model |
|
|
| @property |
| def vision_backbone(self) -> torch.nn.Module: |
| """Expose the vision encoder backbone. |
| |
| Returns: |
| torch.nn.Module: Vision backbone module. |
| """ |
| return self.model.vision_model |
|
|
| @can_return_tuple |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[ |
| Union[Cache, Tuple[Tuple[torch.FloatTensor]]] |
| ] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| patch_attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| mm_token_type_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs: Any, |
| ) -> Union[ |
| Tuple[torch.Tensor, ...], "Yasa2ForConditionalGenerationModelOutput" |
| ]: |
| """Run the multimodal model, project outputs to logits, and compute loss if needed. |
| |
| Args: |
| input_ids (Optional[torch.LongTensor]): Language token IDs. |
| attention_mask (Optional[torch.Tensor]): Attention mask for language tokens. |
| position_ids (Optional[torch.LongTensor]): Position indices. |
| past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached decoder states. |
| inputs_embeds (Optional[torch.FloatTensor]): Input embeddings instead of token IDs. |
| use_cache (Optional[bool]): Whether to cache key/value pairs. |
| output_attentions (Optional[bool]): Whether to return attention weights. |
| output_hidden_states (Optional[bool]): Whether to return hidden states. |
| cache_position (Optional[torch.LongTensor]): Positions used for caching. |
| pixel_values (Optional[torch.Tensor]): Vision inputs. |
| patch_attention_mask (Optional[torch.Tensor]): Optional mask for vision patches. |
| token_type_ids (Optional[torch.Tensor]): Unused token type ids for compatibility. |
| mm_token_type_ids (Optional[torch.Tensor]): Unused multimodal token type ids. |
| labels (Optional[torch.LongTensor]): Labels for computing cross-entropy loss. |
| return_dict (Optional[bool]): Whether to return a dict-like output. |
| |
| Returns: |
| Union[Tuple[torch.Tensor, ...], Yasa2ForConditionalGenerationModelOutput]: Model logits, caches, and optional loss. |
| """ |
| return_dict = ( |
| return_dict |
| if return_dict is not None |
| else self.config.use_return_dict |
| ) |
|
|
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| pixel_values=pixel_values, |
| patch_attention_mask=patch_attention_mask, |
| return_dict=True, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| labels = labels.to(logits.device) |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:] |
| loss_fct = nn.CrossEntropyLoss( |
| ignore_index=self.config.label_ignore_index |
| ) |
| loss = loss_fct( |
| shift_logits.reshape(-1, shift_logits.size(-1)), |
| shift_labels.reshape(-1), |
| ) |
|
|
| return Yasa2ForConditionalGenerationModelOutput( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| vision_hidden_states=outputs.vision_hidden_states, |
| language_model_outputs=outputs, |
| ) |
|
|
| def generate( |
| self, |
| input_ids: Optional[torch.LongTensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| patch_attention_mask: Optional[torch.Tensor] = None, |
| **generate_kwargs, |
| ) -> torch.LongTensor: |
| """Generate text tokens conditioned on vision and/or language inputs. |
| |
| Args: |
| input_ids (Optional[torch.LongTensor]): Seed language tokens. |
| attention_mask (Optional[torch.Tensor]): Language attention mask. |
| pixel_values (Optional[torch.Tensor]): Vision inputs appended to prompts. |
| patch_attention_mask (Optional[torch.Tensor]): Mask for vision patches. |
| **generate_kwargs: Additional generation options forwarded to the `super().generate`. |
| |
| Returns: |
| torch.LongTensor: Generated token IDs. |
| """ |
| return super().generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| pixel_values=pixel_values, |
| patch_attention_mask=patch_attention_mask, |
| **generate_kwargs, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[ |
| Union[Cache, Tuple[Tuple[torch.FloatTensor]]] |
| ] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| pixel_values: Optional[torch.Tensor] = None, |
| patch_attention_mask: Optional[torch.Tensor] = None, |
| **kwargs: Any, |
| ) -> Dict[str, Any]: |
| """Prepare multimodal inputs for generation bookkeeping. |
| |
| Args: |
| input_ids (torch.LongTensor): Current token IDs for generation. |
| past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached past key/value tensors. |
| inputs_embeds (Optional[torch.FloatTensor]): Optional token embeddings. |
| attention_mask (Optional[torch.Tensor]): Language attention mask. |
| cache_position (Optional[torch.LongTensor]): Cache alignment positions. |
| pixel_values (Optional[torch.Tensor]): Vision inputs that should be reused. |
| patch_attention_mask (Optional[torch.Tensor]): Vision patch mask for the prefill step. |
| **kwargs: Additional arguments forwarded to the base implementation. |
| |
| Returns: |
| Dict[str, Any]: Prepared inputs for the next generation step. |
| """ |
| model_inputs = super().prepare_inputs_for_generation( |
| input_ids=input_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| is_prefill = past_key_values is None or ( |
| cache_position is not None and cache_position[0] == 0 |
| ) |
| if is_prefill: |
| model_inputs["pixel_values"] = pixel_values |
| model_inputs["patch_attention_mask"] = patch_attention_mask |
|
|
| return model_inputs |
|
|
|
|
| Yasa2ForConditionalGeneration.register_for_auto_class( |
| "AutoModelForImageTextToText" |
| ) |
|
|