HyperCLOVAX-SEED-CLIP / modeling_hyperclovax_seed_clip.py
bigshanedogg's picture
Upload folder using huggingface_hub
f2f8be1 verified
# coding=utf-8
# Copyright 2026 NAVER Cloud Corp. and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""HyperCLOVAX SEED CLIP model.
Architecture:
- Vision: HyperCLOVAXSeedCLIPVisionEncoder (spatio-temporal ViT, no merger)
+ post_layernorm + Siglip2MultiheadAttentionPoolingHead
- Text: SiglipTextTransformer (reused from HuggingFace transformers)
- Contrastive: logit_scale + logit_bias (SigLIP-style sigmoid contrastive loss)
Acknowledgements:
- Training objective based on SigLIP
(https://github.com/google-research/big_vision), Apache-2.0 License.
"""
from collections.abc import Callable
from dataclasses import dataclass
from typing import Optional, Tuple, Union, Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel
from transformers.models.siglip.modeling_siglip import SiglipTextModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.siglip2.configuration_siglip2 import Siglip2VisionConfig
from transformers.models.siglip2.modeling_siglip2 import Siglip2MultiheadAttentionPoolingHead
from transformers.utils import ModelOutput, auto_docstring
try:
from transformers.modeling_layers import GradientCheckpointingLayer
except ImportError:
class GradientCheckpointingLayer(nn.Module): # transformers < 4.46
pass
try:
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
except ImportError:
ALL_ATTENTION_FUNCTIONS = {} # transformers < 4.46
from .configuration_hyperclovax_seed_clip import HyperCLOVAXSeedCLIPConfig, HyperCLOVAXSeedCLIPVisionConfig
class HyperCLOVAXSeedVisionRMSNorm(nn.Module):
"""RMS normalisation layer."""
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb_vision(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings to query and key tensors."""
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep).
hidden_states: (batch, num_key_value_heads, seqlen, head_dim)
-> (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def eager_attention_forward(
module: nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],
scaling: float,
dropout: float = 0.0,
**kwargs,
):
"""Eager (non-fused) scaled dot-product attention, used as fallback."""
key_states = repeat_kv(key, module.num_key_value_groups)
value_states = repeat_kv(value, module.num_key_value_groups)
attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output, attn_weights
class HyperCLOVAXSeedVisionMLP(nn.Module):
"""SwiGLU MLP used inside each vision transformer block."""
def __init__(self, config: HyperCLOVAXSeedCLIPVisionConfig, bias: bool = False):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
class HyperCLOVAXSeedVisionPatchEmbed(nn.Module):
"""3D patch embedding for spatio-temporal inputs via Conv3d."""
def __init__(
self,
patch_size: int = 14,
temporal_patch_size: int = 2,
in_channels: int = 3,
embed_dim: int = 1152,
) -> None:
super().__init__()
self.patch_size = patch_size
self.temporal_patch_size = temporal_patch_size
self.in_channels = in_channels
self.embed_dim = embed_dim
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
hidden_states = hidden_states.view(
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
)
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
return hidden_states
class HyperCLOVAXSeedVisionRotaryEmbedding(nn.Module):
"""2D rotary position embedding for vision patches.
Recomputes ``inv_freq`` in ``forward`` to be robust against
``no_init_weights()`` zeroing in transformers 5.x (``persistent=False``).
"""
inv_freq: torch.Tensor # fix linting for `register_buffer`
def __init__(self, dim: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.theta = theta
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, seqlen: int) -> torch.Tensor:
# Recompute inv_freq on the fly: in transformers 5.x, no_init_weights() zeros out
# register_buffer values, and persistent=False means they aren't restored from checkpoint.
inv_freq = 1.0 / (self.theta ** (
torch.arange(0, self.dim, 2, dtype=torch.float, device=self.inv_freq.device) / self.dim
))
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=inv_freq.dtype)
freqs = torch.outer(seq, inv_freq)
return freqs
class HyperCLOVAXSeedVisionAttention(nn.Module):
"""Multi-head self-attention with 2D RoPE, supporting flash-attention and SDPA."""
def __init__(self, config: HyperCLOVAXSeedCLIPVisionConfig) -> None:
super().__init__()
self.dim = config.hidden_size
self.num_heads = config.num_heads
self.head_dim = self.dim // self.num_heads
self.num_key_value_groups = 1 # needed for eager attention
self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
self.proj = nn.Linear(self.dim, self.dim)
self.scaling = self.head_dim**-0.5
self.config = config
self.attention_dropout = 0.0
self.is_causal = False
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
seq_length = hidden_states.shape[0]
query_states, key_states, value_states = (
self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
query_states = query_states.transpose(0, 1).unsqueeze(0)
key_states = key_states.transpose(0, 1).unsqueeze(0)
value_states = value_states.transpose(0, 1).unsqueeze(0)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
if self.config._attn_implementation == "flash_attention_2":
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
cu_seq_lens_q=cu_seqlens,
cu_seq_lens_k=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
)
else:
lengths = cu_seqlens[1:] - cu_seqlens[:-1]
splits = [
torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
]
attn_outputs = [
attention_interface(
self,
q,
k,
v,
attention_mask=None,
scaling=self.scaling,
dropout=0.0 if not self.training else self.attention_dropout,
is_causal=False,
**kwargs,
)[0]
for q, k, v in zip(*splits)
]
attn_output = torch.cat(attn_outputs, dim=1)
attn_output = attn_output.reshape(seq_length, -1).contiguous()
attn_output = self.proj(attn_output)
return attn_output
class HyperCLOVAXSeedVisionBlock(GradientCheckpointingLayer):
"""Transformer block with window or full attention and fp16-safe MLP."""
def __init__(
self,
config: HyperCLOVAXSeedCLIPVisionConfig,
is_fullatt: bool = False,
is_last: bool = False,
) -> None:
super().__init__()
self.norm1 = HyperCLOVAXSeedVisionRMSNorm(config.hidden_size, eps=1e-6)
self.norm2 = HyperCLOVAXSeedVisionRMSNorm(config.hidden_size, eps=1e-6)
self.attn = HyperCLOVAXSeedVisionAttention(config=config)
self.mlp = HyperCLOVAXSeedVisionMLP(config, bias=True)
self.is_fullatt = is_fullatt
self.is_last = is_last
def forward(
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: Optional[torch.Tensor] = None,
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> torch.Tensor:
hidden_states = hidden_states + self.attn(
self.norm1(hidden_states),
cu_seqlens=cu_seqlens,
rotary_pos_emb=rotary_pos_emb,
position_embeddings=position_embeddings,
**kwargs,
)
# fp16 full-attention blocks and the last block accumulate rounding error
# in the MLP; promote to float32 for numerical stability.
if (
(not self.is_fullatt and not self.is_last)
or hidden_states.dtype != torch.float16
):
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
else:
org_type = hidden_states.dtype
with torch.amp.autocast(device_type=hidden_states.device.type, dtype=torch.float32):
mlp_out = self.mlp(self.norm2(hidden_states))
if self.is_last:
hidden_states = (hidden_states + mlp_out).to(torch.float32)
else:
hidden_states = (hidden_states + mlp_out).to(org_type)
return hidden_states
class HyperCLOVAXSeedCLIPVisionEncoder(PreTrainedModel):
"""HyperCLOVAX SEED CLIP Vision Encoder.
A spatio-temporal vision transformer that encodes images and videos into
sequential patch token sequences. Used as the vision backbone in the CLIP model;
the patch merger is not applied here — pooling is handled by
HyperCLOVAXSeedCLIPVisionModel.
"""
config_class = HyperCLOVAXSeedCLIPVisionConfig
_no_split_modules = ["HyperCLOVAXSeedVisionBlock"]
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
def __init__(self, config: HyperCLOVAXSeedCLIPVisionConfig, *inputs, **kwargs) -> None:
super().__init__(config, *inputs, **kwargs)
self.spatial_merge_size = config.spatial_merge_size
self.patch_size = config.patch_size
self.fullatt_block_indexes = config.fullatt_block_indexes
self.window_size = config.window_size
self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size
self.patch_embed = HyperCLOVAXSeedVisionPatchEmbed(
patch_size=config.patch_size,
temporal_patch_size=config.temporal_patch_size,
in_channels=config.in_channels,
embed_dim=config.hidden_size,
)
head_dim = config.hidden_size // config.num_heads
self.rotary_pos_emb = HyperCLOVAXSeedVisionRotaryEmbedding(head_dim // 2)
self.blocks = nn.ModuleList([
HyperCLOVAXSeedVisionBlock(
config,
is_fullatt=(_block_idx in config.fullatt_block_indexes),
is_last=(_block_idx == config.depth - 1),
)
for _block_idx in range(config.depth)
])
self.gradient_checkpointing = False
self.post_init()
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
"""Compute 2D rotary position embeddings for all patches in the batch."""
pos_ids = []
for t, h, w in grid_thw:
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
hpos_ids = hpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
hpos_ids = hpos_ids.flatten()
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
wpos_ids = wpos_ids.reshape(
h // self.spatial_merge_size,
self.spatial_merge_size,
w // self.spatial_merge_size,
self.spatial_merge_size,
)
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
wpos_ids = wpos_ids.flatten()
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
return rotary_pos_emb
def get_window_index(self, grid_thw: torch.Tensor) -> tuple[torch.Tensor, list]:
"""Build a flat index that reorders tokens into non-overlapping windows.
Returns:
window_index: permutation indices to gather tokens in window order
cu_window_seqlens: cumulative window sequence lengths for varlen attention
"""
window_index: list = []
cu_window_seqlens: list = [0]
window_index_id = 0
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
for grid_t, grid_h, grid_w in grid_thw:
llm_grid_h, llm_grid_w = (
grid_h // self.spatial_merge_size,
grid_w // self.spatial_merge_size,
)
index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
index_padded = index_padded.reshape(
grid_t,
num_windows_h,
vit_merger_window_size,
num_windows_w,
vit_merger_window_size,
)
index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
grid_t,
num_windows_h * num_windows_w,
vit_merger_window_size,
vit_merger_window_size,
)
seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
index_padded = index_padded.reshape(-1)
index_new = index_padded[index_padded != -100]
window_index.append(index_new + window_index_id)
cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
window_index = torch.cat(window_index, dim=0)
return window_index, cu_window_seqlens
def forward(
self,
hidden_states: torch.Tensor,
grid_thw: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.Tensor` of shape `(total_patches, patch_dim)`):
Flattened patch pixels passed to the patch embedding layer.
grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
Temporal, height and width grid dimensions for each input item.
Returns:
`torch.Tensor` of shape `(total_tokens, hidden_size)` in sequential patch order.
"""
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
cu_window_seqlens,
device=hidden_states.device,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
dim=0,
dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
for layer_num, blk in enumerate(self.blocks):
cu_seqlens_now = cu_seqlens if layer_num in self.fullatt_block_indexes else cu_window_seqlens
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
blk,
hidden_states,
cu_seqlens_now,
None, # rotary_pos_emb (unused; position_embeddings used instead)
position_embeddings,
)
else:
hidden_states = blk(
hidden_states,
cu_seqlens=cu_seqlens_now,
position_embeddings=position_embeddings,
**kwargs,
)
# Un-reorder from window order back to sequential patch order
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[reverse_indices]
hidden_states = hidden_states.reshape(seq_len, -1)
return hidden_states
@dataclass
@auto_docstring
# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip
class HyperCLOVAXSeedCLIPOutput(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`):
Contrastive loss for image-text similarity.
logits_per_image (`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`):
The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text
similarity scores.
logits_per_text (`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`):
The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image
similarity scores.
text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`].
image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim`):
The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`].
text_model_output (`BaseModelOutputWithPooling`):
The output of the [`SiglipTextModel`].
vision_model_output (`BaseModelOutputWithPooling`):
The output of the [`SiglipVisionModel`].
"""
loss: Optional[torch.FloatTensor] = None
logits_per_image: Optional[torch.FloatTensor] = None
logits_per_text: Optional[torch.FloatTensor] = None
text_embeds: Optional[torch.FloatTensor] = None
image_embeds: Optional[torch.FloatTensor] = None
text_model_output: BaseModelOutputWithPooling = None
vision_model_output: BaseModelOutputWithPooling = None
def to_tuple(self) -> tuple[Any]:
return tuple(
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple()
for k in self.keys()
)
class HyperCLOVAXSeedCLIPVisionModel(nn.Module):
"""Vision encoder with attention pooling head.
Combines:
1. HyperCLOVAXSeedCLIPVisionEncoder (spatio-temporal ViT, no merger)
2. post_layernorm (LayerNorm)
3. attn_pool (Siglip2MultiheadAttentionPoolingHead)
Output: single pooled vector per image/video of shape (batch, hidden_size).
"""
def __init__(self, config: HyperCLOVAXSeedCLIPVisionConfig):
super().__init__()
self.config = config
# 1. Vision encoder (no merger — pooling is done here instead)
self.encoder = HyperCLOVAXSeedCLIPVisionEncoder(config)
# 2. Post-layernorm before attention pooling
self.post_layernorm = nn.LayerNorm(config.hidden_size)
# 3. Siglip2-style attention pooling head
attn_pool_config = Siglip2VisionConfig(
hidden_size=config.hidden_size,
num_attention_heads=config.attn_pool_heads,
intermediate_size=int(config.hidden_size * config.attn_pool_mlp_ratio),
)
self.attn_pool = Siglip2MultiheadAttentionPoolingHead(attn_pool_config)
def forward(
self,
pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
) -> torch.Tensor:
"""
Args:
pixel_values: patchified tensor (total_patches, patch_dim)
grid_thw: (num_images, 3) - [temporal, height, width]
Returns:
pooled: (batch, hidden_size)
"""
# Vision encoder forward -> (total_tokens, hidden_size) in sequential order
hidden_states = self.encoder(pixel_values, grid_thw=grid_thw)
# Reshape (total_tokens, hidden_size) -> (batch, num_tokens, hidden_size)
batch_size = grid_thw.shape[0]
total_tokens = hidden_states.shape[0]
hidden_size = hidden_states.shape[1]
num_tokens_per_image = total_tokens // batch_size
hidden_states = hidden_states.reshape(batch_size, num_tokens_per_image, hidden_size)
# Post-layernorm -> attention pooling
hidden_states = self.post_layernorm(hidden_states)
pooled = self.attn_pool(hidden_states) # (batch, hidden_size)
return pooled
class HyperCLOVAXSeedCLIPPreTrainedModel(PreTrainedModel):
config_class = HyperCLOVAXSeedCLIPConfig
base_model_prefix = "hyperclovax_seed_clip"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
class HyperCLOVAXSeedCLIPModel(HyperCLOVAXSeedCLIPPreTrainedModel):
"""HyperCLOVAX SEED CLIP model with vision encoder + SiglipText encoder.
Uses:
- HyperCLOVAXSeedCLIPVisionModel: vision encoder + post-LN + attention pooling
- SiglipTextTransformer: bidirectional text with "last" pooling + linear head
"""
config_class = HyperCLOVAXSeedCLIPConfig
def __init__(self, config: HyperCLOVAXSeedCLIPConfig):
super().__init__(config)
text_config = config.text_config
vision_config = config.vision_config
# --- Vision model (vision encoder + pooling) ---
self.vision_model = HyperCLOVAXSeedCLIPVisionModel(vision_config)
# --- Text model (SiglipTextTransformer) ---
text_model = SiglipTextModel._from_config(
text_config, attn_implementation=config._attn_implementation
)
self.text_model = text_model.text_model # inner SiglipTextTransformer
# --- Contrastive parameters ---
self.logit_scale = nn.Parameter(torch.randn(1))
self.logit_bias = nn.Parameter(torch.randn(1))
self.post_init()
def get_text_features(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=None,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = text_outputs[1]
return pooled_output
def get_image_features(
self,
pixel_values: Optional[torch.FloatTensor] = None,
grid_thw: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
pooled_output = self.vision_model(
pixel_values=pixel_values,
grid_thw=grid_thw,
)
return pooled_output
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
grid_thw: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, HyperCLOVAXSeedCLIPOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
image_embeds = self.vision_model(
pixel_values=pixel_values,
grid_thw=grid_thw,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=None,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_embeds = text_outputs[1] # pooled_output
# L2 normalize
image_embeds = F.normalize(image_embeds, p=2, dim=-1)
text_embeds = F.normalize(text_embeds, p=2, dim=-1)
# Cosine similarity as logits
logits_per_text = (
torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp()
+ self.logit_bias
)
logits_per_image = logits_per_text.t()
loss = None
if return_loss:
eye = torch.eye(logits_per_text.size(0), device=logits_per_text.device)
m1_diag1 = -torch.ones_like(logits_per_text) + 2 * eye
loglik = torch.nn.functional.logsigmoid(m1_diag1 * logits_per_text)
nll = -torch.sum(loglik, dim=-1)
loss = nll.mean()
if not return_dict:
output = (
logits_per_image,
logits_per_text,
text_embeds,
image_embeds,
text_outputs,
)
return ((loss,) + output) if loss is not None else output
return HyperCLOVAXSeedCLIPOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=None,
)