reka-edge-2603 / modeling_yasa2.py
donovanOng92's picture
upload
7d24555 verified
from __future__ import annotations
import dataclasses
import glob
from collections.abc import Callable
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union, cast
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from safetensors.torch import load_file as safetensors_load
from transformers import PretrainedConfig
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation.utils import GenerationMixin
from transformers.integrations import use_kernel_forward_from_hub
from transformers.masking_utils import create_causal_mask
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPooling,
)
from transformers.modeling_rope_utils import (
ROPE_INIT_FUNCTIONS,
dynamic_rope_update,
)
from transformers.modeling_utils import (
ALL_ATTENTION_FUNCTIONS,
PreTrainedModel,
)
from transformers.processing_utils import Unpack
from transformers.utils import (
ModelOutput,
TransformersKwargs,
auto_docstring,
can_return_tuple,
logging,
)
from transformers.utils.deprecation import deprecate_kwarg
try:
from transformers.utils.generic import check_model_inputs
except ImportError:
def check_model_inputs(*args, **kwargs):
def _wrap(fn):
return fn
return _wrap
from .configuration_yasa2 import ConvNextConfig, Yasa2Config, YasaConfig
logger = logging.get_logger(__name__)
# ---- Model outputs ----
@dataclasses.dataclass
class Yasa2ModelOutputWithPast(BaseModelOutputWithPast):
"""
Base class for Yasa2 model outputs with past key values.
Args:
last_hidden_state (`torch.FloatTensor`, *optional*):
Last hidden state of the model.
past_key_values (`Cache`, *optional*):
Cache of key/value tensors for each layer.
hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
Tuple of hidden states from the model.
attentions (`Tuple[torch.FloatTensor]`, *optional*):
Tuple of attention maps from the model.
"""
last_hidden_state: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
vision_hidden_states: Optional[torch.FloatTensor] = None
@dataclasses.dataclass
class Yasa2ForConditionalGenerationModelOutput(ModelOutput):
"""
Outputs for Yasa2 conditional generation.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True`):
Cache of key/value tensors for each layer.
hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
Tuple of hidden states from the language model.
attentions (`Tuple[torch.FloatTensor]`, *optional*):
Tuple of attention maps from the language model.
vision_hidden_states (`torch.FloatTensor`, *optional*):
Vision embeddings after projection and pooling.
language_model_outputs (`Yasa2ModelOutputWithPast`, *optional*):
The full language model outputs.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
vision_hidden_states: Optional[torch.FloatTensor] = None
language_model_outputs: Optional[Yasa2ModelOutputWithPast] = None
# ---- Utilities ----
def get_2d_sincos_pos_embed(
embed_dim: int, image_size: int | tuple[int, int]
) -> np.ndarray:
"""Generate 2D sincos positional embeddings for a vision grid.
Args:
embed_dim (int): Embedding dimension.
image_size (int | tuple[int, int]): Image size as an int or (height, width) tuple.
Returns:
np.ndarray: Positional embedding array of shape (H*W, embed_dim).
"""
if isinstance(image_size, int):
grid_h_size, grid_w_size = image_size, image_size
else:
grid_h_size, grid_w_size = image_size[0], image_size[1]
grid_h = np.arange(grid_h_size, dtype=np.float32)
grid_w = np.arange(grid_w_size, dtype=np.float32)
# Build a meshgrid of spatial coordinates to compute positional embeddings.
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: np.ndarray
) -> np.ndarray:
"""Generate 2D sincos positional embeddings from a coordinate grid.
Args:
embed_dim (int): Embedding dimension.
grid (np.ndarray): Grid array of shape (2, H, W).
Returns:
np.ndarray: Positional embedding array of shape (H, W, embed_dim).
"""
assert embed_dim % 2 == 0
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0])
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1])
emb = np.concatenate([emb_h, emb_w], axis=-1)
return emb
def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: np.ndarray
) -> np.ndarray:
"""Generate 1D sincos positional embeddings from a positional array.
Args:
embed_dim (int): Embedding dimension.
pos (np.ndarray): Position grid array for one dimension.
Returns:
np.ndarray: Positional embedding array with sin/cos features.
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega
out = np.einsum("hw,d->hwd", pos, omega)
emb_sin = np.sin(out)
emb_cos = np.cos(out)
emb = np.concatenate([emb_sin, emb_cos], axis=-1)
return emb
# ---- ConvNeXt V2 backbone ----
def drop_path(
input: torch.Tensor, drop_prob: float = 0.0, training: bool = False
) -> torch.Tensor:
"""Apply stochastic depth (drop path) to the input tensor.
Args:
input (torch.Tensor): Input tensor to apply drop path to.
drop_prob (float): Probability of dropping a path. Defaults to 0.0.
training (bool): Whether the model runs in training mode. Defaults to False.
Returns:
torch.Tensor: Tensor with drop path applied when enabled.
"""
if drop_prob == 0.0 or not training:
return input
keep_prob = 1 - drop_prob
shape = (input.shape[0],) + (1,) * (input.ndim - 1)
# Sample a random tensor that determines which paths to keep per sample.
random_tensor = keep_prob + torch.rand(
shape, dtype=input.dtype, device=input.device
)
random_tensor.floor_()
output = input.div(keep_prob) * random_tensor
return output
class ConvNextDropPath(nn.Module):
"""Drop paths (stochastic depth) per sample in residual blocks."""
def __init__(self, drop_prob: Optional[float] = None):
"""Initialize the drop-path module.
Args:
drop_prob (Optional[float]): Probability of dropping a path.
"""
super().__init__()
self.drop_prob = drop_prob
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""Apply drop path to the provided hidden states.
Args:
hidden_states (torch.Tensor): Tensor to apply stochastic depth to.
Returns:
torch.Tensor: Tensor after stochastic depth.
"""
return drop_path(hidden_states, self.drop_prob, self.training)
def extra_repr(self) -> str:
"""Return a string representation for module printing.
Returns:
str: Description containing the configured drop probability.
"""
return "p={}".format(self.drop_prob)
class ConvNextLayerNorm(nn.Module):
r"""LayerNorm that supports channels_last (default) or channels_first.
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
"""
def __init__(
self,
normalized_shape: int,
eps: float = 1e-6,
data_format: str = "channels_last",
) -> None:
"""Initialize ConvNext LayerNorm.
Args:
normalized_shape (int): Expected shape of the input channels.
eps (float): Small epsilon to avoid division by zero.
data_format (str): Either 'channels_last' or 'channels_first'.
Raises:
NotImplementedError: If data_format is not supported.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError(
f"Unsupported data format: {self.data_format}"
)
self.normalized_shape = (normalized_shape,)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply layer normalization according to the configured data format.
Args:
x (torch.Tensor): Input tensor of shape (N, C, H, W) or (N, H, W, C).
Returns:
torch.Tensor: Normalized tensor with the same shape as input.
"""
if self.data_format == "channels_last":
x = nn.functional.layer_norm(
x, self.normalized_shape, self.weight, self.bias, self.eps
)
elif self.data_format == "channels_first":
input_dtype = x.dtype
x = x.float()
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
# Compute normalized values in fp32 for stable statistics before restoring dtype.
x = (x - u) / torch.sqrt(s + self.eps)
x = x.to(dtype=input_dtype)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class ConvNextV2GRN(nn.Module):
"""Global Response Normalization (GRN) layer for ConvNeXt V2."""
def __init__(self, dim: int):
"""Initialize the GRN layer parameters.
Args:
dim (int): Channel dimension of the input tensor.
"""
super().__init__()
self.weight = nn.Parameter(torch.zeros(1, 1, 1, dim))
self.bias = nn.Parameter(torch.zeros(1, 1, 1, dim))
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
"""Apply Global Response Normalization to the hidden states.
Args:
hidden_states (torch.FloatTensor): Input tensor shaped (batch, height, width, channels).
Returns:
torch.FloatTensor: Normalized tensor with the same shape.
"""
# Compute and normalize global spatial feature maps
global_features = torch.norm(
hidden_states, p=2, dim=(1, 2), keepdim=True
)
norm_features = global_features / (
global_features.mean(dim=-1, keepdim=True) + 1e-6
)
# Combine normalized features with learnable scale and bias.
hidden_states = (
self.weight * (hidden_states * norm_features)
+ self.bias
+ hidden_states
)
return hidden_states
class ConvNextEmbeddings(nn.Module):
"""ConvNeXt patch embedding layer."""
def __init__(
self, num_channels: int = 3, hidden_size: int = 96, patch_size: int = 4
) -> None:
"""Initialize ConvNeXt patch embeddings.
Args:
num_channels (int): Number of image channels. Defaults to 3.
hidden_size (int): Hidden dimension size. Defaults to 96.
patch_size (int): Size of patches for initial convolution. Defaults to 4.
"""
super().__init__()
self.patch_embeddings = nn.Conv2d(
num_channels,
hidden_size,
kernel_size=patch_size,
stride=patch_size,
)
self.layernorm = ConvNextLayerNorm(
hidden_size, eps=1e-6, data_format="channels_first"
)
self.num_channels = num_channels
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
"""Create patch embeddings from pixel values.
Args:
pixel_values (torch.FloatTensor): Image tensor shaped (batch, channels, height, width).
Returns:
torch.Tensor: Embedded tensor after patch convolution.
Raises:
ValueError: If the channel dimension does not match the expected count.
"""
num_channels = pixel_values.shape[1]
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
)
embeddings = self.patch_embeddings(pixel_values)
embeddings = self.layernorm(embeddings)
return embeddings
class ConvNextLayer(nn.Module):
"""ConvNeXt V2 layer with GRN."""
def __init__(
self,
dim: int,
drop_path: float = 0,
layer_scale_init_value: float = 1e-6,
use_grn: bool = True,
) -> None:
"""Construct a ConvNeXt V2 layer with GRN and scaling.
Args:
dim (int): Input/output channel dimension.
drop_path (float): Drop path probability for stochastic depth.
layer_scale_init_value (float): Initial scaling factor for residual branches.
use_grn (bool): Whether to enable Global Response Normalization.
"""
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
self.layernorm = ConvNextLayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim)
self.act = nn.GELU()
if not use_grn:
raise ValueError("ConvNeXt V2 requires use_grn=True.")
self.grn = ConvNextV2GRN(4 * dim)
self.pwconv2 = nn.Linear(4 * dim, dim)
self.layer_scale_parameter = (
nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True
)
if layer_scale_init_value > 0
else None
)
self.drop_path = (
ConvNextDropPath(drop_path) if drop_path > 0.0 else nn.Identity()
)
def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
"""Run the ConvNeXt layer forward.
Args:
hidden_states (torch.FloatTensor): Input tensor shaped (batch, channels, height, width).
Returns:
torch.Tensor: Tensor after depthwise conv, GRN, and residual connection.
"""
input = hidden_states
x = self.dwconv(hidden_states)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.layernorm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
if self.layer_scale_parameter is not None:
x = self.layer_scale_parameter * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = input + self.drop_path(x)
return x
class ConvNextStage(nn.Module):
"""ConvNeXt V2 stage with optional downsampling and residual blocks."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 2,
stride: int = 2,
depth: int = 2,
drop_path_rates: Optional[list[float]] = None,
layer_scale_init_value: float = 1e-6,
use_grn: bool = True,
) -> None:
"""Build a ConvNeXt stage that can downsample and stack layers.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_size (int): Kernel size for stripe downsampling.
stride (int): Stride for downsampling.
depth (int): Number of layers in the stage.
drop_path_rates (Optional[list[float]]): Per-layer drop path rates.
layer_scale_init_value (float): Residual scaling initial value.
use_grn (bool): Whether to enable GRN.
"""
super().__init__()
if in_channels != out_channels or stride > 1:
self.downsampling_layer = nn.Sequential(
ConvNextLayerNorm(
in_channels, eps=1e-6, data_format="channels_first"
),
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
),
)
else:
self.downsampling_layer = nn.Identity()
drop_path_rates = drop_path_rates or [0.0] * depth
self.layers = nn.Sequential(
*[
ConvNextLayer(
dim=out_channels,
drop_path=drop_path_rates[j],
layer_scale_init_value=layer_scale_init_value,
use_grn=use_grn,
)
for j in range(depth)
]
)
def forward(self, hidden_states: torch.FloatTensor) -> torch.Tensor:
"""Process a batch through downsampling and residual layers.
Args:
hidden_states (torch.FloatTensor): Input tensor of shape (batch, channels, height, width).
Returns:
torch.Tensor: Output tensor after the stage.
"""
hidden_states = self.downsampling_layer(hidden_states)
hidden_states = self.layers(hidden_states)
return hidden_states
class ConvNextEncoder(nn.Module):
"""ConvNeXt V2 encoder."""
def __init__(
self,
hidden_sizes: list[int],
depths: list[int],
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-6,
use_grn: bool = True,
) -> None:
"""Construct the ConvNeXt encoder with multiple stages.
Args:
hidden_sizes (list[int]): Hidden dimensions per stage.
depths (list[int]): Number of layers per stage.
drop_path_rate (float): Maximum drop path rate (linear schedule).
layer_scale_init_value (float): Initial residual scaling.
use_grn (bool): Whether to use GRN within layers.
"""
super().__init__()
self.stages = nn.ModuleList()
self.gradient_checkpointing = False
num_stages = len(hidden_sizes)
total_depth = sum(depths)
drop_path_schedule = np.linspace(
0.0, float(drop_path_rate), total_depth
).tolist()
drop_path_rates = []
start = 0
for depth in depths:
end = start + depth
drop_path_rates.append(drop_path_schedule[start:end])
start = end
# Keep track of the previous stage channel count for connecting stages.
prev_chs = hidden_sizes[0]
for i in range(num_stages):
out_chs = hidden_sizes[i]
stage = ConvNextStage(
in_channels=prev_chs,
out_channels=out_chs,
stride=2 if i > 0 else 1,
depth=depths[i],
drop_path_rates=drop_path_rates[i],
layer_scale_init_value=layer_scale_init_value,
use_grn=use_grn,
)
self.stages.append(stage)
prev_chs = out_chs
def forward(
self,
hidden_states: torch.FloatTensor,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Tuple:
"""Forward propagate through the ConvNeXt encoder stack.
Args:
hidden_states (torch.FloatTensor): Input tensor shaped (batch, channels, height, width).
output_hidden_states (Optional[bool]): Whether to collect intermediate states.
return_dict (Optional[bool]): Whether to return tuple or dict-like output.
Returns:
Tuple: Last hidden state followed by optional hidden states tuple.
"""
all_hidden_states = () if output_hidden_states else None
for i, layer_module in enumerate(self.stages):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states = torch.utils.checkpoint.checkpoint(
layer_module,
hidden_states,
use_reentrant=False,
)
else:
hidden_states = layer_module(hidden_states)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, all_hidden_states] if v is not None
)
return (hidden_states, all_hidden_states)
class ConvNextModel(nn.Module):
"""ConvNeXt V2 model."""
def __init__(
self,
hidden_sizes: list[int],
depths: list[int],
num_channels: int = 3,
patch_size: int = 4,
drop_path_rate: float = 0.0,
layer_scale_init_value: float = 1e-6,
use_grn: bool = True,
) -> None:
"""Build the ConvNeXt V2 model with embedding, encoder, and pooling.
Args:
hidden_sizes (list[int]): Hidden channel sizes per stage.
depths (list[int]): Layer counts per stage.
num_channels (int): Number of image channels.
patch_size (int): Patch size for initial embedding.
drop_path_rate (float): Drop path rate range for residual blocks.
layer_scale_init_value (float): Initial scale for residuals.
use_grn (bool): Whether to enable GRN.
"""
super().__init__()
if not use_grn:
raise ValueError("ConvNeXt V2 requires use_grn=True.")
self.embeddings = ConvNextEmbeddings(
num_channels, hidden_sizes[0], patch_size
)
self.encoder = ConvNextEncoder(
hidden_sizes,
depths,
drop_path_rate,
layer_scale_init_value,
use_grn,
)
self.layernorm = nn.LayerNorm(hidden_sizes[-1], eps=1e-6)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module) -> None:
"""Initialize module weights following standard ConvNeXt heuristics.
Args:
module (nn.Module): Module to initialize.
"""
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = True,
return_pooled: bool = True,
) -> Union[Tuple, BaseModelOutputWithPooling]:
"""Encode images and optionally return pooled features.
Args:
pixel_values (Optional[torch.FloatTensor]): Input tensor shaped (batch, channels, height, width).
output_hidden_states (Optional[bool]): Whether to return intermediate hidden states.
return_dict (Optional[bool]): Whether to return output as BaseModelOutput.
return_pooled (bool): Whether to include pooled output.
Returns:
Union[Tuple, BaseModelOutputWithPooling]: Model outputs containing last hidden states and optionally pooled output.
Raises:
ValueError: If `pixel_values` is None.
"""
if pixel_values is None:
raise ValueError("You have to specify pixel_values")
embedding_output = self.embeddings(pixel_values)
encoder_outputs = self.encoder(
embedding_output,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
all_hidden_states = (
encoder_outputs[1] if output_hidden_states else None
)
# Skip pooled output when callers only need token features.
pooled_output = None
if return_pooled:
# Global average pooling, (N, C, H, W) -> (N, C).
pooled_output = self.layernorm(last_hidden_state.mean([-2, -1]))
if not return_dict:
outputs = [last_hidden_state]
if return_pooled:
outputs.append(pooled_output)
if output_hidden_states:
outputs.append(all_hidden_states)
return tuple(outputs)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=all_hidden_states,
)
@staticmethod
def from_pretrained(model_path: Path | str) -> "ConvNextModel":
"""Load ConvNeXt model weights from a pretrained checkpoint directory.
Args:
model_path (Path | str): Directory path containing the checkpoint files.
Returns:
ConvNextModel: Initialized model with weights loaded from checkpoint.
Raises:
NotImplementedError: If config.json is missing in the directory.
FileNotFoundError: If no weight file is found.
"""
model_path_str = str(model_path)
model_path_obj = Path(model_path_str)
# Check if this is a HuggingFace model path
is_ckpt_dir = (
model_path_obj.is_dir()
and (model_path_obj / "config.json").exists()
)
if not is_ckpt_dir:
raise NotImplementedError(
"The checkpoint path should be a directory containing config.json "
"and model.safetensors or pytorch_model.bin files."
)
# Load configuration
config = ConvNextConfig.from_pretrained(model_path_str)
checkpoint_dir = model_path_obj
# Create our model directly
if not config.use_grn:
raise ValueError(
"ConvNeXt V2 requires use_grn=True in the checkpoint config."
)
logger.info(
"Loading ConvNeXt V2 model from checkpoint: %s", checkpoint_dir
)
model = ConvNextModel(
hidden_sizes=config.hidden_sizes,
depths=config.depths,
num_channels=config.num_channels,
patch_size=config.patch_size,
drop_path_rate=config.drop_path_rate,
layer_scale_init_value=config.layer_scale_init_value,
use_grn=config.use_grn,
)
# Load state dict from checkpoint files
state_dict = {}
# Try to load from safetensors first (preferred)
safetensors_file = checkpoint_dir / "model.safetensors"
if safetensors_file.exists():
logger.info("Loading weights from %s", safetensors_file)
state_dict = safetensors_load(str(safetensors_file))
else:
# Try pytorch_model.bin
pytorch_file = checkpoint_dir / "pytorch_model.bin"
if pytorch_file.exists():
logger.info("Loading weights from %s", pytorch_file)
state_dict = torch.load(
str(pytorch_file), map_location="cpu", weights_only=False
)
else:
# Try sharded checkpoints
shard_files = sorted(
glob.glob(str(checkpoint_dir / "pytorch_model-*.bin"))
)
if shard_files:
logger.info(
"Loading weights from %s sharded files",
len(shard_files),
)
for shard_file in shard_files:
state_dict.update(
torch.load(
shard_file,
map_location="cpu",
weights_only=False,
)
)
else:
raise FileNotFoundError(
f"Could not find model weights in {checkpoint_dir}. "
"Expected model.safetensors, pytorch_model.bin, or pytorch_model-*.bin files."
)
# Load the mapped state dict into our model
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=False
)
if missing_keys:
logger.warning(
"Some weights of the model were not initialized from the checkpoint "
"and are newly initialized: %s",
missing_keys,
)
if unexpected_keys:
logger.warning(
"Some weights of the checkpoint were not used when initializing the model: %s",
unexpected_keys,
)
return model
class ConvNextVisionModel(nn.Module):
"""Vision model wrapper around ConvNeXt V2 backbone."""
def __init__(self, config: Optional[ConvNextConfig] = None):
"""Wrap ConvNeXt backbone for use within the multimodal stack.
Args:
config (Optional[ConvNextConfig]): Configuration for the ConvNeXt backbone.
Raises:
ValueError: If the config lacks required ConvNeXt attributes.
"""
super().__init__()
if config is None:
config = ConvNextConfig.convnext_large()
self.config = config
# Support both HuggingFace config and ensure we extract the right parameters
if hasattr(config, "hidden_sizes"):
# HuggingFace-style config
hidden_sizes = config.hidden_sizes
depths = config.depths
num_channels = config.num_channels
patch_size = config.patch_size
drop_path_rate = config.drop_path_rate
layer_scale_init_value = config.layer_scale_init_value
use_grn = config.use_grn
else:
raise ValueError("Config must be a ConvNextConfig")
if not use_grn:
raise ValueError("ConvNeXt V2 requires use_grn=True.")
self.backbone = ConvNextModel(
hidden_sizes=hidden_sizes,
depths=depths,
num_channels=num_channels,
patch_size=patch_size,
drop_path_rate=drop_path_rate,
layer_scale_init_value=layer_scale_init_value,
use_grn=use_grn,
)
@staticmethod
def from_pretrained(model_path: Path | str) -> "ConvNextVisionModel":
"""Load a vision wrapper with pretrained ConvNeXt weights.
Args:
model_path (Path | str): Directory path containing the pretrained weights.
Returns:
ConvNextVisionModel: Wrapper instance with backbone weights loaded.
"""
# Load the backbone model
backbone = ConvNextModel.from_pretrained(model_path)
config = ConvNextConfig.from_pretrained(str(model_path))
wrapper = ConvNextVisionModel(config)
wrapper.backbone = backbone
return wrapper
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: bool = True,
patch_attention_mask: Optional[torch.Tensor] = None,
return_pooled: bool = True,
) -> Union[Tuple, BaseModelOutputWithPooling]:
"""Encode pixel values and reformat the ConvNeXt output.
Args:
pixel_values (torch.FloatTensor): Input tensor shaped (batch, channels, height, width).
output_attentions (Optional[bool]): Ignored but present for compatibility.
output_hidden_states (Optional[bool]): Whether to return staged hidden states.
return_dict (bool): Whether to return `BaseModelOutputWithPooling`.
patch_attention_mask (Optional[torch.Tensor]): Mask for patch tokens (unused here).
return_pooled (bool): Whether to request pooled output.
Returns:
Union[Tuple, BaseModelOutputWithPooling]: Vision outputs in sequence format.
"""
# Avoid pooled output unless requested to reduce extra work.
outputs = self.backbone(
pixel_values,
output_hidden_states=output_hidden_states,
return_dict=True,
return_pooled=return_pooled,
)
outputs = cast(BaseModelOutputWithPooling, outputs)
last_hidden_state = outputs.last_hidden_state # (b, c, h, w)
pooled = outputs.pooler_output if return_pooled else None
# Convert to sequence format: (b, c, h, w) -> (b, h*w, c)
last_hidden_state = rearrange(
last_hidden_state, "b c h w -> b (h w) c"
)
if return_dict:
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled,
hidden_states=(
outputs.hidden_states if output_hidden_states else None
),
)
if output_hidden_states:
outputs_tuple = [last_hidden_state]
if return_pooled:
outputs_tuple.append(pooled)
outputs_tuple.append(outputs.hidden_states)
return tuple(outputs_tuple)
if return_pooled:
return (last_hidden_state, pooled)
return (last_hidden_state,)
# ---- Yasa language model utilities (inlined) ----
@use_kernel_forward_from_hub("RMSNorm")
class YasaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
YasaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class YasaRotaryEmbedding(nn.Module):
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, config: YasaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(
config.rope_scaling, dict
):
self.rope_type = config.rope_scaling.get(
"rope_type", config.rope_scaling.get("type")
)
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(
self.config, device
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = (
self.inv_freq[None, :, None]
.float()
.expand(position_ids.shape[0], -1, 1)
.to(x.device)
)
position_ids_expanded = position_ids[:, None, :].float()
device_type = (
x.device.type
if isinstance(x.device.type, str) and x.device.type != "mps"
else "cpu"
)
with torch.autocast(
device_type=device_type, enabled=False
): # Force float32
freqs = (
inv_freq_expanded.float() @ position_ids_expanded.float()
).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`, *optional*):
Deprecated and unused.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class YasaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(
self.act_fn(self.gate_proj(x)) * self.up_proj(x)
)
return down_proj
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(
batch, num_key_value_heads * n_rep, slen, head_dim
)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs: Unpack[TransformersKwargs],
):
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=dropout, training=module.training
)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class YasaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: YasaConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config,
"head_dim",
config.hidden_size // config.num_attention_heads,
)
self.num_key_value_groups = (
config.num_attention_heads // config.num_key_value_heads
)
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
self.q_proj = nn.Linear(
config.hidden_size,
config.num_attention_heads * self.head_dim,
bias=config.attention_bias,
)
self.k_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
config.hidden_size,
config.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
config.num_attention_heads * self.head_dim,
config.hidden_size,
bias=config.attention_bias,
)
@deprecate_kwarg(
"past_key_value", new_name="past_key_values", version="4.58"
)
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor, torch.Tensor]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = (
self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
)
key_states = (
self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
)
value_states = (
self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {
"sin": sin,
"cos": cos,
"cache_position": cache_position,
}
key_states, value_states = past_key_values.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
self.config._attn_implementation
]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class YasaDecoderLayer(GradientCheckpointingLayer):
def __init__(self, config: YasaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = YasaAttention(config=config, layer_idx=layer_idx)
self.mlp = YasaMLP(config)
self.input_layernorm = YasaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_attention_layernorm = YasaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
@deprecate_kwarg(
"past_key_value", new_name="past_key_values", version="4.58"
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
tuple[torch.Tensor, torch.Tensor]
] = None, # necessary, but kept here for BC
**kwargs: Unpack[TransformersKwargs],
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class YasaPreTrainedModel(PreTrainedModel):
config = Yasa2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["YasaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
_supports_flex_attn = True
_can_compile_fullgraph = True
_supports_attention_backend = True
_can_record_outputs = {
"hidden_states": YasaDecoderLayer,
"attentions": YasaAttention,
}
@auto_docstring
class YasaModel(YasaPreTrainedModel):
def __init__(self, config: YasaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
YasaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = YasaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = YasaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
@check_model_inputs()
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
**kwargs: Unpack[TransformersKwargs],
) -> BaseModelOutputWithPast:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
if inputs_embeds is None:
inputs_embeds: torch.Tensor = self.embed_tokens(input_ids)
if use_cache and past_key_values is None:
past_key_values = DynamicCache(config=self.config)
if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length()
if past_key_values is not None
else 0
)
cache_position: torch.Tensor = (
torch.arange(
inputs_embeds.shape[1], device=inputs_embeds.device
)
+ past_seen_tokens
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = create_causal_mask(
config=self.config,
input_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
past_key_values=past_key_values,
position_ids=position_ids,
)
hidden_states = inputs_embeds
position_embeddings = self.rotary_emb(
hidden_states, position_ids=position_ids
)
for decoder_layer in self.layers[: self.config.num_hidden_layers]:
hidden_states = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_embeddings=position_embeddings,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
hidden_states = self.norm(hidden_states)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=past_key_values,
)
class Yasa2Model(YasaPreTrainedModel):
"""Pretrained base class that holds the full Yasa2 multimodal stack."""
config_class: PretrainedConfig = Yasa2Config
base_model_prefix: str = ""
_checkpoint_conversion_mapping: Dict[str, str] = {}
_no_split_modules = ["YasaDecoderLayer", "ConvNextVisionModel"]
config: Yasa2Config
def __init__(
self,
config: Yasa2Config,
):
"""Initialize the full Yasa2 multimodal stack.
Args:
config (Yasa2Config): Configuration for the multimodal model.
"""
super().__init__(config)
self.vision_pooling = config.vision_pooling
if self.vision_pooling != "adaptive_avg":
raise ValueError(
f"Yasa2 only supports adaptive_avg vision pooling, got {self.vision_pooling}"
)
self.adaptive_pooling = nn.AdaptiveAvgPool2d(
int(config.num_query_tokens**0.5)
)
if not (config.num_query_tokens**0.5).is_integer():
raise ValueError(
f"num_query_tokens {config.num_query_tokens} must be a "
"square number for adaptive_avg pooling"
)
# Set up vision backbone
vision_config = config.vision_config
if isinstance(vision_config, dict):
vision_config = ConvNextConfig(**vision_config)
self.vision_model = ConvNextVisionModel(vision_config)
self.language_projection = nn.Sequential(
nn.Linear(
config.vision_config.hidden_size,
config.text_config.hidden_size,
),
nn.GELU(),
nn.Linear(
config.text_config.hidden_size,
config.text_config.hidden_size,
),
)
# Set up language model
self.language_model = YasaModel(config.text_config)
# Store only the raw non-learned vision positional embedding data.
# Build device/dtype-specific tensors lazily in forward.
self.add_vision_pos_embed = config.use_vision_pos_embed
self._vision_pos_embed_np = get_2d_sincos_pos_embed(
config.vision_config.hidden_size,
image_size=50,
)
self._vision_pos_embed_cache: Dict[str, torch.Tensor] = {}
self.post_init()
def get_input_embeddings(self) -> torch.nn.Module:
"""Return the multimodal head's input embeddings.
Returns:
torch.nn.Module: Embedding module used by the language model.
"""
return self.language_model.get_input_embeddings()
def set_input_embeddings(self, value: torch.nn.Module) -> None:
"""Override the multimodal head's input embeddings.
Args:
value (torch.nn.Module): Embedding module to register.
"""
self.language_model.set_input_embeddings(value)
def set_decoder(self, decoder: YasaModel) -> None:
"""Proxy to set the multimodal model decoder.
Args:
decoder: Decoder to register with the multimodal model.
"""
self.language_model = decoder
def get_decoder(self) -> YasaModel:
"""Return the decoder component.
Returns:
YasaModel: Registered decoder module.
"""
return self.language_model
def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, torch.Tensor]:
"""Return a filtered state dict that omits derived or non-persistent buffers.
Args:
*args: Positional arguments forwarded to the superclass.
**kwargs: Keyword arguments forwarded to the superclass.
Returns:
Dict[str, torch.Tensor]: Filtered parameter mapping.
"""
state_dict = super().state_dict(*args, **kwargs)
for key in list(state_dict.keys()):
# masked_bias is a constant non-persistent attention buffer (-1e9).
if "attention.masked_bias" in key:
state_dict.pop(key, None)
continue
# rotary_emb.inv_freq is derived from rotary dims/base and rebuilt at init.
if "rotary_emb.inv_freq" in key:
state_dict.pop(key, None)
return state_dict
def _encode_vision_adaptive_2d_avg_pooling(
self,
pixel_values: torch.Tensor,
patch_attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Encode vision inputs via the ConvNeXt backbone and adaptive avg pooling.
Args:
pixel_values (torch.Tensor): Vision input tensor.
patch_attention_mask (Optional[torch.Tensor]): Optional patch mask.
Returns:
torch.Tensor: Vision embeddings projected into text hidden size.
"""
# Vision prefill only needs patch tokens; skip pooled output.
image_embeds = self.vision_model(
pixel_values=pixel_values,
output_attentions=None,
output_hidden_states=None,
return_dict=False,
patch_attention_mask=patch_attention_mask,
return_pooled=False,
)[0]
img_num, seq_length, vision_hidden_size = image_embeds.size()
height, width = int(seq_length**0.5), int(seq_length**0.5)
if self.add_vision_pos_embed:
vision_pos_embed = self._get_vision_pos_embed(
device=image_embeds.device,
dtype=image_embeds.dtype,
seq_len=image_embeds.size(1),
)
image_embeds = image_embeds + vision_pos_embed
image_embeds = image_embeds.permute(0, 2, 1).contiguous()
image_embeds = image_embeds.reshape(
img_num, vision_hidden_size, height, width
)
if (
self.config.apply_patch_attention_mask
and patch_attention_mask is not None
and patch_attention_mask.numel() > 0
):
patch_attention_mask = patch_attention_mask.reshape(
img_num, height, width
)
image_embeds = image_embeds * patch_attention_mask.unsqueeze(1).to(
dtype=image_embeds.dtype
)
# Force pooling in fp32 with autocast disabled; bf16 pooling can produce NaNs.
pooled_dtype = image_embeds.dtype
with torch.autocast(device_type="cuda", enabled=False):
image_embeds = torch.nn.functional.adaptive_avg_pool2d(
image_embeds.float(), self.adaptive_pooling.output_size
)
image_embeds = image_embeds.to(dtype=pooled_dtype)
image_embeds = image_embeds.flatten(2)
image_embeds = image_embeds.permute(0, 2, 1).contiguous()
vision_embeds = self.language_projection(image_embeds)
return vision_embeds
def _get_vision_pos_embed(
self,
device: torch.device,
dtype: torch.dtype,
seq_len: int,
) -> torch.Tensor:
"""Return cached/runtime-built vision positional embeddings."""
cache_key = f"{device}:{dtype}"
cached = self._vision_pos_embed_cache.get(cache_key)
if cached is None:
cached = (
torch.from_numpy(self._vision_pos_embed_np)
.view(-1, self.config.vision_config.hidden_size)
.to(device=device, dtype=dtype)
.unsqueeze(0)
)
self._vision_pos_embed_cache[cache_key] = cached
return cached[:, :seq_len, :]
def get_image_features(
self, pixel_values: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
"""Return vision features for vLLM compatibility."""
patch_attention_mask = kwargs.get("patch_attention_mask")
return self._encode_vision_adaptive_2d_avg_pooling(
pixel_values, patch_attention_mask=patch_attention_mask
)
@classmethod
def scatter_embeddings_to_target_special_id(
cls,
target_tensor: torch.Tensor,
target_input_ids: torch.Tensor,
src_embeddings: torch.Tensor,
special_token_id: int,
) -> torch.Tensor:
"""Scatter vision embeddings into the language embedding buffer at special tokens.
Args:
target_tensor (torch.Tensor): Target embedding buffer to update.
target_input_ids (torch.Tensor): Input IDs aligned with the target tensor.
src_embeddings (torch.Tensor): Source embeddings to scatter from vision outputs.
special_token_id (int): Token ID used to locate insertion positions.
Returns:
torch.Tensor: Updated target tensor with vision embeddings placed at special IDs.
"""
b_source, n_source, d_embedding = src_embeddings.shape
b_target, n_target, d_target = target_tensor.shape
if b_target != target_input_ids.size(0):
raise ValueError(
"Batch size mismatch: target_input_ids "
f"{target_input_ids.size(0)} vs target_tensor {b_target}"
)
if n_target != target_input_ids.size(1):
raise ValueError(
"Sequence length mismatch: target_input_ids "
f"{target_input_ids.size(1)} vs target_tensor {n_target}"
)
if d_embedding != d_target:
raise ValueError(
"Embedding dimension mismatch: src_embeddings "
f"{d_embedding} vs target_tensor {d_target}"
)
special_token_mask = target_input_ids.view(-1) == special_token_id
special_token_indices = torch.nonzero(special_token_mask).squeeze(-1)
if len(special_token_indices) != b_source * n_source:
raise ValueError(
"Special token count mismatch: found "
f"{len(special_token_indices)}, expected {b_source * n_source}"
)
target_tensor = target_tensor.view(-1, d_embedding)
src_embeddings = src_embeddings.view(-1, d_embedding)
target_tensor[special_token_indices] = src_embeddings
target_tensor = target_tensor.view(b_target, n_target, d_embedding)
return target_tensor
def _interleave_scatter(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
inputs_embeds: torch.Tensor,
vision_embeds: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Scatter vision embeddings into language embeddings at the image token positions.
Args:
input_ids (torch.Tensor): Token IDs containing image placeholders.
attention_mask (torch.Tensor): Attention mask for text tokens.
inputs_embeds (torch.Tensor): Language model input embeddings.
vision_embeds (torch.Tensor): Vision embeddings to be inserted.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Updated inputs_embeds and attention_mask.
"""
inputs_embeds = Yasa2Model.scatter_embeddings_to_target_special_id(
target_tensor=inputs_embeds,
target_input_ids=input_ids,
src_embeddings=vision_embeds,
special_token_id=self.config.image_token_id,
)
return inputs_embeds, attention_mask
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[
Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
patch_attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
mm_token_type_ids: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> Union[Tuple[torch.Tensor, ...], "Yasa2ModelOutputWithPast"]:
"""Forward pass combining language and vision inputs for Yasa2.
Args:
input_ids (Optional[torch.LongTensor]): Token IDs for the language model.
attention_mask (Optional[torch.Tensor]): Attention mask aligned with `input_ids`.
position_ids (Optional[torch.LongTensor]): Position indices feeding the language model.
inputs_embeds (Optional[torch.FloatTensor]): Precomputed token embeddings.
past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached decoder key/value tensors.
cache_position (Optional[torch.LongTensor]): Positions used for cache alignment.
use_cache (Optional[bool]): Whether to request cached key/values.
output_attentions (Optional[bool]): Whether to return attention weights.
output_hidden_states (Optional[bool]): Whether to return hidden states for each layer.
return_dict (Optional[bool]): Whether to return a `ModelOutput`.
pixel_values (Optional[torch.Tensor]): Vision inputs providing image context.
patch_attention_mask (Optional[torch.Tensor]): Optional patch mask for vision tokens.
token_type_ids (Optional[torch.Tensor]): Unused token type ids for compatibility.
mm_token_type_ids (Optional[torch.Tensor]): Unused multimodal token type ids.
Returns:
Union[Tuple[torch.Tensor, ...], Yasa2ModelOutputWithPast]: Combined multimodal outputs.
"""
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
use_cache = (
use_cache if use_cache is not None else self.config.use_cache
)
if input_ids is None and inputs_embeds is None:
raise ValueError(
"You must provide either input_ids or inputs_embeds."
)
if inputs_embeds is not None and pixel_values is not None:
raise ValueError(
"pixel_values cannot be used when inputs_embeds is provided."
)
if inputs_embeds is None:
inputs_embeds = self.language_model.get_input_embeddings()(
input_ids
)
if attention_mask is None:
pad_token_id = self.config.text_config.pad_token_id
if input_ids is not None and pad_token_id is not None:
if (input_ids == pad_token_id).any():
attention_mask = input_ids.ne(pad_token_id)
if attention_mask is not None:
if attention_mask.numel() == 0:
attention_mask = None
if cache_position is not None:
expected_len = inputs_embeds.shape[1]
if cache_position.shape[-1] != expected_len:
raise ValueError(
"cache_position length must match input sequence length: "
f"{cache_position.shape[-1]} vs {expected_len}"
)
vision_embeds = None
if pixel_values is not None and len(pixel_values) > 0:
if input_ids is None:
raise ValueError(
"input_ids is required when pixel_values is provided."
)
vision_embeds = self._encode_vision_adaptive_2d_avg_pooling(
pixel_values,
patch_attention_mask=patch_attention_mask,
)
inputs_embeds, attention_mask = self._interleave_scatter(
input_ids,
attention_mask,
inputs_embeds,
vision_embeds,
)
outputs = self.language_model(
input_ids=None,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
head_mask=None,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
return_dict=True,
**kwargs,
)
return Yasa2ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
vision_hidden_states=vision_embeds,
)
class Yasa2ForConditionalGeneration(YasaPreTrainedModel, GenerationMixin):
"""Yasa2 multimodal conditional generation model (vision + text)."""
config_class = Yasa2Config
_checkpoint_conversion_mapping = {}
_tied_weights_keys = [] # Weights are not tied
config: Yasa2Config
def __init__(self, config: Yasa2Config):
"""Initialize the Yasa2 conditional generation model.
Args:
config: Yasa2 configuration object.
"""
super().__init__(config)
self.model = Yasa2Model(config)
self.lm_head = nn.Linear(
config.hidden_size, config.vocab_size, bias=False
)
self.vocab_size = config.vocab_size
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> torch.nn.Module:
"""Return the multimodal head's input embeddings.
Returns:
torch.nn.Module: Embedding module used by the language model.
"""
return self.model.language_model.get_input_embeddings()
def set_input_embeddings(self, value: torch.nn.Module) -> None:
"""Override the multimodal head's input embeddings.
Args:
value (torch.nn.Module): Embedding module to register.
"""
self.model.language_model.set_input_embeddings(value)
def set_decoder(self, decoder):
"""Proxy to set the multimodal model decoder.
Args:
decoder: Decoder to register with the multimodal model.
"""
self.model.set_decoder(decoder)
def get_decoder(self):
"""Proxy to return the multimodal decoder."""
return self.model.get_decoder()
# Make modules available throught conditional class for BC
@property
def language_model(self) -> torch.nn.Module:
"""Expose the language model component.
Returns:
torch.nn.Module: Language model module.
"""
return self.model.language_model
@property
def vision_backbone(self) -> torch.nn.Module:
"""Expose the vision encoder backbone.
Returns:
torch.nn.Module: Vision backbone module.
"""
return self.model.vision_model
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[
Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
patch_attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
mm_token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None,
**kwargs: Any,
) -> Union[
Tuple[torch.Tensor, ...], "Yasa2ForConditionalGenerationModelOutput"
]:
"""Run the multimodal model, project outputs to logits, and compute loss if needed.
Args:
input_ids (Optional[torch.LongTensor]): Language token IDs.
attention_mask (Optional[torch.Tensor]): Attention mask for language tokens.
position_ids (Optional[torch.LongTensor]): Position indices.
past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached decoder states.
inputs_embeds (Optional[torch.FloatTensor]): Input embeddings instead of token IDs.
use_cache (Optional[bool]): Whether to cache key/value pairs.
output_attentions (Optional[bool]): Whether to return attention weights.
output_hidden_states (Optional[bool]): Whether to return hidden states.
cache_position (Optional[torch.LongTensor]): Positions used for caching.
pixel_values (Optional[torch.Tensor]): Vision inputs.
patch_attention_mask (Optional[torch.Tensor]): Optional mask for vision patches.
token_type_ids (Optional[torch.Tensor]): Unused token type ids for compatibility.
mm_token_type_ids (Optional[torch.Tensor]): Unused multimodal token type ids.
labels (Optional[torch.LongTensor]): Labels for computing cross-entropy loss.
return_dict (Optional[bool]): Whether to return a dict-like output.
Returns:
Union[Tuple[torch.Tensor, ...], Yasa2ForConditionalGenerationModelOutput]: Model logits, caches, and optional loss.
"""
return_dict = (
return_dict
if return_dict is not None
else self.config.use_return_dict
)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
return_dict=True,
**kwargs,
)
hidden_states = outputs.last_hidden_state
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
labels = labels.to(logits.device)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:]
loss_fct = nn.CrossEntropyLoss(
ignore_index=self.config.label_ignore_index
)
loss = loss_fct(
shift_logits.reshape(-1, shift_logits.size(-1)),
shift_labels.reshape(-1),
)
return Yasa2ForConditionalGenerationModelOutput(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
vision_hidden_states=outputs.vision_hidden_states,
language_model_outputs=outputs,
)
def generate(
self,
input_ids: Optional[torch.LongTensor],
attention_mask: Optional[torch.Tensor] = None,
pixel_values: Optional[torch.Tensor] = None,
patch_attention_mask: Optional[torch.Tensor] = None,
**generate_kwargs,
) -> torch.LongTensor:
"""Generate text tokens conditioned on vision and/or language inputs.
Args:
input_ids (Optional[torch.LongTensor]): Seed language tokens.
attention_mask (Optional[torch.Tensor]): Language attention mask.
pixel_values (Optional[torch.Tensor]): Vision inputs appended to prompts.
patch_attention_mask (Optional[torch.Tensor]): Mask for vision patches.
**generate_kwargs: Additional generation options forwarded to the `super().generate`.
Returns:
torch.LongTensor: Generated token IDs.
"""
return super().generate(
input_ids=input_ids,
attention_mask=attention_mask,
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
**generate_kwargs,
)
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor,
past_key_values: Optional[
Union[Cache, Tuple[Tuple[torch.FloatTensor]]]
] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.Tensor] = None,
patch_attention_mask: Optional[torch.Tensor] = None,
**kwargs: Any,
) -> Dict[str, Any]:
"""Prepare multimodal inputs for generation bookkeeping.
Args:
input_ids (torch.LongTensor): Current token IDs for generation.
past_key_values (Optional[Union[Cache, Tuple[Tuple[torch.FloatTensor]]]]): Cached past key/value tensors.
inputs_embeds (Optional[torch.FloatTensor]): Optional token embeddings.
attention_mask (Optional[torch.Tensor]): Language attention mask.
cache_position (Optional[torch.LongTensor]): Cache alignment positions.
pixel_values (Optional[torch.Tensor]): Vision inputs that should be reused.
patch_attention_mask (Optional[torch.Tensor]): Vision patch mask for the prefill step.
**kwargs: Additional arguments forwarded to the base implementation.
Returns:
Dict[str, Any]: Prepared inputs for the next generation step.
"""
model_inputs = super().prepare_inputs_for_generation(
input_ids=input_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
**kwargs,
)
is_prefill = past_key_values is None or (
cache_position is not None and cache_position[0] == 0
)
if is_prefill:
model_inputs["pixel_values"] = pixel_values
model_inputs["patch_attention_mask"] = patch_attention_mask
return model_inputs
Yasa2ForConditionalGeneration.register_for_auto_class(
"AutoModelForImageTextToText"
)