Image-Text-to-Text
Transformers
Safetensors
youtu_vl
text-generation
conversational
custom_code
Youtu-Parsing / modeling_siglip2.py
Yinsongliu's picture
Upload model with LFS assets
c13c3aa
import math
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Optional, Tuple, Union
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn.init import _calculate_fan_in_and_fan_out
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
# from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
can_return_tuple,
logging,
replace_return_docstrings,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
)
from .configuration_siglip2 import Siglip2Config, Siglip2TextConfig, Siglip2VisionConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "Siglip2Config"
is_aiter_available = False
if is_flash_attn_2_available():
try:
from aiter import flash_attn_varlen_func
is_aiter_available = True
except ImportError:
from flash_attn import flash_attn_varlen_func
from flash_attn.layers.rotary import apply_rotary_emb
else:
flash_attn_varlen_func = None
apply_rotary_emb = None
if is_flash_attn_2_available():
from transformers.modeling_flash_attention_utils import _flash_attention_forward
else:
flash_attn_varlen_func = None
@dataclass
class Siglip2VisionOutput(ModelOutput):
"""
Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
Args:
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`
*optional* returned when model is initialized with `with_projection=True`):
The image embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True`
is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or
when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
image_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Siglip2TextOutput(ModelOutput):
"""
Base class for text model's outputs that also contains a pooling of the last hidden states.
Args:
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)`
*optional* returned when model is initialized with `with_projection=True`):
The text embeddings obtained by applying the projection layer to the pooler_output.
last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned
when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed
or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
text_embeds: Optional[torch.FloatTensor] = None
last_hidden_state: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
@dataclass
class Siglip2Output(ModelOutput):
"""
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`Siglip2TextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to
the pooled output of [`Siglip2VisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`Siglip2TextModel`].
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`Siglip2VisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> Tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class Siglip2VisionEmbeddings(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
if hasattr(config, 'in_features') and config.in_features > 0:
self.in_features = config.in_features
else:
self.in_features = config.num_channels * self.patch_size * self.patch_size
self.patch_embedding = nn.Linear(
in_features=self.in_features,
out_features=self.embed_dim,
)
self.num_patches = config.num_patches
self.position_embedding_size = int(self.num_patches**0.5)
self.position_embedding = nn.Embedding(self.num_patches, self.embed_dim)
@staticmethod
def resize_positional_embeddings(
positional_embeddings: torch.Tensor,
spatial_shapes: torch.LongTensor,
max_length: int,
) -> torch.Tensor:
"""
Resize positional embeddings to image-specific size and pad to a fixed size.
Args:
positional_embeddings (`torch.Tensor`):
Position embeddings of shape (height, width, embed_dim)
spatial_shapes (`torch.LongTensor`):
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
max_length (`int`):
Maximum length of the positional embeddings to pad resized positional embeddings to
Returns:
`torch.Tensor`: Embeddings of shape (batch_size, max_length, embed_dim)
"""
batch_size = spatial_shapes.shape[0]
embed_dim = positional_embeddings.shape[-1]
source_dtype = positional_embeddings.dtype
resulted_positional_embeddings = torch.empty(
(batch_size, max_length, embed_dim),
device=positional_embeddings.device,
dtype=source_dtype,
)
positional_embeddings = positional_embeddings.permute(2, 0, 1).unsqueeze(0)
if positional_embeddings.device.type == "cpu":
positional_embeddings = positional_embeddings.to(torch.float32)
for i in range(batch_size):
height, width = spatial_shapes[i]
resized_embeddings = F.interpolate(
positional_embeddings,
size=(height, width),
mode="bilinear",
align_corners=False,
antialias=True,
)
resized_embeddings = resized_embeddings.reshape(embed_dim, height * width).transpose(0, 1)
resized_embeddings = resized_embeddings.to(source_dtype)
resulted_positional_embeddings[i, : height * width] = resized_embeddings
resulted_positional_embeddings[i, height * width :] = resized_embeddings[0]
return resulted_positional_embeddings
def forward(self, pixel_values: torch.FloatTensor, spatial_shapes: torch.LongTensor) -> torch.Tensor:
"""
Args:
pixel_values (`torch.FloatTensor`):
Pixel values of shape (batch_size, max_num_patches, num_channels * patch_size * patch_size)
spatial_shapes (`List[Tuple[int, int]]`):
Spatial shapes of shape (batch_size, 2) to resize the positional embeddings to
"""
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
positional_embeddings = self.position_embedding.weight.reshape(
self.position_embedding_size, self.position_embedding_size, -1
)
resized_positional_embeddings = self.resize_positional_embeddings(
positional_embeddings, spatial_shapes, max_length=pixel_values.shape[1]
)
embeddings = patch_embeds + resized_positional_embeddings
return embeddings
class Siglip2VisionEmbeddingsWoPos(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.patch_size = config.patch_size
if hasattr(config, 'in_features') and config.in_features > 0:
self.in_features = config.in_features
else:
self.in_features = config.num_channels * self.patch_size * self.patch_size
self.patch_embedding = nn.Linear(
in_features=self.in_features,
out_features=self.embed_dim,
)
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype))
patch_embeds = patch_embeds.view(-1, self.embed_dim)
return patch_embeds
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
attn_weights = torch.matmul(query, key.transpose(-1, -2)) * scaling
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
def apply_rotary_pos_emb_flashatt(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
cos = cos.chunk(2, dim=-1)[0].contiguous()
sin = sin.chunk(2, dim=-1)[0].contiguous()
q_embed = apply_rotary_emb(q.float(), cos.float(), sin.float()).type_as(q)
k_embed = apply_rotary_emb(k.float(), cos.float(), sin.float()).type_as(k)
return q_embed, k_embed
class Siglip2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
f" {self.num_heads})."
)
self.scale = self.head_dim**-0.5
self.dropout = config.attention_dropout
self.is_causal = False
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Input shape: Batch x Time x Channel"""
batch_size, seq_length, embed_dim = hidden_states.shape
queries = self.q_proj(hidden_states)
keys = self.k_proj(hidden_states)
values = self.v_proj(hidden_states)
queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and output_attentions:
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support"
"`output_attentions=True`. Falling back to 'eager attention. This warning"
'can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
queries,
keys,
values,
attention_mask,
is_causal=self.is_causal,
scaling=self.scale,
dropout=0.0 if not self.training else self.dropout,
)
attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous()
attn_output = self.out_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights
class Vision_FlashAttention2(nn.Module):
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]) -> None:
super().__init__()
dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.q_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
k = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
v = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
q = q.squeeze(0)
k = k.squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
if is_aiter_available:
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen, return_lse=True)[0].reshape(
seq_length, -1)
else:
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens,
max_seqlen, max_seqlen).reshape(
seq_length, -1)
attn_output = self.out_proj(attn_output)
return attn_output, None
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
class Vision_EagerAttention(nn.Module):
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]) -> None:
super().__init__()
dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.k_proj = nn.Linear(dim, dim)
self.v_proj = nn.Linear(dim, dim)
self.q_proj = nn.Linear(dim, dim)
self.out_proj = nn.Linear(dim, dim)
self.head_dim = dim // self.num_heads
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
q = self.q_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
k = self.k_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
v = self.v_proj(hidden_states).reshape(seq_length, self.num_heads, -1)
if position_embeddings is None:
logger.warning_once(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
"removed and `position_embeddings` will be mandatory."
)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
cos = emb.cos()
sin = emb.sin()
else:
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
attention_mask = torch.full(
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(0, 1)
attn_output = attn_output.reshape(seq_length, -1)
attn_output = self.out_proj(attn_output)
return attn_output, None
VISION_ATTENTION_CLASSES = {
'eager': Vision_EagerAttention,
'flash_attention_2': Vision_FlashAttention2,
}
class Siglip2MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states
# class Siglip2EncoderLayer(GradientCheckpointingLayer):
class Siglip2EncoderLayer(nn.Module):
def __init__(self, config: Union[Siglip2VisionConfig, Siglip2TextConfig]):
super().__init__()
self.embed_dim = config.hidden_size
self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.self_attn = VISION_ATTENTION_CLASSES[config._attn_implementation](config=config)
self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.FloatTensor]:
"""
Args:
hidden_states (`torch.FloatTensor`):
Input to the layer of shape `(batch, seq_len, embed_dim)`.
attention_mask (`torch.FloatTensor`):
Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements
are indicated by very large negative values.
output_attentions (`bool`, *optional*, defaults to `False`):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states, attn_weights = self.self_attn(
hidden_states=hidden_states,
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class VisionRope(nn.Module):
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(seq, self.inv_freq)
return freqs
class Siglip2Encoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
[`Siglip2EncoderLayer`].
Args:
config: Siglip2Config
"""
def __init__(self, config: Siglip2Config):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Siglip2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self.spatial_merge_size = 2
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.patch_size = config.patch_size
self.window_size = self.patch_size * 2 * 8
assert(config.hidden_size%(config.num_attention_heads*2) == 0)
self.rotary_pos_emb = VisionRope(config.hidden_size//config.num_attention_heads//2)
def rot_pos_emb(self, spatial_shapes):
pos_ids = []
for h, w in spatial_shapes:
t = 1
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = spatial_shapes.max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_window_index(self, spatial_shapes):
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
for grid_h, grid_w in spatial_shapes:
grid_t = 1
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
grid_w // self.spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = (vit_merger_window_size - llm_grid_h % vit_merger_window_size) % vit_merger_window_size
pad_w = (vit_merger_window_size - llm_grid_w % vit_merger_window_size) % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
@can_return_tuple
def forward(
self,
inputs_embeds,
spatial_shapes: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutput:
r"""
Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
hidden_states = inputs_embeds
rotary_pos_emb = self.rot_pos_emb(spatial_shapes)
window_index, cu_window_seqlens = self.get_window_index(spatial_shapes)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=spatial_shapes.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(spatial_shapes[:, 0] * spatial_shapes[:, 1], 1).cumsum(
dim=0,
dtype=spatial_shapes.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for layer_num, encoder_layer in enumerate(self.layers):
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if (1+layer_num) % 8 == 0 or layer_num == len(self.layers) - 1:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings
)
hidden_states = layer_outputs[0]
if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
SIGLIP2_VISION_INPUTS_DOCSTRING = r"""
Args:
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class Siglip2VisionTransformer(nn.Module):
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2VisionEmbeddingsWoPos(config)
self.encoder = Siglip2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.use_head = False
if self.use_head:
self.head = Siglip2MultiheadAttentionPoolingHead(config)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@can_return_tuple
@add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig)
def forward(
self,
pixel_values: torch.FloatTensor,
attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
bs, length, dim = pixel_values.shape
hidden_states = self.embeddings(pixel_values)
if attention_mask is not None and not self._use_flash_attention_2:
encoder_attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
else:
encoder_attention_mask = attention_mask
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
spatial_shapes=spatial_shapes,
attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.post_layernorm(last_hidden_state)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=None,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class Siglip2TextEmbeddings(nn.Module):
def __init__(self, config: Siglip2TextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
self.register_buffer(
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
max_position_embedding = self.position_embedding.weight.shape[0]
if seq_length > max_position_embedding:
raise ValueError(
f"Sequence length must be less than max_position_embeddings (got `sequence length`: "
f"{seq_length} and max_position_embeddings: {max_position_embedding}"
)
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
def _trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
tensor.uniform_(2 * l - 1, 2 * u - 1)
tensor.erfinv_()
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
tensor.clamp_(min=a, max=b)
def trunc_normal_tf_(
tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0
) -> torch.Tensor:
"""
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
"""
with torch.no_grad():
_trunc_normal_(tensor, 0, 1.0, a, b)
tensor.mul_(std).add_(mean)
def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
if mode == "fan_in":
denom = fan_in
elif mode == "fan_out":
denom = fan_out
elif mode == "fan_avg":
denom = (fan_in + fan_out) / 2
variance = scale / denom
if distribution == "truncated_normal":
trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
elif distribution == "normal":
with torch.no_grad():
tensor.normal_(std=math.sqrt(variance))
elif distribution == "uniform":
bound = math.sqrt(3 * variance)
with torch.no_grad():
tensor.uniform_(-bound, bound)
else:
raise ValueError(f"invalid distribution {distribution}")
def lecun_normal_(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
def default_flax_embed_init(tensor):
variance_scaling_(tensor, mode="fan_in", distribution="normal")
SIGLIP2_TEXT_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class Siglip2TextTransformer(nn.Module):
def __init__(self, config: Siglip2TextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = Siglip2TextEmbeddings(config)
self.encoder = Siglip2Encoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self.head = nn.Linear(embed_dim, config.projection_size)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
@can_return_tuple
@add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
if attention_mask is not None and not self._use_flash_attention_2:
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.final_layer_norm(last_hidden_state)
pooled_output = last_hidden_state[:, -1, :]
pooled_output = self.head(pooled_output)
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
SIGLIP2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Siglip2Config`]): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
SIGLIP2_INPUTS_DOCSTRING = r"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.max_position_embeddings - 1]`.
[What are position IDs?](../glossary#position-ids)
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
[`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
return_loss (`bool`, *optional*):
Whether or not to return the contrastive loss.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
class Siglip2PreTrainedModel(PreTrainedModel):
config_class = Siglip2Config
base_model_prefix = "siglip2"
supports_gradient_checkpointing = True
_no_split_modules = [
"Siglip2TextEmbeddings",
"Siglip2EncoderLayer",
"Siglip2VisionEmbeddings",
"Siglip2EncoderLayer",
"Siglip2MultiheadAttentionPoolingHead",
]
_supports_flash_attn_2 = True
_supports_sdpa = True
def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, Siglip2VisionEmbeddings):
width = (
self.config.vision_config.hidden_size
if isinstance(self.config, Siglip2Config)
else self.config.hidden_size
)
nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width))
elif isinstance(module, nn.Embedding):
default_flax_embed_init(module.weight)
elif isinstance(module, Siglip2Attention):
nn.init.xavier_uniform_(module.q_proj.weight)
nn.init.xavier_uniform_(module.k_proj.weight)
nn.init.xavier_uniform_(module.v_proj.weight)
nn.init.xavier_uniform_(module.out_proj.weight)
nn.init.zeros_(module.q_proj.bias)
nn.init.zeros_(module.k_proj.bias)
nn.init.zeros_(module.v_proj.bias)
nn.init.zeros_(module.out_proj.bias)
elif isinstance(module, Siglip2MLP):
nn.init.xavier_uniform_(module.fc1.weight)
nn.init.xavier_uniform_(module.fc2.weight)
nn.init.normal_(module.fc1.bias, std=1e-6)
nn.init.normal_(module.fc2.bias, std=1e-6)
elif isinstance(module, Siglip2MultiheadAttentionPoolingHead):
nn.init.xavier_uniform_(module.probe.data)
nn.init.xavier_uniform_(module.attention.in_proj_weight.data)
nn.init.zeros_(module.attention.in_proj_bias.data)
elif isinstance(module, Siglip2Model):
logit_scale_init = torch.log(torch.tensor(1.0))
module.logit_scale.data.fill_(logit_scale_init)
module.logit_bias.data.zero_()
elif isinstance(module, Siglip2ForImageClassification):
nn.init.normal_(
module.classifier.weight,
std=self.config.vision_config.hidden_size**-0.5 * self.config.initializer_factor,
)
elif isinstance(module, (nn.Linear, nn.Conv2d)):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
@add_start_docstrings(
"""The text model from Siglip2 without any head or projection on top.""",
SIGLIP2_START_DOCSTRING,
)
class Siglip2TextModel(Siglip2PreTrainedModel):
config_class = Siglip2TextConfig
def __init__(self, config: Siglip2TextConfig):
super().__init__(config)
self.text_model = Siglip2TextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.text_model.embeddings.token_embedding
def set_input_embeddings(self, value):
self.text_model.embeddings.token_embedding = value
@can_return_tuple
@add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2TextConfig)
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
Returns:
"""
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
class Siglip2MultiheadAttentionPoolingHead(nn.Module):
"""Multihead Attention Pooling."""
def __init__(self, config: Siglip2VisionConfig):
super().__init__()
self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.mlp = Siglip2MLP(config)
self.num_heads = config.num_attention_heads
def forward(self, hidden_state: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size = hidden_state.shape[0]
probe = self.probe.repeat(batch_size, 1, 1)
if attention_mask is not None:
target_len, source_len = probe.shape[1], hidden_state.shape[1]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_state.dtype, target_len)
attention_mask = attention_mask.repeat(1, self.num_heads, target_len, 1)
attention_mask = attention_mask.reshape(-1, target_len, source_len)
hidden_state = self.attention(probe, hidden_state, hidden_state, attn_mask=attention_mask)[0]
residual = hidden_state
hidden_state = self.layernorm(hidden_state)
hidden_state = residual + self.mlp(hidden_state)
return hidden_state[:, 0]
@add_start_docstrings(
"""The vision model from Siglip2 without any head or projection on top.""",
SIGLIP2_START_DOCSTRING,
)
class Siglip2VisionModel(Siglip2PreTrainedModel):
config_class = Siglip2VisionConfig
main_input_name = "pixel_values"
def __init__(self, config: Siglip2VisionConfig):
super().__init__(config)
self.vision_model = Siglip2VisionTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self) -> nn.Module:
return self.vision_model.embeddings.patch_embedding
@can_return_tuple
@add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=Siglip2VisionConfig)
def forward(
self,
pixel_values: torch.FloatTensor,
pixel_attention_mask: torch.Tensor,
spatial_shapes: torch.LongTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutputWithPooling:
r"""
Returns:
```"""
return self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
@add_start_docstrings(SIGLIP2_START_DOCSTRING)
class Siglip2Model(Siglip2PreTrainedModel):
config_class = Siglip2Config
def __init__(self, config: Siglip2Config):
super().__init__(config)
if not isinstance(config.text_config, Siglip2TextConfig):
raise TypeError(
"config.text_config is expected to be of type Siglip2TextConfig but is of type"
f" {type(config.text_config)}."
)
if not isinstance(config.vision_config, Siglip2VisionConfig):
raise TypeError(
"config.vision_config is expected to be of type Siglip2VisionConfig but is of type"
f" {type(config.vision_config)}."
)
text_config = config.text_config
vision_config = config.vision_config
text_model = Siglip2TextModel._from_config(text_config)
vision_model = Siglip2VisionModel._from_config(vision_config)
self.text_model = text_model.text_model
self.vision_model = vision_model.vision_model
self.logit_scale = nn.Parameter(torch.randn(1))
self.logit_bias = nn.Parameter(torch.randn(1))
self.post_init()
@add_start_docstrings_to_model_forward(SIGLIP2_TEXT_INPUTS_DOCSTRING)
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
applying the projection layer to the pooled output of [`Siglip2TextModel`].
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = text_outputs.pooler_output
return pooled_output
@add_start_docstrings_to_model_forward(SIGLIP2_VISION_INPUTS_DOCSTRING)
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_attention_mask: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
applying the projection layer to the pooled output of [`Siglip2VisionModel`].
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = vision_outputs.pooler_output
return pooled_output
@can_return_tuple
@add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Siglip2Output, config_class=Siglip2Config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_attention_mask: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> Siglip2Output:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
image_embeds = vision_outputs.pooler_output
text_embeds = text_outputs.pooler_output
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
logits_per_text = torch.matmul(text_embeds, image_embeds.t().to(text_embeds.device))
logit_scale, logit_bias = self.logit_scale.to(text_embeds.device), self.logit_bias.to(text_embeds.device)
logits_per_text = logits_per_text * logit_scale.exp() + logit_bias
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()
return Siglip2Output(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
@add_start_docstrings(
"""
Siglip2 vision encoder with an image classification head on top (a
linear layer on top of the pooled final hidden states of
the patch tokens) e.g. for ImageNet.
""",
SIGLIP2_START_DOCSTRING,
)
class Siglip2ForImageClassification(Siglip2PreTrainedModel):
main_input_name = "pixel_values"
def __init__(self, config: Siglip2Config) -> None:
super().__init__(config)
self.num_labels = config.num_labels
vision_model = Siglip2VisionModel._from_config(config.vision_config)
self.vision_model = vision_model.vision_model
self.classifier = (
nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
)
self.post_init()
@can_return_tuple
@add_start_docstrings_to_model_forward(SIGLIP2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
pixel_attention_mask: Optional[torch.Tensor] = None,
spatial_shapes: Optional[torch.LongTensor] = None,
labels: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> ImageClassifierOutput:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values,
attention_mask=pixel_attention_mask,
spatial_shapes=spatial_shapes,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs.last_hidden_state
if pixel_attention_mask is not None:
pool_mask = pixel_attention_mask[..., None].to(sequence_output.device)
sequence_output = torch.sum(sequence_output * pool_mask, dim=1) / torch.sum(pool_mask, dim=1)
else:
sequence_output = torch.mean(sequence_output, dim=1)
logits = self.classifier(sequence_output)
loss = None
if labels is not None:
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
return ImageClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
__all__ = [
"Siglip2Model",
"Siglip2PreTrainedModel",
"Siglip2TextModel",
"Siglip2VisionModel",
"Siglip2ForImageClassification",
]